Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sparsezoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# flake8: noqa
# isort: skip_file

from .api import *
from .inference import *
from .model import *
from .objects import *
Expand Down
19 changes: 19 additions & 0 deletions src/sparsezoo/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .graphql import *
from .query_parser import *
from .utils import *
82 changes: 82 additions & 0 deletions src/sparsezoo/api/graphql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any, Dict, List, Optional

import requests

from sparsezoo.utils import BASE_API_URL

from .query_parser import QueryParser
from .utils import map_keys, to_snake_case


class GraphQLAPI:
def fetch(
self,
operation_body: str,
arguments: Optional[Dict[str, str]] = None,
fields: Optional[List[str]] = None,
url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Fetch data for models via api. Uses graphql convention of post,
not get for requests.
Input args are parsed to make a query body for the api request.
For more details on the appropriate values, please refer to the
url endpoint on the browser

:param operation_body: The data object of interest
:param arguments: Used to filter data object in the backend
:param field: the object's field of interest
"""

response_objects = self.make_request(
operation_body=operation_body,
arguments=arguments,
fields=fields,
url=url,
)

return [
map_keys(dictionary=response_object, mapper=to_snake_case)
for response_object in response_objects
]

def make_request(
self,
operation_body: str,
arguments: Optional[Dict[str, str]] = None,
fields: Optional[List[str]] = None,
url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Given the input args, parse them to a graphql appropriate format
and make an graph post request to get the desired raw response.
Raw response's keys are in camelCase, not snake_case
"""

query = QueryParser(
operation_body=operation_body, arguments=arguments, fields=fields
)

response = requests.post(
url=url or f"{BASE_API_URL}/v2/graphql", json={"query": query.query_body}
)

response.raise_for_status()
response_json = response.json()

return response_json["data"][query.operation_body]
150 changes: 150 additions & 0 deletions src/sparsezoo/api/query_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict, List, Optional

from .utils import to_camel_case


DEFAULT_MODELS_FIELDS = ["modelId", "stub"]

DEFAULT_FILES_FIELDS = ["displayName", "fileSize", "modelId", "fileType"]

DEFAULT_TRAINING_RESULTS_FIELDS = [
"datasetName",
"datasetType",
"recordedUnits",
"recordedValue",
]

DEFAULT_BENCHMARK_RESULTS_FIELDS = [
"batchSize",
"deviceInfo",
"numCores",
"recordedUnits",
"recordedValue",
]

DEPRECATED_STUB_ARGS_MAPPER = {"sub_domain": "task", "dataset": "source_dataset"}
DEFAULT_FIELDS = {
"models": DEFAULT_MODELS_FIELDS,
"files": DEFAULT_FILES_FIELDS,
"trainingResults": DEFAULT_TRAINING_RESULTS_FIELDS,
"benchmarkResults": DEFAULT_BENCHMARK_RESULTS_FIELDS,
}

QUERY_BODY = """
{{
{operation_body} {arguments}
{{
{fields}
}}
}}
"""


class QueryParser:
"""Parse the class input arg fields to be used for graphql post requests"""

def __init__(
self,
operation_body: str,
arguments: Optional[Dict[str, str]] = None,
fields: Optional[List[str]] = None,
):
self._operation_body = operation_body
self._arguments = arguments
self._fields = fields
self._query_body = None

self._parse()

def _parse(self):
"""Parse to a string compatible with graphql requst body"""

self._parse_operation_body()
self._parse_arguments()
self._parse_fields()
self._build_query_body()

def _parse_operation_body(self) -> None:
self._operation_body = to_camel_case(self._operation_body)

def _parse_arguments(self) -> None:
"""Transform deprecated stub args and convert to camel case"""
parsed_arguments = ""
arguments = self.arguments or {}

for argument, value in arguments.items():
if value is not None:
contemporary_key = DEPRECATED_STUB_ARGS_MAPPER.get(argument, argument)
camel_case_key = to_camel_case(contemporary_key)

# single, double quotes matters
if isinstance(value, str):
parsed_arguments += f'{camel_case_key}: "{value}",'
else:
parsed_arguments += f"{camel_case_key}: {value},"

if parsed_arguments:
parsed_arguments = "(" + parsed_arguments + ")"
self._arguments = parsed_arguments

def _parse_fields(self) -> None:
fields = self.fields or DEFAULT_FIELDS.get(self.operation_body)
self.fields = " ".join(map(to_camel_case, fields))

def _build_query_body(self) -> None:
self.query_body = QUERY_BODY.format(
operation_body=self.operation_body,
arguments=self.arguments,
fields=self.fields,
)

@property
def operation_body(self) -> str:
"""Return the query operation body"""
return self._operation_body

@operation_body.setter
def operation_body(self, operation_body: str) -> None:
self._operation_body = operation_body

@property
def arguments(self) -> str:
"""Return the query arguments"""
return self._arguments

@arguments.setter
def arguments(self, arguments: str) -> None:
self._operation_body = arguments

@property
def fields(self) -> str:
"""Return the query fields"""
return self._fields

@fields.setter
def fields(self, fields: str) -> None:
self._fields = fields

@property
def query_body(self) -> str:
"""Return the query body"""
return self._query_body

@query_body.setter
def query_body(self, query_body: str) -> None:
self._query_body = query_body
38 changes: 38 additions & 0 deletions src/sparsezoo/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict


def to_camel_case(string: str):
"Convert string to camel case"
components = string.split("_")
return components[0] + "".join(word.title() for word in components[1:])


def to_snake_case(string: str):
"Convert string to snake case"
return "".join(
[
"_" + character.lower() if character.isupper() else character
for character in string
]
).lstrip("_")


def map_keys(
dictionary: Dict[str, str], mapper: Callable[[str], str]
) -> Dict[str, str]:
"""Given a dictionary, update its key to a given mapper callable"""
return {mapper(key): value for key, value in dictionary.items()}
4 changes: 3 additions & 1 deletion src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def initialize_model_from_stub(
path = os.path.join(SAVE_DIR, model_id)
if not files:
raise ValueError(f"No files found for given stub {stub}")

url = os.path.dirname(files[0].get("url"))

return files, path, url, validation_results, size

@staticmethod
Expand Down Expand Up @@ -543,7 +545,7 @@ def _file_from_files(
elif len(files_found) == 1:
return files_found[0]

elif display_name == "model.onnx" and len(files_found) == 2:
elif display_name == "model.onnx":
# `model.onnx` file may be found twice:
# - directly in the root directory
# - inside `deployment` directory
Expand Down
16 changes: 12 additions & 4 deletions src/sparsezoo/model/result_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ class ModelResult(BaseModel):
Base class to store common result information
"""

result_type: str = Field(
description="A string representing the type of "
"result ex `training`, `inference`, etc"
)
recorded_value: float = Field(description="The float value of the result")
recorded_units: str = Field(description="The unit in which result is specified")

Expand All @@ -36,6 +32,12 @@ class ValidationResult(ModelResult):
A class holding information for validation results
"""

result_type: str = Field(
description="A string representing the type of "
"result ex `training`, `inference`, etc",
default="inference",
)

dataset_type: str = Field(
description="A string representing the type of "
"dataset used ex. `upstream`, `downstream`"
Expand All @@ -50,6 +52,12 @@ class ThroughputResults(ModelResult):
A class holding information for throughput based results
"""

result_type: str = Field(
description="A string representing the type of "
"result ex `training`, `inference`, etc",
default="training",
)

device_info: str = Field(description="The device current result was measured on")
num_cores: int = Field(
description="Number of cores used while measuring " "this result"
Expand Down
Loading