New report
Created on February 16|Last edited on February 16
Comment
Expand 57 lines ... | |||||
58 |
| 58 |
| ||
59 |
| 59 |
| ||
60 |
| 60 |
| ||
61 | - |
| |||
62 |
| 61 |
| ||
63 |
| 62 |
| ||
64 | * large (possibly—not sure we can tokenize for it) | 63 | * large (possibly—not sure we can tokenize for it) | ||
65 |
| 64 |
| ||
66 | - |
| 65 | + |
|
67 |
| 66 |
| ||
68 | GPT_RM = "stacey/model-registry/tinybird" | 67 | GPT_RM = "stacey/model-registry/tinybird" | ||
69 |
| 68 |
| ||
Expand 28 lines ... | |||||
98 |
| 97 |
| ||
99 | 98 | ||||
100 | - |
| 99 | + |
|
101 | - |
| 100 | + | model = load_model_variant("medium", model_type=MODEL_TYPE, run_prefix="resurrect_") |
102 |
| 101 |
| ||
103 | to_tune = load_model_variant("medium") | 102 | to_tune = load_model_variant("medium") | ||
104 |
| 103 |
| ||
105 |
| 104 |
| ||
106 |
| 105 |
| ||
Expand 70 lines ... | |||||
176 | + | NUM_SAMPLES = 10 | |||
177 | NUM_TOKENS = 200 | 177 | NUM_TOKENS = 200 | ||
178 | 178 | ||||
179 | from transformers import Trainer | 179 | from transformers import Trainer | ||
180 | from transformers import TrainingArguments | 180 | from transformers import TrainingArguments | ||
181 | 181 | ||||
182 | training_args = TrainingArguments( | 182 | training_args = TrainingArguments( | ||
Expand 52 lines ... | |||||
235 | - | def gen_results(model, prompt, max_tokens=MAX_TOKENS, num_samples=NUM_SAMPLES, mode="pretty"): | 235 | + | def gen_results(model, prompt, max_tokens=MAX_TOKENS, num_samples= |
236 | device = "cuda" | 236 | device = "cuda" | ||
237 | tokenizer = GPT2Tokenizer.from_pretrained(MODEL_TYPE) | 237 | tokenizer = GPT2Tokenizer.from_pretrained(MODEL_TYPE) | ||
238 | encoded_input = tokenizer(prompt, return_tensors='pt').to(device) | 238 | encoded_input = tokenizer(prompt, return_tensors='pt').to(device) | ||
239 | x = encoded_input['input_ids'] | 239 | x = encoded_input['input_ids'] | ||
240 | x = x.expand(num_samples, -1) | 240 | x = x.expand(num_samples, -1) | ||
241 | y = model.generate(x, max_new_tokens=max_tokens, do_sample=True, top_k=40) | 241 | y = model.generate(x, max_new_tokens=max_tokens, do_sample=True, top_k=40) | ||
246 | - |
| 246 | + | data.append([prompt,out]) |
Expand 4 lines ... | |||||
247 | - | html_result = pretty(prompt, out) | |||
248 | - | data.append([prompt, html_result]) | |||
249 | - | else: | |||
250 | - | data.append([prompt, out]) | |||
251 | return wandb.Table(data=data, columns=["prompt", "response"]) | 247 | return wandb.Table(data=data, columns=["prompt", "response"]) | ||
252 | - | ||||
253 |
| 248 |
| ||
254 |
| 249 |
| ||
255 |
| 250 |
| ||
256 | poem_titles = [ | 251 | poem_titles = [ | ||
257 | "The Beach Plum", | 252 | "The Beach Plum", | ||
258 | "The Inevitability of Side Effects", | 253 | "The Inevitability of Side Effects", | ||
259 | "The Edge of the World", | 254 | "The Edge of the World", | ||
260 | "Chirality", | 255 | "Chirality", | ||
365 | - | def infer(prompts, prompt_type, template,run_name="",notes=""): | 360 | + | def infer(prompts, prompt_type, template,notes=""): |
366 | - | wandb.init(project=PROJECT,name=run_name,job_type="live-explore") | 361 | + | wandb.init(project=PROJECT,job_type="live-explore") |
Expand 104 lines ... | |||||
367 |
| 362 |
| ||
368 | wandb.run.use_artifact("stacey/model-registry/tinybird:best_tuned") | 363 | wandb.run.use_artifact("stacey/model-registry/tinybird:best_tuned") | ||
369 |
| 364 |
| ||
370 |
| 365 |
| ||
371 |
| 366 |
| ||
372 |
| 367 |
| ||
373 |
| 368 |
| ||
374 | human_time, timestamp = get_timestamps() | 369 | human_time, timestamp = get_timestamps() | ||
375 | cfg = { | 370 | cfg = { | ||
387 | - | wandb.log({f"infer_{i}": table}) | 382 | + | wandb.log({f"infer_live_{i}": table}) |
389 | - |
| 384 | + | for j, row in enumerate(table.data): |
390 | - |
| 385 | + | print(f"Prompt:{row[0]}") |
Expand 12 lines ... | |||||
391 | - |
| 386 | + | print(f"Response:{row[1]}") |
392 |
| 387 |
| ||
393 |
| 388 |
| ||
394 |
| 389 |
| ||
395 | 390 | ||||
396 | wandb.run.finish() | 391 | wandb.run.finish() | ||
397 |
| 392 |
| ||
398 | 393 | ||||
399 |
| 394 |
| ||
400 |
| 395 |
| ||
401 |
| 396 |
| ||
402 | live_text = [ | 397 | live_text = [ | ||
403 | "Title: The Fog. Poem:", | 398 | "Title: The Fog. Poem:", | ||
404 | "Title: The Oncoming Storm. Poem:", | 399 | "Title: The Oncoming Storm. Poem:", | ||
453 | - | infer(mon_5, "compose-explore", "Title: X. {Poem,Lyrics,Article}:",run_name="mon_5_test_best_tune", | 448 | + | infer(mon_5, "compose-explore", "Title: X. {Poem,Lyrics,Article}:", |
Expand 55 lines ... | |||||
461 | return wandb.Html(html) | 456 | return wandb.Html(html) | ||
462 |
| 457 |
| ||
463 | html='<html><span style="color:#FF0000;background-color:green">H</span><span style="color:#66CC66;background-color:red"">el</span><span style="color:#FF9966;background-color:yellow">lo</span></html>' | 458 | html='<html><span style="color:#FF0000;background-color:green">H</span><span style="color:#66CC66;background-color:red"">el</span><span style="color:#FF9966;background-color:yellow">lo</span></html>' | ||
464 | wandb.init(project=PROJECT, name="test_html", job_type="test") | 459 | wandb.init(project=PROJECT, name="test_html", job_type="test") | ||
465 | wandb.log({"html" : wandb.Html(html)}) | 460 | wandb.log({"html" : wandb.Html(html)}) | ||
466 | wandb.run.finish() | 461 | wandb.run.finish() | ||
467 | 462 | ||||
538 | - |
| 533 | + |
|
539 | - | art = wandb.Api().artifact('stacey/wave/gpt2-med-tune-3-compose-long:v0') | |||
Expand 77 lines ... | |||||
547 |
| 541 |
| ||
548 | wandb.init(project=PROJECT, name="upload-lyrics-test", job_type="upload") | 542 | wandb.init(project=PROJECT, name="upload-lyrics-test", job_type="upload") | ||
549 | t = df.sample(frac=0.06) | 543 | t = df.sample(frac=0.06) | ||
550 | tab = wandb.Table(dataframe=t) | 544 | tab = wandb.Table(dataframe=t) | ||
551 | wandb.run.log({"song_lyrics_005_test" : tab}) | 545 | wandb.run.log({"song_lyrics_005_test" : tab}) | ||
552 | wandb.run.finish() | 546 | wandb.run.finish() | ||
553 |
| 547 |
| ||
554 | wandb.init(project=PROJECT, name="upload-all-788", job_type="upload") | 548 | wandb.init(project=PROJECT, name="upload-all-788", job_type="upload") | ||
555 |
| 549 |
| ||
Expand 30 lines ... |
Add a comment