1+ import logging
2+ import os
3+ import sys
4+ from dataclasses import dataclass , field
5+ from typing import Any , Dict , List , Optional , Union
6+
7+ import torch
8+ from torch .utils .data import DataLoader , SequentialSampler
9+
10+ from accelerate import Accelerator
11+ from accelerate .logging import get_logger
12+
13+ import transformers
14+ from transformers import (
15+ HfArgumentParser ,
16+ AutoTokenizer ,
17+ TrainingArguments ,
18+ set_seed ,
19+ )
20+
21+ from llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
22+ from dataset .utils import load_dataset
23+ from llm2vec .loss .utils import load_loss
24+ from llm2vec .experiment_utils import generate_experiment_id
25+
26+ from tqdm import tqdm
27+
28+ # Suppress unnecessary warnings
29+ transformers .logging .set_verbosity_error ()
30+
31+ # Configure logging
32+ logging .basicConfig (
33+ format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ,
34+ datefmt = "%Y-%m-%d %H:%M:%S" ,
35+ level = logging .INFO ,
36+ )
37+ logger = get_logger (__name__ , log_level = "INFO" )
38+
39+ @dataclass
40+ class ModelArguments :
41+ """
42+ Arguments pertaining to which model/config/tokenizer we are going to use for testing.
43+ """
44+ model_name_or_path : Optional [str ] = field (
45+ default = None ,
46+ metadata = {
47+ "help" : "The model checkpoint for weights initialization."
48+ },
49+ )
50+ peft_model_name_or_path : Optional [str ] = field (
51+ default = None ,
52+ metadata = {"help" : "The PEFT model checkpoint to add on top of the base model." },
53+ )
54+ pooling_mode : Optional [str ] = field (
55+ default = "mean" ,
56+ metadata = {
57+ "help" : "The pooling mode to use in the model." ,
58+ "choices" : ["mean" , "weighted_mean" , "eos_token" ],
59+ },
60+ )
61+ max_seq_length : Optional [int ] = field (
62+ default = 512 ,
63+ metadata = {
64+ "help" : "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated."
65+ },
66+ )
67+ torch_dtype : Optional [str ] = field (
68+ default = "float16" ,
69+ metadata = {
70+ "help" : "Override the default `torch.dtype` and load the model under this dtype." ,
71+ "choices" : ["auto" , "bfloat16" , "float16" , "float32" ],
72+ },
73+ )
74+
75+ @dataclass
76+ class DataTrainingArguments :
77+ """
78+ Arguments pertaining to the test data input.
79+ """
80+ test_dataset_name : Optional [str ] = field (
81+ default = None ,
82+ metadata = {"help" : "The name of the test dataset to use." },
83+ )
84+ test_file_path : Optional [str ] = field (
85+ default = None ,
86+ metadata = {"help" : "The input test data file or folder." },
87+ )
88+ dataframe_path : Optional [str ] = field (
89+ default = None ,
90+ metadata = {"help" : "Path to the dataframe file for the test dataset." },
91+ )
92+ max_test_samples : Optional [int ] = field (
93+ default = None ,
94+ metadata = {
95+ "help" : "For debugging purposes or quicker testing, truncate the number of test examples to this value if set."
96+ },
97+ )
98+
99+ @dataclass
100+ class CustomArguments :
101+ """
102+ Custom arguments for the testing script.
103+ """
104+ batch_size : int = field (
105+ default = 16 ,
106+ metadata = {"help" : "Batch size for testing." }
107+ )
108+ output_dir : Optional [str ] = field (
109+ default = "./test_results" ,
110+ metadata = {"help" : "Directory to save test results." }
111+ )
112+
113+ def prepare_for_tokenization (model , text , pooling_mode = "mean" ):
114+ if model .config ._name_or_path in ["meta-llama/Meta-Llama-3-8B-Instruct" , "meta-llama/Llama-3.2-3B" , "meta-llama/Llama-3.2-1B" , "meta-llama/Llama-3.2-8B" ]:
115+ text = (
116+ "<|start_header_id|>user<|end_header_id|>\n \n " + text .strip () + "<|eot_id|>"
117+ )
118+ return text
119+ if model .config ._name_or_path in [
120+ "mistralai/Mistral-7B-Instruct-v0.2" ,
121+ "meta-llama/Llama-2-7b-chat-hf" ,
122+ ]:
123+ text = "[INST] " + text .strip () + " [/INST]"
124+ if model .config ._name_or_path in [
125+ "google/gemma-2-9b-it" ,
126+ ]:
127+ text = "<bos><start_of_turn>user\n " + text .strip () + "<end_of_turn>"
128+ if model .config ._name_or_path in [
129+ "Qwen/Qwen2-1.5B-Instruct" ,
130+ "Qwen/Qwen2-7B-Instruct" ,
131+ ]:
132+ text = "<|im_start|>user\n " + text .strip () + "<|im_end|>"
133+ if pooling_mode == "eos_token" :
134+ if model .config ._name_or_path == "meta-llama/Meta-Llama-3-8B" :
135+ text = text .strip () + "<|end_of_text|>"
136+ elif isinstance (model .config , LlamaConfig ) or isinstance (
137+ model .config , MistralConfig
138+ ):
139+ text = text .strip () + " </s>"
140+ elif isinstance (model .config , GemmaConfig ):
141+ text = text .strip () + "<eos>"
142+ elif isinstance (model .config , Qwen2Config ):
143+ text = text .strip () + "<|endoftext|>"
144+ return text
145+
146+ def main ():
147+ # Initialize Accelerator
148+ accelerator = Accelerator ()
149+ logger .info ("Initialized Accelerator." )
150+
151+ # Argument parsing
152+ parser = HfArgumentParser (
153+ (ModelArguments , DataTrainingArguments , CustomArguments )
154+ )
155+ if len (sys .argv ) == 2 and sys .argv [1 ].endswith (".json" ):
156+ model_args , data_args , custom_args = parser .parse_json_file (
157+ json_file = os .path .abspath (sys .argv [1 ])
158+ )
159+ else :
160+ model_args , data_args , custom_args = parser .parse_args_into_dataclasses ()
161+
162+ # Set seed for reproducibility
163+ set_seed (42 )
164+
165+ # Load tokenizer
166+ logger .info ("Loading tokenizer..." )
167+ tokenizer = AutoTokenizer .from_pretrained (model_args .model_name_or_path )
168+
169+ # Load model with PEFT
170+ logger .info ("Loading the LLM2Vec model with PEFT adapter for testing..." )
171+ model = LLM2Vec .from_pretrained (
172+ base_model_name_or_path = model_args .model_name_or_path ,
173+ enable_bidirectional = False , # Typically not needed for testing
174+ peft_model_name_or_path = model_args .peft_model_name_or_path ,
175+ merge_peft = True ,
176+ pooling_mode = model_args .pooling_mode ,
177+ max_length = model_args .max_seq_length ,
178+ torch_dtype = torch .float16 if model_args .torch_dtype == "float16" else getattr (torch , model_args .torch_dtype ),
179+ )
180+
181+ model .to (accelerator .device )
182+ model .eval ()
183+ logger .info ("Model loaded and set to evaluation mode." )
184+
185+ # Load test dataset
186+ logger .info ("Loading test dataset..." )
187+ test_dataset = load_dataset (
188+ dataset_name = data_args .test_dataset_name ,
189+ split = "test" ,
190+ file_path = data_args .test_file_path ,
191+ dataframe_path = data_args .dataframe_path ,
192+ )
193+
194+ if data_args .max_test_samples is not None :
195+ test_dataset = test_dataset .select (range (data_args .max_test_samples ))
196+ logger .info (f"Truncated test dataset to { data_args .max_test_samples } samples." )
197+
198+ # Prepare data collator
199+ def collate_fn (batch ):
200+ texts = [prepare_for_tokenization (model .model , example ['text' ], pooling_mode = model_args .pooling_mode ) for example in batch ]
201+ inputs = tokenizer (
202+ texts ,
203+ padding = True ,
204+ truncation = True ,
205+ max_length = model_args .max_seq_length ,
206+ return_tensors = "pt"
207+ )
208+ labels = torch .tensor ([example ['label' ] for example in batch ])
209+ return {"input_ids" : inputs ["input_ids" ], "attention_mask" : inputs ["attention_mask" ], "labels" : labels }
210+
211+ test_dataloader = DataLoader (
212+ test_dataset ,
213+ sampler = SequentialSampler (test_dataset ),
214+ batch_size = custom_args .batch_size ,
215+ collate_fn = collate_fn ,
216+ )
217+
218+ # Prepare dataloader with accelerator
219+ test_dataloader = accelerator .prepare (test_dataloader )
220+ logger .info ("Test DataLoader prepared." )
221+
222+ # Evaluation loop
223+ logger .info ("Starting evaluation..." )
224+ all_predictions = []
225+ all_labels = []
226+
227+ with torch .no_grad ():
228+ for batch in tqdm (test_dataloader , desc = "Testing" ):
229+ inputs = {
230+ "input_ids" : batch ["input_ids" ].to (accelerator .device ),
231+ "attention_mask" : batch ["attention_mask" ].to (accelerator .device ),
232+ }
233+ labels = batch ["labels" ].to (accelerator .device )
234+
235+ outputs = model .model (** inputs )
236+ logits = outputs .logits
237+ predictions = torch .argmax (logits , dim = - 1 )
238+
239+ all_predictions .extend (accelerator .gather (predictions ).cpu ().numpy ())
240+ all_labels .extend (accelerator .gather (labels ).cpu ().numpy ())
241+
242+ # Compute metrics
243+ from sklearn .metrics import accuracy_score , f1_score
244+
245+ accuracy = accuracy_score (all_labels , all_predictions )
246+ f1 = f1_score (all_labels , all_predictions , average = 'weighted' )
247+
248+ logger .info (f"Test Accuracy: { accuracy :.4f} " )
249+ logger .info (f"Test F1 Score: { f1 :.4f} " )
250+
251+ # Save results
252+ os .makedirs (custom_args .output_dir , exist_ok = True )
253+ results_path = os .path .join (custom_args .output_dir , "test_results.txt" )
254+ with open (results_path , "w" ) as f :
255+ f .write (f"Test Accuracy: { accuracy :.4f} \n " )
256+ f .write (f"Test F1 Score: { f1 :.4f} \n " )
257+
258+ logger .info (f"Test results saved to { results_path } " )
259+
260+ if __name__ == "__main__" :
261+ main ()
0 commit comments