@@ -83,7 +83,7 @@ def prompt_generate_custom_cuda_fewshot_and_template(ref_arch_src: str, shots: l
8383 Avaliable few shot options to start with:
8484 - ex_add: pointwise addition
8585 - ex_fuse_gelu: fused gelu
86- - ex_fuse_mnist2 : fused convolutions and relus
86+ - ex_mnist2 : fused convolutions and relus
8787 - ex_tiled_matmul: tiled matrix multiplication
8888 """
8989 prompt = PROBLEM_STATEMENT_CLEANED
@@ -107,13 +107,13 @@ def prompt_generate_custom_cuda_fewshot_and_template(ref_arch_src: str, shots: l
107107 example_fuse_gelu_desc = "This given architecture is for a fused gelu: "
108108
109109 # k = 3
110- example_fuse_mnist2 = read_file (
110+ example_mnist2 = read_file (
111111 os .path .join (REPO_TOP_PATH , "src/prompts/few_shot/model_ex_mnist2.py" )
112112 )
113- example_fuse_mnist2_new = read_file (
113+ example_mnist2_new = read_file (
114114 os .path .join (REPO_TOP_PATH , "src/prompts/few_shot/model_new_ex_mnist2.py" )
115115 )
116- exmaple_fuse_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: "
116+ exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: "
117117
118118 # k = 4
119119 example_tiled_matmul = read_file (
@@ -127,14 +127,14 @@ def prompt_generate_custom_cuda_fewshot_and_template(ref_arch_src: str, shots: l
127127
128128 examples = []
129129 for s in shots :
130- if s not in ["ex_add" , "ex_fuse_gelu" , "ex_fuse_mnist2 " , "ex_tiled_matmul" ]:
130+ if s not in ["ex_add" , "ex_fuse_gelu" , "ex_mnist2 " , "ex_tiled_matmul" ]:
131131 raise ValueError (f"Invalid shot: { s } " )
132132 elif s == "ex_add" :
133133 examples .append ((example_add , example_add_new , example_add_desc ))
134134 elif s == "ex_fuse_gelu" :
135135 examples .append ((example_fuse_gelu , example_fuse_gelu_new , example_fuse_gelu_desc ))
136- elif s == "ex_fuse_mnist2 " :
137- examples .append ((example_fuse_mnist2 , example_fuse_mnist2_new , exmaple_fuse_mnist2_desc ))
136+ elif s == "ex_mnist2 " :
137+ examples .append ((example_mnist2 , example_mnist2_new , exmaple_mnist2_desc ))
138138 elif s == "ex_tiled_matmul" :
139139 examples .append ((example_tiled_matmul , example_tiled_matmul_new , example_tiled_matmul_desc ))
140140
@@ -171,7 +171,7 @@ def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) ->
171171 Generate a prompt with a CoT example following a template
172172 Avaliable CoT examples:
173173 - ex_fuse_gelu: fused gelu
174- - ex_fuse_mnist2 : fused convolutions and relus
174+ - ex_mnist2 : fused convolutions and relus
175175 - ex_tiled_matmul: tiled matrix multiplication
176176 """
177177
@@ -184,7 +184,7 @@ def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) ->
184184
185185 prompt = PROBLEM_STATEMENT_CLEANED
186186
187- assert cot_example in ["ex_fuse_gelu" , "ex_fuse_mnist2 " , "ex_tiled_matmul" ]
187+ assert cot_example in ["ex_fuse_gelu" , "ex_mnist2 " , "ex_tiled_matmul" ]
188188
189189 # k = 2
190190 example_fuse_gelu = read_file (
@@ -199,16 +199,16 @@ def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) ->
199199 example_fuse_gelu_desc = "This given architecture is for a fused gelu: "
200200
201201 # k = 3
202- example_fuse_mnist2 = read_file (
202+ example_mnist2 = read_file (
203203 os .path .join (REPO_TOP_PATH , "src/prompts/few_shot/model_ex_mnist2.py" )
204204 )
205- example_fuse_mnist2_cot = read_file (
205+ example_mnist2_cot = read_file (
206206 os .path .join (REPO_TOP_PATH , "src/prompts/cot/model_cot_mnist2.py" )
207207 )
208- example_fuse_mnist2_new = read_file (
208+ example_mnist2_new = read_file (
209209 os .path .join (REPO_TOP_PATH , "src/prompts/few_shot/model_new_ex_mnist2.py" )
210210 )
211- exmaple_fuse_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: "
211+ exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: "
212212
213213 # k = 4
214214 example_tiled_matmul = read_file (
@@ -228,16 +228,18 @@ def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) ->
228228 cot = example_fuse_gelu_cot
229229 kernel = example_fuse_gelu_new
230230 desc = example_fuse_gelu_desc
231- case "ex_fuse_mnist2 " :
232- base = example_fuse_mnist2
233- cot = example_fuse_mnist2_cot
234- kernel = example_fuse_mnist2_new
235- desc = exmaple_fuse_mnist2_desc
231+ case "ex_mnist2 " :
232+ base = example_mnist2
233+ cot = example_mnist2_cot
234+ kernel = example_mnist2_new
235+ desc = exmaple_mnist2_desc
236236 case "ex_tiled_matmul" :
237237 base = example_tiled_matmul
238238 cot = example_tiled_matmul_cot
239239 kernel = example_tiled_matmul_new
240240 desc = example_tiled_matmul_desc
241+ case _:
242+ raise ValueError (f"Invalid CoT example: { cot_example } not found in CoT examples" )
241243
242244 # construct example with
243245 # NOTE: we only do one example with CoT for now
0 commit comments