Skip to content

Commit cc84eb2

Browse files
add dtype support
1 parent 2d6cbf6 commit cc84eb2

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

docext/app/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def main(
270270
max_img_size: int,
271271
concurrency_limit: int,
272272
share: bool,
273+
dtype: str,
273274
):
274275
vllm_server = None
275276
if model_name.startswith("hosted_vllm/") and (
@@ -290,6 +291,7 @@ def main(
290291
gpu_memory_utilization=gpu_memory_utilization,
291292
max_num_imgs=max_num_imgs,
292293
vllm_start_timeout=vllm_start_timeout,
294+
dtype=dtype,
293295
)
294296
vllm_server.run_in_background()
295297

@@ -356,6 +358,7 @@ def docext_app():
356358
args.max_img_size,
357359
args.concurrency_limit,
358360
args.share,
361+
args.dtype,
359362
)
360363

361364

docext/app/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,10 @@ def parse_args():
7979
default=1,
8080
help="Maximum number of concurrent PDF to markdown conversion requests. Higher values allow more users to process documents simultaneously but require more memory and compute resources.",
8181
)
82+
parser.add_argument(
83+
"--dtype",
84+
type=str,
85+
default="bfloat16",
86+
help="Data type for the model. Can be 'bfloat16' or 'float16'.",
87+
)
8288
return parser.parse_args()

docext/core/vllm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
gpu_memory_utilization: float = 0.98,
2121
max_num_imgs: int = 5,
2222
vllm_start_timeout: int = 300,
23+
dtype: str = "bfloat16",
2324
):
2425
self.host = host
2526
self.port = port
@@ -30,12 +31,18 @@ def __init__(
3031
self.server_process = None
3132
self.url = f"http://{self.host}:{self.port}/v1/models"
3233
self.vllm_start_timeout = vllm_start_timeout
34+
self.dtype = dtype
35+
assert self.dtype in [
36+
"bfloat16",
37+
"float16",
38+
], "Invalid dtype. Must be 'bfloat16' or 'float16'."
3339

3440
def start_server(self):
3541
"""Start the vLLM server in a background thread."""
3642
logger.info("Starting vLLM server...")
3743
# Command to start the vLLM server
3844
is_awq = "awq" in self.model_name.lower()
45+
dtype = dtype if not is_awq else "float16"
3946
command = [
4047
"vllm",
4148
"serve",
@@ -45,7 +52,7 @@ def start_server(self):
4552
"--port",
4653
str(self.port),
4754
"--dtype",
48-
"bfloat16" if not is_awq else "float16",
55+
dtype,
4956
"--limit-mm-per-prompt",
5057
f"image={self.max_num_imgs},video=0",
5158
"--served-model-name",

0 commit comments

Comments
 (0)