Skip to content

Commit 8fa12b0

Browse files
khushgxawni
andauthored
Adapters loading (ml-explore#902)
* Added functionality to load in adapters through post-requests so you do not need to restart the server * ran pre-commit * nits * fix test --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 85dc76f commit 8fa12b0

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

llms/mlx_lm/SERVER.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,10 @@ curl localhost:8080/v1/chat/completions \
7878
- `logprobs`: (Optional) An integer specifying the number of top tokens and
7979
corresponding log probabilities to return for each output in the generated
8080
sequence. If set, this can be any value between 1 and 10, inclusive.
81+
82+
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
83+
If the path is local is must be relative to the directory the server was
84+
started in.
85+
86+
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
87+
rlative to the directory the server was started in.

llms/mlx_lm/server.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ def _validate_model_path(self, model_path: str):
9797
"Local models must be relative to the current working dir."
9898
)
9999

100-
def load(self, model_path):
101-
if self.model_key == model_path:
100+
# Added in adapter_path to load dynamically
101+
def load(self, model_path, adapter_path=None):
102+
if self.model_key == (model_path, adapter_path):
102103
return self.model, self.tokenizer
103104

104105
# Remove the old model if it exists.
@@ -116,18 +117,22 @@ def load(self, model_path):
116117
if model_path == "default_model" and self.cli_args.model is not None:
117118
model, tokenizer = load(
118119
self.cli_args.model,
119-
adapter_path=self.cli_args.adapter_path,
120+
adapter_path=(
121+
adapter_path if adapter_path else self.cli_args.adapter_path
122+
), # if the user doesn't change the model but adds an adapter path
120123
tokenizer_config=tokenizer_config,
121124
)
122125
else:
123126
self._validate_model_path(model_path)
124-
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
127+
model, tokenizer = load(
128+
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
129+
)
125130

126131
if self.cli_args.use_default_chat_template:
127132
if tokenizer.chat_template is None:
128133
tokenizer.chat_template = tokenizer.default_chat_template
129134

130-
self.model_key = model_path
135+
self.model_key = (model_path, adapter_path)
131136
self.model = model
132137
self.tokenizer = tokenizer
133138

@@ -193,6 +198,7 @@ def do_POST(self):
193198
self.stream = self.body.get("stream", False)
194199
self.stream_options = self.body.get("stream_options", None)
195200
self.requested_model = self.body.get("model", "default_model")
201+
self.adapter = self.body.get("adapters", None)
196202
self.max_tokens = self.body.get("max_tokens", 100)
197203
self.temperature = self.body.get("temperature", 1.0)
198204
self.top_p = self.body.get("top_p", 1.0)
@@ -204,7 +210,9 @@ def do_POST(self):
204210

205211
# Load the model if needed
206212
try:
207-
self.model, self.tokenizer = self.model_provider.load(self.requested_model)
213+
self.model, self.tokenizer = self.model_provider.load(
214+
self.requested_model, self.adapter
215+
)
208216
except:
209217
self._set_completion_headers(404)
210218
self.end_headers()
@@ -278,6 +286,8 @@ def validate_model_parameters(self):
278286

279287
if not isinstance(self.requested_model, str):
280288
raise ValueError("model must be a string")
289+
if self.adapter is not None and not isinstance(self.adapter, str):
290+
raise ValueError("adapter must be a string")
281291

282292
def generate_response(
283293
self,

llms/tests/test_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self):
1212
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
1313
self.model, self.tokenizer = load(HF_MODEL_PATH)
1414

15-
def load(self, model):
15+
def load(self, model, adapter=None):
1616
assert model in ["default_model", "chat_model"]
1717
return self.model, self.tokenizer
1818

0 commit comments

Comments
 (0)