BaselineHF.predict:v7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import weave
import time
@weave.op()
async def predict(self, inp: str) -> str:
inp = self.tokenizer.encode(inp, return_tensors="pt").to(self.device)
# attention_mask = inp["attention_mask"]
pad_token_id = self.tokenizer.eos_token_id
start_time = time.time()
output = self.llm_model.generate(inp,
max_new_tokens=20,
# attention_mask=attention_mask,
pad_token_id=pad_token_id
)
generation_time = time.time() - start_time
completion = self.tokenizer.decode(output[0])
return {"generation_time": generation_time,
"output": completion,
"total_tokens": len(output[0])}