New report
Created on February 16|Last edited on February 16
Comment
Expand 57 lines ... | |||||
58 | | 58 | | ||
59 | | 59 | | ||
60 | | 60 | | ||
61 | - | | |||
62 | | 61 | | ||
63 | | 62 | | ||
64 | * large (possibly—not sure we can tokenize for it) | 63 | * large (possibly—not sure we can tokenize for it) | ||
65 | | 64 | | ||
66 | - |
| 65 | + |
|
67 | | 66 | | ||
68 | GPT_RM = "stacey/model-registry/tinybird" | 67 | GPT_RM = "stacey/model-registry/tinybird" | ||
69 | | 68 | | ||
Expand 28 lines ... | |||||
98 | | 97 | | ||
99 | 98 | ||||
100 | - | | 99 | + |
|
101 | - |
| 100 | + | model = load_model_variant("medium", model_type=MODEL_TYPE, run_prefix="resurrect_") |
102 | | 101 | | ||
103 | to_tune = load_model_variant("medium") | 102 | to_tune = load_model_variant("medium") | ||
104 | | 103 | | ||
105 | | 104 | | ||
106 | | 105 | | ||
Expand 70 lines ... | |||||
176 | + | NUM_SAMPLES = 10 | |||
177 | NUM_TOKENS = 200 | 177 | NUM_TOKENS = 200 | ||
178 | 178 | ||||
179 | from transformers import Trainer | 179 | from transformers import Trainer | ||
180 | from transformers import TrainingArguments | 180 | from transformers import TrainingArguments | ||
181 | 181 | ||||
182 | training_args = TrainingArguments( | 182 | training_args = TrainingArguments( | ||
Expand 52 lines ... | |||||
235 | - | def gen_results(model, prompt, max_tokens=MAX_TOKENS, num_samples=NUM_SAMPLES, mode="pretty"): | 235 | + | def gen_results(model, prompt, max_tokens=MAX_TOKENS, num_samples= |
236 | device = "cuda" | 236 | device = "cuda" | ||
237 | tokenizer = GPT2Tokenizer.from_pretrained(MODEL_TYPE) | 237 | tokenizer = GPT2Tokenizer.from_pretrained(MODEL_TYPE) | ||
238 | encoded_input = tokenizer(prompt, return_tensors='pt').to(device) | 238 | encoded_input = tokenizer(prompt, return_tensors='pt').to(device) | ||
239 | x = encoded_input['input_ids'] | 239 | x = encoded_input['input_ids'] | ||
240 | x = x.expand(num_samples, -1) | 240 | x = x.expand(num_samples, -1) | ||
241 | y = model.generate(x, max_new_tokens=max_tokens, do_sample=True, top_k=40) | 241 | y = model.generate(x, max_new_tokens=max_tokens, do_sample=True, top_k=40) | ||
246 | - |
| 246 | + | data.append([prompt,out]) |
Expand 4 lines ... | |||||
247 | - | html_result = pretty(prompt, out) | |||
248 | - | data.append([prompt, html_result]) | |||
249 | - | else: | |||
250 | - | data.append([prompt, out]) | |||
251 | return wandb.Table(data=data, columns=["prompt", "response"]) | 247 | return wandb.Table(data=data, columns=["prompt", "response"]) | ||
252 | - | ||||
253 | | 248 | | ||
254 | | 249 | | ||
255 | | 250 | | ||
256 | poem_titles = [ | 251 | poem_titles = [ | ||
257 | "The Beach Plum", | 252 | "The Beach Plum", | ||
258 | "The Inevitability of Side Effects", | 253 | "The Inevitability of Side Effects", | ||
259 | "The Edge of the World", | 254 | "The Edge of the World", | ||
260 | "Chirality", | 255 | "Chirality", | ||
365 | - | def infer(prompts, prompt_type, template,run_name="",notes=""): | 360 | + | def infer(prompts, prompt_type, template,notes=""): |
366 | - | wandb.init(project=PROJECT,name=run_name,job_type="live-explore") | 361 | + | wandb.init(project=PROJECT,job_type="live-explore") |
Expand 104 lines ... | |||||
367 | | 362 | | ||
368 | wandb.run.use_artifact("stacey/model-registry/tinybird:best_tuned") | 363 | wandb.run.use_artifact("stacey/model-registry/tinybird:best_tuned") | ||
369 | | 364 | | ||
370 | | 365 | | ||
371 | | 366 | | ||
372 | | 367 | | ||
373 | | 368 | | ||
374 | human_time, timestamp = get_timestamps() | 369 | human_time, timestamp = get_timestamps() | ||
375 | cfg = { | 370 | cfg = { | ||
387 | - | wandb.log({f"infer_{i}": table}) | 382 | + | wandb.log({f"infer_live_{i}": table}) |
389 | - |
| 384 | + | for j, row in enumerate(table.data): |
390 | - |
| 385 | + | print(f"Prompt:{row[0]}") |
Expand 12 lines ... | |||||
391 | - |
| 386 | + | print(f"Response:{row[1]}") |
392 | | 387 | | ||
393 | | 388 | | ||
394 | | 389 | | ||
395 | 390 | ||||
396 | wandb.run.finish() | 391 | wandb.run.finish() | ||
397 | | 392 | | ||
398 | 393 | ||||
399 | | 394 | | ||
400 | | 395 | | ||
401 | | 396 | | ||
402 | live_text = [ | 397 | live_text = [ | ||
403 | "Title: The Fog. Poem:", | 398 | "Title: The Fog. Poem:", | ||
404 | "Title: The Oncoming Storm. Poem:", | 399 | "Title: The Oncoming Storm. Poem:", | ||
453 | - | infer(mon_5, "compose-explore", "Title: X. {Poem,Lyrics,Article}:",run_name="mon_5_test_best_tune", | 448 | + | infer(mon_5, "compose-explore", "Title: X. {Poem,Lyrics,Article}:", |
Expand 55 lines ... | |||||
461 | return wandb.Html(html) | 456 | return wandb.Html(html) | ||
462 | | 457 | | ||
463 | html='<html><span style="color:#FF0000;background-color:green">H</span><span style="color:#66CC66;background-color:red"">el</span><span style="color:#FF9966;background-color:yellow">lo</span></html>' | 458 | html='<html><span style="color:#FF0000;background-color:green">H</span><span style="color:#66CC66;background-color:red"">el</span><span style="color:#FF9966;background-color:yellow">lo</span></html>' | ||
464 | wandb.init(project=PROJECT, name="test_html", job_type="test") | 459 | wandb.init(project=PROJECT, name="test_html", job_type="test") | ||
465 | wandb.log({"html" : wandb.Html(html)}) | 460 | wandb.log({"html" : wandb.Html(html)}) | ||
466 | wandb.run.finish() | 461 | wandb.run.finish() | ||
467 | 462 | ||||
538 | - |
| 533 | + | |
539 | - | art = wandb.Api().artifact('stacey/wave/gpt2-med-tune-3-compose-long:v0') | |||
Expand 77 lines ... | |||||
547 | | 541 | | ||
548 | wandb.init(project=PROJECT, name="upload-lyrics-test", job_type="upload") | 542 | wandb.init(project=PROJECT, name="upload-lyrics-test", job_type="upload") | ||
549 | t = df.sample(frac=0.06) | 543 | t = df.sample(frac=0.06) | ||
550 | tab = wandb.Table(dataframe=t) | 544 | tab = wandb.Table(dataframe=t) | ||
551 | wandb.run.log({"song_lyrics_005_test" : tab}) | 545 | wandb.run.log({"song_lyrics_005_test" : tab}) | ||
552 | wandb.run.finish() | 546 | wandb.run.finish() | ||
553 | | 547 | | ||
554 | wandb.init(project=PROJECT, name="upload-all-788", job_type="upload") | 548 | wandb.init(project=PROJECT, name="upload-all-788", job_type="upload") | ||
555 | | 549 | | ||
Expand 30 lines ... |
Add a comment