BaselineHF.predict:v11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import weave
import time
import torch
@weave.op()
async def predict(self, inp: str) -> str:
inp = self.tokenizer.encode(inp, return_tensors="pt").to(self.device)
# if dtype:
# inp = inp.dtype(dtype)
pad_token_id = self.tokenizer.eos_token_id
start_time = time.time()
with torch.no_grad():
output = self.llm_model.generate(inp,
max_new_tokens=self.max_new_tokens,
# attention_mask=attention_mask,
pad_token_id=pad_token_id
)
generation_time = time.time() - start_time
completion = self.tokenizer.decode(output[0])
return {"generation_time": generation_time,
"output": completion,
"total_tokens": len(output[0])}