Skip to content

Commit d6f18cc

Browse files
committed
adding support for nscale inference provider
1 parent 785835f commit d6f18cc

File tree

5 files changed

+109
-1
lines changed

5 files changed

+109
-1
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class InferenceClient:
133133
path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api)
134134
documentation for details). When passing a URL as `model`, the client will not append any suffix path to it.
135135
provider (`str`, *optional*):
136-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
136+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"fireworks-ai"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`.
137137
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
138138
If model is a URL or `base_url` is passed, then `provider` is not used.
139139
token (`str`, *optional*):

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
2323
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
2424
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
25+
from .nscale import NscaleChatCompletion, NscaleTextToImageTask
2526
from .openai import OpenAIConversationalTask
2627
from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
2728
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
@@ -41,6 +42,7 @@
4142
"hyperbolic",
4243
"nebius",
4344
"novita",
45+
"nscale",
4446
"openai",
4547
"replicate",
4648
"sambanova",
@@ -111,6 +113,10 @@
111113
"conversational": NovitaConversationalTask(),
112114
"text-to-video": NovitaTextToVideoTask(),
113115
},
116+
"nscale": {
117+
"conversational": NscaleChatCompletion(),
118+
"text-to-image": NscaleTextToImageTask(),
119+
},
114120
"openai": {
115121
"conversational": OpenAIConversationalTask(),
116122
},

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"hf-inference": {},
2828
"hyperbolic": {},
2929
"nebius": {},
30+
"nscale": {},
3031
"replicate": {},
3132
"sambanova": {},
3233
"together": {},
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Any, Dict, Optional, Union
2+
import base64
3+
4+
from huggingface_hub.inference._common import RequestParameters, _as_dict
5+
from ._common import (
6+
TaskProviderHelper,
7+
BaseConversationalTask,
8+
filter_none,
9+
)
10+
11+
class NscaleTask(TaskProviderHelper):
12+
def __init__(self, task: str):
13+
super().__init__(provider="nscale", base_url="https://inference.api.nscale.com", task=task)
14+
15+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
16+
if self.task == "text-to-image":
17+
return "/v1/images/generations"
18+
elif self.task == "conversational":
19+
return "/v1/chat/completions"
20+
raise ValueError(f"Unsupported task '{self.task}' for Nscale API.")
21+
22+
class NscaleChatCompletion(BaseConversationalTask):
23+
def __init__(self):
24+
super().__init__(provider="nscale", base_url="https://inference.api.nscale.com")
25+
26+
class NscaleTextToImageTask(NscaleTask):
27+
def __init__(self):
28+
super().__init__("text-to-image")
29+
30+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
31+
# Combine all parameters except inputs and parameters
32+
parameters = filter_none(parameters)
33+
if "width" in parameters and "height" in parameters:
34+
parameters["size"] = f"{parameters.pop('width')}x{parameters.pop('height')}"
35+
if "num_inference_steps" in parameters:
36+
parameters.pop("num_inference_steps")
37+
if "cfg_scale" in parameters:
38+
parameters.pop("cfg_scale")
39+
payload = {
40+
"response_format": "b64_json",
41+
"prompt": inputs,
42+
"model": mapped_model,
43+
**parameters,
44+
}
45+
return payload
46+
47+
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
48+
response_dict = _as_dict(response)
49+
return base64.b64decode(response_dict["data"][0]["b64_json"])
50+

tests/test_inference_providers.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from huggingface_hub.inference._providers.hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
3434
from huggingface_hub.inference._providers.nebius import NebiusTextToImageTask
3535
from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask
36+
from huggingface_hub.inference._providers.nscale import NscaleChatCompletion, NscaleTextToImageTask
3637
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
3738
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
3839
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
@@ -829,6 +830,56 @@ def test_prepare_url_conversational(self):
829830
url = helper._prepare_url("novita_token", "username/repo_name")
830831
assert url == "https://api.novita.ai/v3/openai/chat/completions"
831832

833+
class TestNscaleProvider:
834+
def test_prepare_route_text_to_image(self):
835+
helper = NscaleTextToImageTask()
836+
assert helper._prepare_route("model_name", "api_key") == "/v1/images/generations"
837+
838+
def test_prepare_route_chat_completion(self):
839+
helper = NscaleChatCompletion()
840+
assert helper._prepare_route("model_name", "api_key") == "/v1/chat/completions"
841+
842+
def test_prepare_payload_with_size_conversion(self):
843+
helper = NscaleTextToImageTask()
844+
payload = helper._prepare_payload_as_dict(
845+
"a beautiful landscape",
846+
{
847+
"width": 512,
848+
"height": 512,
849+
},
850+
"stabilityai/stable-diffusion-xl-base-1.0",
851+
)
852+
assert payload == {
853+
"prompt": "a beautiful landscape",
854+
"size": "512x512",
855+
"response_format": "b64_json",
856+
"model": "stabilityai/stable-diffusion-xl-base-1.0",
857+
}
858+
859+
def test_prepare_payload_as_dict(self):
860+
helper = NscaleTextToImageTask()
861+
payload = helper._prepare_payload_as_dict(
862+
"a beautiful landscape",
863+
{
864+
"width": 1024,
865+
"height": 768,
866+
"cfg_scale": 7.5,
867+
"num_inference_steps": 50,
868+
},
869+
"stabilityai/stable-diffusion-xl-base-1.0",
870+
)
871+
assert "width" not in payload
872+
assert "height" not in payload
873+
assert "num_inference_steps" not in payload
874+
assert "cfg_scale" not in payload
875+
assert payload["size"] == "1024x768"
876+
assert payload["model"] == "stabilityai/stable-diffusion-xl-base-1.0"
877+
878+
def test_text_to_image_get_response(self):
879+
helper = NscaleTextToImageTask()
880+
response = helper.get_response({"data": [{"b64_json": base64.b64encode(b"image_bytes").decode()}]})
881+
assert response == b"image_bytes"
882+
832883

833884
class TestOpenAIProvider:
834885
def test_prepare_url(self):

0 commit comments

Comments
 (0)