Skip to content

Commit f49de47

Browse files
committed
formatting
1 parent 440081c commit f49de47

File tree

1 file changed

+57
-15
lines changed

1 file changed

+57
-15
lines changed

notebooks/09_transformer/gpt/gpt.ipynb

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,14 @@
120120
"source": [
121121
"# Filter the dataset\n",
122122
"filtered_data = [\n",
123-
" \"wine review : \" + x[\"country\"] + \" : \" + x[\"province\"] + \" : \" + x[\"variety\"] + \" : \" + x[\"description\"]\n",
123+
" \"wine review : \"\n",
124+
" + x[\"country\"]\n",
125+
" + \" : \"\n",
126+
" + x[\"province\"]\n",
127+
" + \" : \"\n",
128+
" + x[\"variety\"]\n",
129+
" + \" : \"\n",
130+
" + x[\"description\"]\n",
124131
" for x in wine_data\n",
125132
" if x[\"country\"] is not None\n",
126133
" and x[\"province\"] is not None\n",
@@ -203,7 +210,11 @@
203210
"outputs": [],
204211
"source": [
205212
"# Convert to a Tensorflow Dataset\n",
206-
"text_ds = tf.data.Dataset.from_tensor_slices(text_data).batch(BATCH_SIZE).shuffle(1000)"
213+
"text_ds = (\n",
214+
" tf.data.Dataset.from_tensor_slices(text_data)\n",
215+
" .batch(BATCH_SIZE)\n",
216+
" .shuffle(1000)\n",
217+
")"
207218
]
208219
},
209220
{
@@ -340,7 +351,9 @@
340351
" m = i >= j - n_src + n_dest\n",
341352
" mask = tf.cast(m, dtype)\n",
342353
" mask = tf.reshape(mask, [1, n_dest, n_src])\n",
343-
" mult = tf.concat([tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0)\n",
354+
" mult = tf.concat(\n",
355+
" [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0\n",
356+
" )\n",
344357
" return tf.tile(mask, mult)\n",
345358
"\n",
346359
"\n",
@@ -370,7 +383,9 @@
370383
" self.embed_dim = embed_dim\n",
371384
" self.ff_dim = ff_dim\n",
372385
" self.dropout_rate = dropout_rate\n",
373-
" self.attn = layers.MultiHeadAttention(num_heads, key_dim, output_shape=embed_dim)\n",
386+
" self.attn = layers.MultiHeadAttention(\n",
387+
" num_heads, key_dim, output_shape=embed_dim\n",
388+
" )\n",
374389
" self.dropout_1 = layers.Dropout(self.dropout_rate)\n",
375390
" self.ln_1 = layers.LayerNormalization(epsilon=1e-6)\n",
376391
" self.ffn_1 = layers.Dense(self.ff_dim, activation=\"relu\")\n",
@@ -382,9 +397,14 @@
382397
" input_shape = tf.shape(inputs)\n",
383398
" batch_size = input_shape[0]\n",
384399
" seq_len = input_shape[1]\n",
385-
" causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)\n",
400+
" causal_mask = causal_attention_mask(\n",
401+
" batch_size, seq_len, seq_len, tf.bool\n",
402+
" )\n",
386403
" attention_output, attention_scores = self.attn(\n",
387-
" inputs, inputs, attention_mask=causal_mask, return_attention_scores=True\n",
404+
" inputs,\n",
405+
" inputs,\n",
406+
" attention_mask=causal_mask,\n",
407+
" return_attention_scores=True,\n",
388408
" )\n",
389409
" attention_output = self.dropout_1(attention_output)\n",
390410
" out1 = self.ln_1(inputs + attention_output)\n",
@@ -430,7 +450,9 @@
430450
" self.max_len = max_len\n",
431451
" self.vocab_size = vocab_size\n",
432452
" self.embed_dim = embed_dim\n",
433-
" self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)\n",
453+
" self.token_emb = layers.Embedding(\n",
454+
" input_dim=vocab_size, output_dim=embed_dim\n",
455+
" )\n",
434456
" self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)\n",
435457
"\n",
436458
" def call(self, x):\n",
@@ -469,7 +491,9 @@
469491
"source": [
470492
"inputs = layers.Input(shape=(None,), dtype=tf.int32)\n",
471493
"x = TokenAndPositionEmbedding(MAX_LEN, VOCAB_SIZE, EMBEDDING_DIM)(inputs)\n",
472-
"x, attention_scores = TransformerBlock(N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM)(x)\n",
494+
"x, attention_scores = TransformerBlock(\n",
495+
" N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM\n",
496+
")(x)\n",
473497
"outputs = layers.Dense(VOCAB_SIZE, activation=\"softmax\")(x)\n",
474498
"gpt = models.Model(inputs=inputs, outputs=[outputs, attention_scores])\n",
475499
"gpt.compile(\"adam\", loss=[losses.SparseCategoricalCrossentropy(), None])"
@@ -518,22 +542,32 @@
518542
"class TextGenerator(callbacks.Callback):\n",
519543
" def __init__(self, index_to_word, top_k=10):\n",
520544
" self.index_to_word = index_to_word\n",
521-
" self.word_to_index = {word: index for index, word in enumerate(index_to_word)}\n",
545+
" self.word_to_index = {\n",
546+
" word: index for index, word in enumerate(index_to_word)\n",
547+
" }\n",
522548
"\n",
523549
" def sample_from(self, probs, temperature):\n",
524550
" probs = probs ** (1 / temperature)\n",
525551
" probs = probs / np.sum(probs)\n",
526552
" return np.random.choice(len(probs), p=probs), probs\n",
527553
"\n",
528554
" def generate(self, start_prompt, max_tokens, temperature):\n",
529-
" start_tokens = [self.word_to_index.get(x, 1) for x in start_prompt.split()]\n",
555+
" start_tokens = [\n",
556+
" self.word_to_index.get(x, 1) for x in start_prompt.split()\n",
557+
" ]\n",
530558
" sample_token = None\n",
531559
" info = []\n",
532560
" while len(start_tokens) < max_tokens and sample_token != 0:\n",
533561
" x = np.array([start_tokens])\n",
534562
" y, att = self.model.predict(x, verbose=0)\n",
535563
" sample_token, probs = self.sample_from(y[0][-1], temperature)\n",
536-
" info.append({\"prompt\": start_prompt, \"word_probs\": probs, \"atts\": att[0, :, -1, :]})\n",
564+
" info.append(\n",
565+
" {\n",
566+
" \"prompt\": start_prompt,\n",
567+
" \"word_probs\": probs,\n",
568+
" \"atts\": att[0, :, -1, :],\n",
569+
" }\n",
570+
" )\n",
537571
" start_tokens.append(sample_token)\n",
538572
" start_prompt = start_prompt + \" \" + self.index_to_word[sample_token]\n",
539573
" print(f\"\\ngenerated text:\\n{start_prompt}\\n\")\n",
@@ -611,7 +645,9 @@
611645
"def print_probs(info, vocab, top_k=5):\n",
612646
" for i in info:\n",
613647
" highlighted_text = []\n",
614-
" for word, att_score in zip(i[\"prompt\"].split(), np.mean(i[\"atts\"], axis=0)):\n",
648+
" for word, att_score in zip(\n",
649+
" i[\"prompt\"].split(), np.mean(i[\"atts\"], axis=0)\n",
650+
" ):\n",
615651
" highlighted_text.append(\n",
616652
" '<span style=\"background-color:rgba(135,206,250,'\n",
617653
" + str(att_score / max(np.mean(i[\"atts\"], axis=0)))\n",
@@ -637,7 +673,9 @@
637673
"metadata": {},
638674
"outputs": [],
639675
"source": [
640-
"info = text_generator.generate(\"wine review : us\", max_tokens=80, temperature=1.0)"
676+
"info = text_generator.generate(\n",
677+
" \"wine review : us\", max_tokens=80, temperature=1.0\n",
678+
")"
641679
]
642680
},
643681
{
@@ -647,7 +685,9 @@
647685
"metadata": {},
648686
"outputs": [],
649687
"source": [
650-
"info = text_generator.generate(\"wine review : italy\", max_tokens=80, temperature=0.5)"
688+
"info = text_generator.generate(\n",
689+
" \"wine review : italy\", max_tokens=80, temperature=0.5\n",
690+
")"
651691
]
652692
},
653693
{
@@ -657,7 +697,9 @@
657697
"metadata": {},
658698
"outputs": [],
659699
"source": [
660-
"info = text_generator.generate(\"wine review : germany\", max_tokens=80, temperature=0.5)\n",
700+
"info = text_generator.generate(\n",
701+
" \"wine review : germany\", max_tokens=80, temperature=0.5\n",
702+
")\n",
661703
"print_probs(info, vocab)"
662704
]
663705
},

0 commit comments

Comments
 (0)