Skip to content

Commit b115921

Browse files
authored
Merge pull request #65 from krai/llama3.1-8b-cnndm
Add support to run Llama3.1-8B model on CNNDM dataset
2 parents fa84382 + 4da25a8 commit b115921

File tree

4 files changed

+161
-1
lines changed

4 files changed

+161
-1
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import json
2+
3+
from transformers import AutoTokenizer
4+
5+
def get_accuracy_dict(accuracy_dict_full):
6+
accuracy_dict = {}
7+
for k in accuracy_dict_full.keys():
8+
if k in ["rougeL", "exact_match", "tokens_per_sample"]:
9+
accuracy_dict[k] = accuracy_dict_full[k]
10+
return accuracy_dict
11+
12+
def parse_tokens(
13+
tokenised_accuracy_log_path: str, output_log_path: str
14+
):
15+
with open(tokenised_accuracy_log_path) as f:
16+
log = json.load(f)
17+
18+
output_log = []
19+
for item in log:
20+
hex_str = item["data"]
21+
hex_tokens = [hex_str[i : i + 8] for i in range(0, len(hex_str), 8)]
22+
tokens = [
23+
int.from_bytes(bytes.fromhex(tok), byteorder="little") for tok in hex_tokens
24+
]
25+
output_log.append(tokens)
26+
27+
with open(output_log_path, "w") as f:
28+
json.dump(output_log, f, indent=2)
29+
return output_log_path
30+
31+
def detokenise(
32+
checkpoint_path: str, tokenised_accuracy_log_path: str, output_log_path: str
33+
):
34+
tokeniser = AutoTokenizer.from_pretrained(checkpoint_path)
35+
36+
with open(tokenised_accuracy_log_path) as f:
37+
log = json.load(f)
38+
39+
output_log = []
40+
for item in log:
41+
hex_str = item["data"]
42+
hex_tokens = [hex_str[i : i + 8] for i in range(0, len(hex_str), 8)]
43+
tokens = [
44+
int.from_bytes(bytes.fromhex(tok), byteorder="little") for tok in hex_tokens
45+
]
46+
output_log.append({
47+
"seq_id" : item["seq_id"],
48+
"qsl_idx" : item["qsl_idx"],
49+
"data": tokeniser.decode(tokens),
50+
"token_count" : item["token_count"]
51+
})
52+
53+
with open(output_log_path, "w") as f:
54+
json.dump(output_log, f, indent=2)
55+
return output_log_path
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
{
2+
"_parent_entries": [ [ "^", "byname", "base_loadgen_experiment" ] ],
3+
4+
"transformers_query": [ "python_package", "package_name=transformers", ["desired_python_version", ["^", "kernel_python_major_dot_minor"]] ],
5+
6+
"_BEFORE_CODE_LOADING": [ "^^", "execute", [[
7+
[ "get_kernel" ],
8+
[ "byquery", [[ "^^", "get", "transformers_query" ]] ],
9+
[ "use" ]
10+
]] ],
11+
12+
"desired_python_version": "3.10",
13+
14+
"mlperf_inference_git_entry": [ "^", "byquery", "git_repo,repo_name=mlperf_inference_git" ],
15+
16+
"abs_script_path": [ "^^", "execute", [[
17+
[ "get", "mlperf_inference_git_entry" ],
18+
[ "get_path_of", "llama3_1_8b_cnndm_accuracy_script" ]
19+
]] ],
20+
21+
"accuracy_log_path": ["^^", "get_path", "mlperf_log_accuracy.json"],
22+
23+
"dataset_name": "cnndm",
24+
"model_family": "llama3_1",
25+
"model_variant": "8b",
26+
27+
"dataset_query": [ "downloaded", [ "^^", "substitute", "dataset_name=#{dataset_name}#,model_family=#{model_family}#,variant=#{model_variant}#" ]],
28+
"dataset_entry": [ "^", "byquery", [[ "^^", "get", "dataset_query" ]], {}, ["dataset_query"] ],
29+
30+
"dataset_path": [ "^^", "execute", [[
31+
[ "get", "dataset_entry" ],
32+
[ "get_path" ],
33+
[ "__add__", "/cnn_eval.json" ]
34+
]] ],
35+
36+
"checkpoint_path_query": [ "^^", "substitute", "downloaded,hf_tokeniser,model_family=#{model_family}#,variant=#{model_variant}#" ],
37+
"checkpoint_path": [ "^^", "execute", [[
38+
[ "get_kernel" ],
39+
[ "byquery", [[ "^^", "get", "checkpoint_path_query" ]] ],
40+
[ "get_path" ]
41+
]] ],
42+
43+
"accuracy_log_dtype": "int32",
44+
45+
"extract_accuracy_report": [ "^^", "execute", [[
46+
[ "get_kernel" ],
47+
[ "byname", "python_script" ],
48+
[ "run", [], {
49+
"python_deps": [
50+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=protobuf" ],
51+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=torch" ],
52+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=transformers" ],
53+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=nltk" ],
54+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=rouge_score" ],
55+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=sentencepiece" ],
56+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=pillow" ],
57+
[ "AS^IS", "^^", "python_sync_pip_package", "python_package,package_name=evaluate" ]
58+
],
59+
"abs_script_path": ["^^", "get", "abs_script_path"],
60+
"script_extra_params": [ "^^", "substitute", "--mlperf-accuracy-file #{accuracy_log_path}# --dataset-file #{dataset_path}# --dtype #{accuracy_log_dtype}#" ],
61+
"desired_python_version": ["^", "kernel_python_major_dot_minor"],
62+
"capture_output": true
63+
} ],
64+
0,
65+
[ "func", [ "ufun.rematch", "(\\{.*\\})" ] ],
66+
0,
67+
[ "denumpify_dict" ],
68+
0,
69+
[ "func", "str" ]
70+
]], {} ],
71+
72+
"accuracy_dict_full": [ "^^", "execute", [[
73+
["get", "accuracy_report" ],
74+
0,
75+
[ "func", "eval" ]
76+
]], {} ],
77+
"accuracy_dict": [ "^^", "get_accuracy_dict" ],
78+
"rouge1": [ "^^" , "dig","accuracy_dict.rouge1" ],
79+
"rouge2": [ "^^" , "dig","accuracy_dict.rouge2" ],
80+
"rougeL": [ "^^" , "dig","accuracy_dict.rougeL" ],
81+
"rougeLsum": [ "^^" , "dig","accuracy_dict.rougeLsum" ],
82+
"gen_len": [ "^^" , "dig","accuracy_dict.gen_len" ],
83+
"gen_num": [ "^^" , "dig","accuracy_dict.gen_num" ],
84+
"tokens_per_sample": [ "^^" , "dig","accuracy_dict.tokens_per_sample" ],
85+
86+
"tokenised_accuracy_log_path": [ "^^", "get_path", "mlperf_log_accuracy.json" ],
87+
"output_log_path": [ "^^", "get_path", "detokenised_mlperf_log.json" ],
88+
89+
"detokenised_log": [ "^^", "detokenise" ]
90+
}

data_axs.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"model_pytorch_resnet50": "model_pytorch_resnet50",
4141
"mlperf_power_git_recipe": "mlperf_power_git_recipe",
4242
"dataset_cnndm_mlperf_recipe": "dataset_cnndm_mlperf_recipe",
43+
"dataset_small_llm_cnndm_mlperf_recipe": "dataset_small_llm_cnndm_mlperf_recipe",
4344
"dataset_lambada_recipe": "dataset_lambada_recipe",
4445
"dataset_coco2014_images_recipe": "dataset_coco2014_images_recipe",
4546
"gptj_reference_loadgen": "gptj_reference_loadgen",
@@ -74,7 +75,8 @@
7475
"base_llama3_1_loadgen_experiment": "base_llama3_1_loadgen_experiment",
7576
"dataset_llrg_mlperf_recipe": "dataset_llrg_mlperf_recipe",
7677
"convert_openorca": "convert_openorca",
77-
"quantize_quark_recipe": "quantize_quark_recipe"
78+
"quantize_quark_recipe": "quantize_quark_recipe",
79+
"base_small_llm_loadgen_experiment": "base_small_llm_loadgen_experiment"
7880
},
7981
"repo_name": "axs2mlperf",
8082
"checkout": null,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"_producer_rules": [
3+
[ [ "downloaded", "dataset_name=cnndm", "model_family=llama3_1", "variant=8b" ], [["get_kernel"],["byname","downloader"],["download"]], {
4+
"downloading_tool_query": "shell_tool,can_download_url_from_rclone",
5+
"url": "mlc-inference:mlcommons-inference-wg-public/llama3.1_8b/cnn_eval.json",
6+
"downloading_tool_params": {
7+
"rclone_remote_name": "mlc-inference"
8+
},
9+
"newborn_entry_name": "downloaded_mlc_cnndm_llama3_1_8b",
10+
"file_path": "llama3_1_8b"
11+
}, [] ]
12+
]
13+
}

0 commit comments

Comments
 (0)