-
Notifications
You must be signed in to change notification settings - Fork 1k
Add GitHub Models provider #2114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
DouweM
merged 7 commits into
pydantic:main
from
sgoedecke:sgoedecke/add-github-models-provider
Jul 4, 2025
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
6255fc4
Add GitHub Models provider
sgoedecke 51c0a5e
Rename provider to github and update profile logic
sgoedecke 065dfa8
Split out test case for more explicit coverage
sgoedecke 2304fc6
Add test case for unknown provider
sgoedecke 743374e
Mock out profiles for testing
sgoedecke 594ec23
Remove duplicate import
sgoedecke 3a22e28
Add test case for unknown publisher
sgoedecke File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
import os | ||
from typing import overload | ||
|
||
from httpx import AsyncClient as AsyncHTTPClient | ||
|
||
from pydantic_ai.exceptions import UserError | ||
from pydantic_ai.models import cached_async_http_client | ||
from pydantic_ai.profiles import ModelProfile | ||
from pydantic_ai.profiles.cohere import cohere_model_profile | ||
from pydantic_ai.profiles.deepseek import deepseek_model_profile | ||
from pydantic_ai.profiles.grok import grok_model_profile | ||
from pydantic_ai.profiles.meta import meta_model_profile | ||
from pydantic_ai.profiles.mistral import mistral_model_profile | ||
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile | ||
from pydantic_ai.providers import Provider | ||
|
||
try: | ||
from openai import AsyncOpenAI | ||
except ImportError as _import_error: # pragma: no cover | ||
raise ImportError( | ||
'Please install the `openai` package to use the GitHub Models provider, ' | ||
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' | ||
) from _import_error | ||
|
||
|
||
class GitHubProvider(Provider[AsyncOpenAI]): | ||
"""Provider for GitHub Models API. | ||
|
||
GitHub Models provides access to various AI models through an OpenAI-compatible API. | ||
See <https://docs.github.com/en/github-models> for more information. | ||
""" | ||
|
||
@property | ||
def name(self) -> str: | ||
return 'github' | ||
|
||
@property | ||
def base_url(self) -> str: | ||
return 'https://models.github.ai/inference' | ||
|
||
@property | ||
def client(self) -> AsyncOpenAI: | ||
return self._client | ||
|
||
def model_profile(self, model_name: str) -> ModelProfile | None: | ||
provider_to_profile = { | ||
'xai': grok_model_profile, | ||
'meta': meta_model_profile, | ||
'microsoft': openai_model_profile, | ||
'mistral-ai': mistral_model_profile, | ||
'cohere': cohere_model_profile, | ||
'deepseek': deepseek_model_profile, | ||
} | ||
|
||
profile = None | ||
|
||
# If the model name does not contain a provider prefix, we assume it's an OpenAI model | ||
if '/' not in model_name: | ||
return openai_model_profile(model_name) | ||
|
||
provider, model_name = model_name.lower().split('/', 1) | ||
if provider in provider_to_profile: | ||
model_name, *_ = model_name.split(':', 1) # drop tags | ||
profile = provider_to_profile[provider](model_name) | ||
|
||
# As GitHubProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer, | ||
# we need to maintain that behavior unless json_schema_transformer is set explicitly | ||
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) | ||
|
||
@overload | ||
def __init__(self) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, api_key: str) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... | ||
|
||
def __init__( | ||
self, | ||
*, | ||
api_key: str | None = None, | ||
openai_client: AsyncOpenAI | None = None, | ||
http_client: AsyncHTTPClient | None = None, | ||
) -> None: | ||
"""Create a new GitHub Models provider. | ||
|
||
Args: | ||
api_key: The GitHub token to use for authentication. If not provided, the `GITHUB_API_KEY` | ||
environment variable will be used if available. | ||
openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`. | ||
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. | ||
""" | ||
api_key = api_key or os.getenv('GITHUB_API_KEY') | ||
if not api_key and openai_client is None: | ||
raise UserError( | ||
'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`' | ||
' to use the GitHub Models provider.' | ||
) | ||
|
||
if openai_client is not None: | ||
self._client = openai_client | ||
elif http_client is not None: | ||
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) | ||
else: | ||
http_client = cached_async_http_client(provider='github') | ||
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import re | ||
|
||
import httpx | ||
import pytest | ||
from pytest_mock import MockerFixture | ||
|
||
from pydantic_ai.exceptions import UserError | ||
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer | ||
from pydantic_ai.profiles.cohere import cohere_model_profile | ||
from pydantic_ai.profiles.deepseek import deepseek_model_profile | ||
from pydantic_ai.profiles.grok import grok_model_profile | ||
from pydantic_ai.profiles.meta import meta_model_profile | ||
from pydantic_ai.profiles.mistral import mistral_model_profile | ||
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile | ||
|
||
from ..conftest import TestEnv, try_import | ||
|
||
with try_import() as imports_successful: | ||
import openai | ||
|
||
from pydantic_ai.providers.github import GitHubProvider | ||
|
||
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') | ||
|
||
|
||
def test_github_provider(): | ||
provider = GitHubProvider(api_key='ghp_test_token') | ||
assert provider.name == 'github' | ||
assert provider.base_url == 'https://models.github.ai/inference' | ||
assert isinstance(provider.client, openai.AsyncOpenAI) | ||
assert provider.client.api_key == 'ghp_test_token' | ||
|
||
|
||
def test_github_provider_need_api_key(env: TestEnv) -> None: | ||
env.remove('GITHUB_API_KEY') | ||
with pytest.raises( | ||
UserError, | ||
match=re.escape( | ||
'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`' | ||
' to use the GitHub Models provider.' | ||
), | ||
): | ||
GitHubProvider() | ||
|
||
|
||
def test_github_provider_pass_http_client() -> None: | ||
http_client = httpx.AsyncClient() | ||
provider = GitHubProvider(http_client=http_client, api_key='ghp_test_token') | ||
assert provider.client._client == http_client # type: ignore[reportPrivateUsage] | ||
|
||
|
||
def test_github_pass_openai_client() -> None: | ||
openai_client = openai.AsyncOpenAI(api_key='ghp_test_token') | ||
provider = GitHubProvider(openai_client=openai_client) | ||
assert provider.client == openai_client | ||
|
||
|
||
def test_github_provider_model_profile(mocker: MockerFixture): | ||
provider = GitHubProvider(api_key='ghp_test_token') | ||
|
||
ns = 'pydantic_ai.providers.github' | ||
meta_model_profile_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile) | ||
deepseek_model_profile_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile) | ||
mistral_model_profile_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile) | ||
cohere_model_profile_mock = mocker.patch(f'{ns}.cohere_model_profile', wraps=cohere_model_profile) | ||
grok_model_profile_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile) | ||
openai_model_profile_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile) | ||
|
||
meta_profile = provider.model_profile('meta/Llama-3.2-11B-Vision-Instruct') | ||
meta_model_profile_mock.assert_called_with('llama-3.2-11b-vision-instruct') | ||
assert meta_profile is not None | ||
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
|
||
meta_profile = provider.model_profile('meta/Llama-3.1-405B-Instruct') | ||
meta_model_profile_mock.assert_called_with('llama-3.1-405b-instruct') | ||
assert meta_profile is not None | ||
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
|
||
deepseek_profile = provider.model_profile('deepseek/deepseek-coder') | ||
deepseek_model_profile_mock.assert_called_with('deepseek-coder') | ||
assert deepseek_profile is not None | ||
assert deepseek_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
||
mistral_profile = provider.model_profile('mistral-ai/mixtral-8x7b-instruct') | ||
mistral_model_profile_mock.assert_called_with('mixtral-8x7b-instruct') | ||
assert mistral_profile is not None | ||
assert mistral_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
||
cohere_profile = provider.model_profile('cohere/command-r-plus') | ||
cohere_model_profile_mock.assert_called_with('command-r-plus') | ||
assert cohere_profile is not None | ||
assert cohere_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
||
grok_profile = provider.model_profile('xai/grok-3-mini') | ||
grok_model_profile_mock.assert_called_with('grok-3-mini') | ||
assert grok_profile is not None | ||
assert grok_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
||
microsoft_profile = provider.model_profile('microsoft/Phi-3.5-mini-instruct') | ||
openai_model_profile_mock.assert_called_with('phi-3.5-mini-instruct') | ||
assert microsoft_profile is not None | ||
assert microsoft_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
||
unknown_profile = provider.model_profile('some-unknown-model') | ||
openai_model_profile_mock.assert_called_with('some-unknown-model') | ||
assert unknown_profile is not None | ||
assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
||
unknown_profile_with_prefix = provider.model_profile('unknown-publisher/some-unknown-model') | ||
openai_model_profile_mock.assert_called_with('some-unknown-model') | ||
assert unknown_profile_with_prefix is not None | ||
assert unknown_profile_with_prefix.json_schema_transformer == OpenAIJsonSchemaTransformer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.