|
120 | 120 | "source": [ |
121 | 121 | "# Filter the dataset\n", |
122 | 122 | "filtered_data = [\n", |
123 | | - " \"wine review : \" + x[\"country\"] + \" : \" + x[\"province\"] + \" : \" + x[\"variety\"] + \" : \" + x[\"description\"]\n", |
| 123 | + " \"wine review : \"\n", |
| 124 | + " + x[\"country\"]\n", |
| 125 | + " + \" : \"\n", |
| 126 | + " + x[\"province\"]\n", |
| 127 | + " + \" : \"\n", |
| 128 | + " + x[\"variety\"]\n", |
| 129 | + " + \" : \"\n", |
| 130 | + " + x[\"description\"]\n", |
124 | 131 | " for x in wine_data\n", |
125 | 132 | " if x[\"country\"] is not None\n", |
126 | 133 | " and x[\"province\"] is not None\n", |
|
203 | 210 | "outputs": [], |
204 | 211 | "source": [ |
205 | 212 | "# Convert to a Tensorflow Dataset\n", |
206 | | - "text_ds = tf.data.Dataset.from_tensor_slices(text_data).batch(BATCH_SIZE).shuffle(1000)" |
| 213 | + "text_ds = (\n", |
| 214 | + " tf.data.Dataset.from_tensor_slices(text_data)\n", |
| 215 | + " .batch(BATCH_SIZE)\n", |
| 216 | + " .shuffle(1000)\n", |
| 217 | + ")" |
207 | 218 | ] |
208 | 219 | }, |
209 | 220 | { |
|
340 | 351 | " m = i >= j - n_src + n_dest\n", |
341 | 352 | " mask = tf.cast(m, dtype)\n", |
342 | 353 | " mask = tf.reshape(mask, [1, n_dest, n_src])\n", |
343 | | - " mult = tf.concat([tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0)\n", |
| 354 | + " mult = tf.concat(\n", |
| 355 | + " [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0\n", |
| 356 | + " )\n", |
344 | 357 | " return tf.tile(mask, mult)\n", |
345 | 358 | "\n", |
346 | 359 | "\n", |
|
370 | 383 | " self.embed_dim = embed_dim\n", |
371 | 384 | " self.ff_dim = ff_dim\n", |
372 | 385 | " self.dropout_rate = dropout_rate\n", |
373 | | - " self.attn = layers.MultiHeadAttention(num_heads, key_dim, output_shape=embed_dim)\n", |
| 386 | + " self.attn = layers.MultiHeadAttention(\n", |
| 387 | + " num_heads, key_dim, output_shape=embed_dim\n", |
| 388 | + " )\n", |
374 | 389 | " self.dropout_1 = layers.Dropout(self.dropout_rate)\n", |
375 | 390 | " self.ln_1 = layers.LayerNormalization(epsilon=1e-6)\n", |
376 | 391 | " self.ffn_1 = layers.Dense(self.ff_dim, activation=\"relu\")\n", |
|
382 | 397 | " input_shape = tf.shape(inputs)\n", |
383 | 398 | " batch_size = input_shape[0]\n", |
384 | 399 | " seq_len = input_shape[1]\n", |
385 | | - " causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)\n", |
| 400 | + " causal_mask = causal_attention_mask(\n", |
| 401 | + " batch_size, seq_len, seq_len, tf.bool\n", |
| 402 | + " )\n", |
386 | 403 | " attention_output, attention_scores = self.attn(\n", |
387 | | - " inputs, inputs, attention_mask=causal_mask, return_attention_scores=True\n", |
| 404 | + " inputs,\n", |
| 405 | + " inputs,\n", |
| 406 | + " attention_mask=causal_mask,\n", |
| 407 | + " return_attention_scores=True,\n", |
388 | 408 | " )\n", |
389 | 409 | " attention_output = self.dropout_1(attention_output)\n", |
390 | 410 | " out1 = self.ln_1(inputs + attention_output)\n", |
|
430 | 450 | " self.max_len = max_len\n", |
431 | 451 | " self.vocab_size = vocab_size\n", |
432 | 452 | " self.embed_dim = embed_dim\n", |
433 | | - " self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)\n", |
| 453 | + " self.token_emb = layers.Embedding(\n", |
| 454 | + " input_dim=vocab_size, output_dim=embed_dim\n", |
| 455 | + " )\n", |
434 | 456 | " self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)\n", |
435 | 457 | "\n", |
436 | 458 | " def call(self, x):\n", |
|
469 | 491 | "source": [ |
470 | 492 | "inputs = layers.Input(shape=(None,), dtype=tf.int32)\n", |
471 | 493 | "x = TokenAndPositionEmbedding(MAX_LEN, VOCAB_SIZE, EMBEDDING_DIM)(inputs)\n", |
472 | | - "x, attention_scores = TransformerBlock(N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM)(x)\n", |
| 494 | + "x, attention_scores = TransformerBlock(\n", |
| 495 | + " N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM\n", |
| 496 | + ")(x)\n", |
473 | 497 | "outputs = layers.Dense(VOCAB_SIZE, activation=\"softmax\")(x)\n", |
474 | 498 | "gpt = models.Model(inputs=inputs, outputs=[outputs, attention_scores])\n", |
475 | 499 | "gpt.compile(\"adam\", loss=[losses.SparseCategoricalCrossentropy(), None])" |
|
518 | 542 | "class TextGenerator(callbacks.Callback):\n", |
519 | 543 | " def __init__(self, index_to_word, top_k=10):\n", |
520 | 544 | " self.index_to_word = index_to_word\n", |
521 | | - " self.word_to_index = {word: index for index, word in enumerate(index_to_word)}\n", |
| 545 | + " self.word_to_index = {\n", |
| 546 | + " word: index for index, word in enumerate(index_to_word)\n", |
| 547 | + " }\n", |
522 | 548 | "\n", |
523 | 549 | " def sample_from(self, probs, temperature):\n", |
524 | 550 | " probs = probs ** (1 / temperature)\n", |
525 | 551 | " probs = probs / np.sum(probs)\n", |
526 | 552 | " return np.random.choice(len(probs), p=probs), probs\n", |
527 | 553 | "\n", |
528 | 554 | " def generate(self, start_prompt, max_tokens, temperature):\n", |
529 | | - " start_tokens = [self.word_to_index.get(x, 1) for x in start_prompt.split()]\n", |
| 555 | + " start_tokens = [\n", |
| 556 | + " self.word_to_index.get(x, 1) for x in start_prompt.split()\n", |
| 557 | + " ]\n", |
530 | 558 | " sample_token = None\n", |
531 | 559 | " info = []\n", |
532 | 560 | " while len(start_tokens) < max_tokens and sample_token != 0:\n", |
533 | 561 | " x = np.array([start_tokens])\n", |
534 | 562 | " y, att = self.model.predict(x, verbose=0)\n", |
535 | 563 | " sample_token, probs = self.sample_from(y[0][-1], temperature)\n", |
536 | | - " info.append({\"prompt\": start_prompt, \"word_probs\": probs, \"atts\": att[0, :, -1, :]})\n", |
| 564 | + " info.append(\n", |
| 565 | + " {\n", |
| 566 | + " \"prompt\": start_prompt,\n", |
| 567 | + " \"word_probs\": probs,\n", |
| 568 | + " \"atts\": att[0, :, -1, :],\n", |
| 569 | + " }\n", |
| 570 | + " )\n", |
537 | 571 | " start_tokens.append(sample_token)\n", |
538 | 572 | " start_prompt = start_prompt + \" \" + self.index_to_word[sample_token]\n", |
539 | 573 | " print(f\"\\ngenerated text:\\n{start_prompt}\\n\")\n", |
|
611 | 645 | "def print_probs(info, vocab, top_k=5):\n", |
612 | 646 | " for i in info:\n", |
613 | 647 | " highlighted_text = []\n", |
614 | | - " for word, att_score in zip(i[\"prompt\"].split(), np.mean(i[\"atts\"], axis=0)):\n", |
| 648 | + " for word, att_score in zip(\n", |
| 649 | + " i[\"prompt\"].split(), np.mean(i[\"atts\"], axis=0)\n", |
| 650 | + " ):\n", |
615 | 651 | " highlighted_text.append(\n", |
616 | 652 | " '<span style=\"background-color:rgba(135,206,250,'\n", |
617 | 653 | " + str(att_score / max(np.mean(i[\"atts\"], axis=0)))\n", |
|
637 | 673 | "metadata": {}, |
638 | 674 | "outputs": [], |
639 | 675 | "source": [ |
640 | | - "info = text_generator.generate(\"wine review : us\", max_tokens=80, temperature=1.0)" |
| 676 | + "info = text_generator.generate(\n", |
| 677 | + " \"wine review : us\", max_tokens=80, temperature=1.0\n", |
| 678 | + ")" |
641 | 679 | ] |
642 | 680 | }, |
643 | 681 | { |
|
647 | 685 | "metadata": {}, |
648 | 686 | "outputs": [], |
649 | 687 | "source": [ |
650 | | - "info = text_generator.generate(\"wine review : italy\", max_tokens=80, temperature=0.5)" |
| 688 | + "info = text_generator.generate(\n", |
| 689 | + " \"wine review : italy\", max_tokens=80, temperature=0.5\n", |
| 690 | + ")" |
651 | 691 | ] |
652 | 692 | }, |
653 | 693 | { |
|
657 | 697 | "metadata": {}, |
658 | 698 | "outputs": [], |
659 | 699 | "source": [ |
660 | | - "info = text_generator.generate(\"wine review : germany\", max_tokens=80, temperature=0.5)\n", |
| 700 | + "info = text_generator.generate(\n", |
| 701 | + " \"wine review : germany\", max_tokens=80, temperature=0.5\n", |
| 702 | + ")\n", |
661 | 703 | "print_probs(info, vocab)" |
662 | 704 | ] |
663 | 705 | }, |
|
0 commit comments