@@ -259,12 +259,12 @@ def evaluate_iter(model, tokenizer, data, iter_nums, epoch, args, tb_writer=None
259259 model .eval ()
260260 l2v = LLM2Vec (model .text .model , tokenizer , pooling_mode = "mean" , max_length = 512 ) #TODO: modify this
261261 print ('evaluating retrieval' )
262- with torch .no_grad ():
263- retrieval_zero_shot_metrics = retrieval_eval (model , l2v , data , epoch , args )
264- metrics .update (retrieval_zero_shot_metrics )
265- zero_shot_metrics = zero_shot_eval (model , l2v , data , epoch , args )
266- metrics .update (zero_shot_metrics )
267- print (zero_shot_metrics )
262+ # with torch.no_grad():
263+ # retrieval_zero_shot_metrics = retrieval_eval(model, l2v, data, epoch, args)
264+ # metrics.update(retrieval_zero_shot_metrics)
265+ # zero_shot_metrics = zero_shot_eval(model, l2v, data, epoch, args)
266+ # metrics.update(zero_shot_metrics)
267+ # print(zero_shot_metrics)
268268 autocast = get_autocast (args .precision )
269269 cast_dtype = get_cast_dtype (args .precision )
270270 if 'val' in data :
@@ -356,9 +356,9 @@ def evaluate(model, tokenizer, data, epoch, args, tb_writer=None):
356356 l2v = LLM2Vec (model .text .model , tokenizer , pooling_mode = "mean" , max_length = 512 ) #TODO: modify this
357357 retrieval_zero_shot_metrics = retrieval_eval (model , l2v , data , epoch , args )
358358 metrics .update (retrieval_zero_shot_metrics )
359- zero_shot_metrics = zero_shot_eval (model , l2v , data , epoch , args )
360- metrics .update (zero_shot_metrics )
361- print (zero_shot_metrics )
359+ # zero_shot_metrics = zero_shot_eval(model, l2v, data, epoch, args)
360+ # metrics.update(zero_shot_metrics)
361+ # print(zero_shot_metrics)
362362 autocast = get_autocast (args .precision )
363363 cast_dtype = get_cast_dtype (args .precision )
364364
0 commit comments