BaselineHF.predict:v13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import weave
import time
import torch
use_llamacpp = true
@weave.op()
async def predict(self, inp: str) -> str:
if use_llamacpp:
start_time = time.time()
output = self.llm_model(
inp, max_tokens=2, stop=["Q:", "\n"],
)
generation_time = time.time() - start_time
# n_tokens = 0
completion = output["choices"][0]["text"]
n_tokens = output["usage"]["total_tokens"]
else:
inp = self.tokenizer.encode(inp, return_tensors="pt").to(self.device)
# if dtype:
# inp = inp.dtype(dtype)
pad_token_id = self.tokenizer.eos_token_id
start_time = time.time()