11# from peft import PeftConfig, PeftModel
22# from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, AutoConfig
3+ import os
34from typing import Literal , Optional
45
56from dsp .modules .lm import LM
67
78# from dsp.modules.finetuning.finetune_hf import preprocess_prompt
89
10+
911def openai_to_hf (** kwargs ):
1012 hf_kwargs = {}
1113 for k , v in kwargs .items ():
@@ -26,8 +28,19 @@ def openai_to_hf(**kwargs):
2628
2729
2830class HFModel (LM ):
29- def __init__ (self , model : str , checkpoint : Optional [str ] = None , is_client : bool = False ,
30- hf_device_map : Literal ["auto" , "balanced" , "balanced_low_0" , "sequential" ] = "auto" ):
31+ def __init__ (
32+ self ,
33+ model : str ,
34+ checkpoint : Optional [str ] = None ,
35+ is_client : bool = False ,
36+ hf_device_map : Literal [
37+ "auto" ,
38+ "balanced" ,
39+ "balanced_low_0" ,
40+ "sequential" ,
41+ ] = "auto" ,
42+ token : Optional [str ] = None ,
43+ ):
3144 """wrapper for Hugging Face models
3245
3346 Args:
@@ -42,6 +55,10 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
4255 self .provider = "hf"
4356 self .is_client = is_client
4457 self .device_map = hf_device_map
58+
59+ hf_autoconfig_kwargs = dict (token = token or os .environ .get ("HF_TOKEN" ))
60+ hf_autotokenizer_kwargs = hf_autoconfig_kwargs .copy ()
61+ hf_automodel_kwargs = hf_autoconfig_kwargs .copy ()
4562 if not self .is_client :
4663 try :
4764 import torch
@@ -52,40 +69,68 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
5269 ) from exc
5370 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5471 try :
55- architecture = AutoConfig .from_pretrained (model ).__dict__ ["architectures" ][0 ]
56- self .encoder_decoder_model = ("ConditionalGeneration" in architecture ) or ("T5WithLMHeadModel" in architecture )
72+ architecture = AutoConfig .from_pretrained (
73+ model ,
74+ ** hf_autoconfig_kwargs ,
75+ ).__dict__ ["architectures" ][0 ]
76+ self .encoder_decoder_model = ("ConditionalGeneration" in architecture ) or (
77+ "T5WithLMHeadModel" in architecture
78+ )
5779 self .decoder_only_model = ("CausalLM" in architecture ) or ("GPT2LMHeadModel" in architecture )
58- assert self .encoder_decoder_model or self .decoder_only_model , f"Unknown HuggingFace model class: { model } "
59- self .tokenizer = AutoTokenizer .from_pretrained (model if checkpoint is None else checkpoint )
80+ assert (
81+ self .encoder_decoder_model or self .decoder_only_model
82+ ), f"Unknown HuggingFace model class: { model } "
83+ self .tokenizer = AutoTokenizer .from_pretrained (
84+ model if checkpoint is None else checkpoint ,
85+ ** hf_autotokenizer_kwargs ,
86+ )
6087
6188 self .rationale = True
6289 AutoModelClass = AutoModelForSeq2SeqLM if self .encoder_decoder_model else AutoModelForCausalLM
6390 if checkpoint :
6491 # with open(os.path.join(checkpoint, '..', 'compiler_config.json'), 'r') as f:
6592 # config = json.load(f)
66- self .rationale = False # config['rationale']
93+ self .rationale = False # config['rationale']
6794 # if config['peft']:
6895 # peft_config = PeftConfig.from_pretrained(checkpoint)
6996 # self.model = AutoModelClass.from_pretrained(peft_config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map=hf_device_map)
7097 # self.model = PeftModel.from_pretrained(self.model, checkpoint)
7198 # else:
7299 if self .device_map :
73- self .model = AutoModelClass .from_pretrained (checkpoint , device_map = self .device_map )
100+ self .model = AutoModelClass .from_pretrained (
101+ checkpoint ,
102+ device_map = self .device_map ,
103+ ** hf_automodel_kwargs ,
104+ )
74105 else :
75- self .model = AutoModelClass .from_pretrained (checkpoint ).to (self .device )
106+ self .model = AutoModelClass .from_pretrained (
107+ checkpoint ,
108+ ** hf_automodel_kwargs ,
109+ ).to (self .device )
76110 else :
77111 if self .device_map :
78- self .model = AutoModelClass .from_pretrained (model , device_map = self .device_map )
112+ self .model = AutoModelClass .from_pretrained (
113+ model ,
114+ device_map = self .device_map ,
115+ ** hf_automodel_kwargs ,
116+ )
79117 else :
80- self .model = AutoModelClass .from_pretrained (model ).to (self .device )
118+ self .model = AutoModelClass .from_pretrained (
119+ model ,
120+ ** hf_automodel_kwargs ,
121+ ).to (self .device )
81122 self .drop_prompt_from_output = False
82123 except ValueError :
83124 self .model = AutoModelForCausalLM .from_pretrained (
84125 model if checkpoint is None else checkpoint ,
85126 device_map = self .device_map ,
127+ ** hf_automodel_kwargs ,
86128 )
87129 self .drop_prompt_from_output = True
88- self .tokenizer = AutoTokenizer .from_pretrained (model )
130+ self .tokenizer = AutoTokenizer .from_pretrained (
131+ model ,
132+ ** hf_autotokenizer_kwargs ,
133+ )
89134 self .drop_prompt_from_output = True
90135 self .history = []
91136
@@ -111,7 +156,7 @@ def _generate(self, prompt, **kwargs):
111156 # print(prompt)
112157 if isinstance (prompt , dict ):
113158 try :
114- prompt = prompt [' messages' ][0 ][' content' ]
159+ prompt = prompt [" messages" ][0 ][" content" ]
115160 except (KeyError , IndexError , TypeError ):
116161 print ("Failed to extract 'content' from the prompt." )
117162 inputs = self .tokenizer (prompt , return_tensors = "pt" ).to (self .device )
@@ -121,10 +166,7 @@ def _generate(self, prompt, **kwargs):
121166 if self .drop_prompt_from_output :
122167 input_length = inputs .input_ids .shape [1 ]
123168 outputs = outputs [:, input_length :]
124- completions = [
125- {"text" : c }
126- for c in self .tokenizer .batch_decode (outputs , skip_special_tokens = True )
127- ]
169+ completions = [{"text" : c } for c in self .tokenizer .batch_decode (outputs , skip_special_tokens = True )]
128170 response = {
129171 "prompt" : prompt ,
130172 "choices" : completions ,
0 commit comments