@@ -97,8 +97,9 @@ def _validate_model_path(self, model_path: str):
9797 "Local models must be relative to the current working dir."
9898 )
9999
100- def load (self , model_path ):
101- if self .model_key == model_path :
100+ # Added in adapter_path to load dynamically
101+ def load (self , model_path , adapter_path = None ):
102+ if self .model_key == (model_path , adapter_path ):
102103 return self .model , self .tokenizer
103104
104105 # Remove the old model if it exists.
@@ -116,18 +117,22 @@ def load(self, model_path):
116117 if model_path == "default_model" and self .cli_args .model is not None :
117118 model , tokenizer = load (
118119 self .cli_args .model ,
119- adapter_path = self .cli_args .adapter_path ,
120+ adapter_path = (
121+ adapter_path if adapter_path else self .cli_args .adapter_path
122+ ), # if the user doesn't change the model but adds an adapter path
120123 tokenizer_config = tokenizer_config ,
121124 )
122125 else :
123126 self ._validate_model_path (model_path )
124- model , tokenizer = load (model_path , tokenizer_config = tokenizer_config )
127+ model , tokenizer = load (
128+ model_path , adapter_path = adapter_path , tokenizer_config = tokenizer_config
129+ )
125130
126131 if self .cli_args .use_default_chat_template :
127132 if tokenizer .chat_template is None :
128133 tokenizer .chat_template = tokenizer .default_chat_template
129134
130- self .model_key = model_path
135+ self .model_key = ( model_path , adapter_path )
131136 self .model = model
132137 self .tokenizer = tokenizer
133138
@@ -193,6 +198,7 @@ def do_POST(self):
193198 self .stream = self .body .get ("stream" , False )
194199 self .stream_options = self .body .get ("stream_options" , None )
195200 self .requested_model = self .body .get ("model" , "default_model" )
201+ self .adapter = self .body .get ("adapters" , None )
196202 self .max_tokens = self .body .get ("max_tokens" , 100 )
197203 self .temperature = self .body .get ("temperature" , 1.0 )
198204 self .top_p = self .body .get ("top_p" , 1.0 )
@@ -204,7 +210,9 @@ def do_POST(self):
204210
205211 # Load the model if needed
206212 try :
207- self .model , self .tokenizer = self .model_provider .load (self .requested_model )
213+ self .model , self .tokenizer = self .model_provider .load (
214+ self .requested_model , self .adapter
215+ )
208216 except :
209217 self ._set_completion_headers (404 )
210218 self .end_headers ()
@@ -278,6 +286,8 @@ def validate_model_parameters(self):
278286
279287 if not isinstance (self .requested_model , str ):
280288 raise ValueError ("model must be a string" )
289+ if self .adapter is not None and not isinstance (self .adapter , str ):
290+ raise ValueError ("adapter must be a string" )
281291
282292 def generate_response (
283293 self ,
0 commit comments