Skip to content

Add GitHub Models provider #2114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/models/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,33 @@ agent = Agent(model)
...
```

### GitHub Models

To use [GitHub Models](https://docs.github.com/en/github-models), you'll need a GitHub personal access token with the `models: read` permission.

Once you have the token, you can use it with the [`GitHubProvider`][pydantic_ai.providers.github.GitHubProvider]:

```python
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.github import GitHubProvider

model = OpenAIModel(
'xai/grok-3-mini', # GitHub Models uses prefixed model names
provider=GitHubProvider(api_key='your-github-token'),
)
agent = Agent(model)
...
```

You can also set the `GITHUB_API_KEY` environment variable:

```bash
export GITHUB_API_KEY='your-github-token'
```

GitHub Models supports various model families with different prefixes. You can see the full list on the [GitHub Marketplace](https://github.com/marketplace?type=models) or the public [catalog endpoint](https://models.github.ai/catalog/models).

### Perplexity

Follow the Perplexity [getting started](https://docs.perplexity.ai/guides/getting-started)
Expand Down
12 changes: 11 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,17 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
from .cohere import CohereModel

return CohereModel(model_name, provider=provider)
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku'):
elif provider in (
'openai',
'deepseek',
'azure',
'openrouter',
'grok',
'fireworks',
'together',
'heroku',
'github',
):
from .openai import OpenAIModel

return OpenAIModel(model_name, provider=provider)
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def __init__(
self,
model_name: OpenAIModelName,
*,
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku']
provider: Literal[
'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'
]
| Provider[AsyncOpenAI] = 'openai',
profile: ModelProfileSpec | None = None,
system_prompt_role: OpenAISystemPromptRole | None = None,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
from .heroku import HerokuProvider

return HerokuProvider
elif provider == 'github':
from .github import GitHubProvider

return GitHubProvider
else: # pragma: no cover
raise ValueError(f'Unknown provider: {provider}')

Expand Down
112 changes: 112 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations as _annotations

import os
from typing import overload

from httpx import AsyncClient as AsyncHTTPClient

from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles import ModelProfile
from pydantic_ai.profiles.cohere import cohere_model_profile
from pydantic_ai.profiles.deepseek import deepseek_model_profile
from pydantic_ai.profiles.grok import grok_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
from pydantic_ai.providers import Provider

try:
from openai import AsyncOpenAI
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `openai` package to use the GitHub Models provider, '
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
) from _import_error


class GitHubProvider(Provider[AsyncOpenAI]):
"""Provider for GitHub Models API.

GitHub Models provides access to various AI models through an OpenAI-compatible API.
See <https://docs.github.com/en/github-models> for more information.
"""

@property
def name(self) -> str:
return 'github'

@property
def base_url(self) -> str:
return 'https://models.github.ai/inference'

@property
def client(self) -> AsyncOpenAI:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
provider_to_profile = {
'xai': grok_model_profile,
'meta': meta_model_profile,
'microsoft': openai_model_profile,
'mistral-ai': mistral_model_profile,
'cohere': cohere_model_profile,
'deepseek': deepseek_model_profile,
}

profile = None

# If the model name does not contain a provider prefix, we assume it's an OpenAI model
if '/' not in model_name:
return openai_model_profile(model_name)

provider, model_name = model_name.lower().split('/', 1)
if provider in provider_to_profile:
model_name, *_ = model_name.split(':', 1) # drop tags
profile = provider_to_profile[provider](model_name)

# As GitHubProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
# we need to maintain that behavior unless json_schema_transformer is set explicitly
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)

@overload
def __init__(self) -> None: ...

@overload
def __init__(self, *, api_key: str) -> None: ...

@overload
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...

@overload
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
) -> None:
"""Create a new GitHub Models provider.

Args:
api_key: The GitHub token to use for authentication. If not provided, the `GITHUB_API_KEY`
environment variable will be used if available.
openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
api_key = api_key or os.getenv('GITHUB_API_KEY')
if not api_key and openai_client is None:
raise UserError(
'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`'
' to use the GitHub Models provider.'
)

if openai_client is not None:
self._client = openai_client
elif http_client is not None:
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
else:
http_client = cached_async_http_client(provider='github')
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
8 changes: 8 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
'bedrock',
'BedrockConverseModel',
),
(
'GITHUB_API_KEY',
'github:xai/grok-3-mini',
'xai/grok-3-mini',
'github',
'github',
'OpenAIModel',
),
]


Expand Down
112 changes: 112 additions & 0 deletions tests/providers/test_github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import re

import httpx
import pytest
from pytest_mock import MockerFixture

from pydantic_ai.exceptions import UserError
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
from pydantic_ai.profiles.cohere import cohere_model_profile
from pydantic_ai.profiles.deepseek import deepseek_model_profile
from pydantic_ai.profiles.grok import grok_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile

from ..conftest import TestEnv, try_import

with try_import() as imports_successful:
import openai

from pydantic_ai.providers.github import GitHubProvider

pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')


def test_github_provider():
provider = GitHubProvider(api_key='ghp_test_token')
assert provider.name == 'github'
assert provider.base_url == 'https://models.github.ai/inference'
assert isinstance(provider.client, openai.AsyncOpenAI)
assert provider.client.api_key == 'ghp_test_token'


def test_github_provider_need_api_key(env: TestEnv) -> None:
env.remove('GITHUB_API_KEY')
with pytest.raises(
UserError,
match=re.escape(
'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`'
' to use the GitHub Models provider.'
),
):
GitHubProvider()


def test_github_provider_pass_http_client() -> None:
http_client = httpx.AsyncClient()
provider = GitHubProvider(http_client=http_client, api_key='ghp_test_token')
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]


def test_github_pass_openai_client() -> None:
openai_client = openai.AsyncOpenAI(api_key='ghp_test_token')
provider = GitHubProvider(openai_client=openai_client)
assert provider.client == openai_client


def test_github_provider_model_profile(mocker: MockerFixture):
provider = GitHubProvider(api_key='ghp_test_token')

ns = 'pydantic_ai.providers.github'
meta_model_profile_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile)
deepseek_model_profile_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile)
mistral_model_profile_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile)
cohere_model_profile_mock = mocker.patch(f'{ns}.cohere_model_profile', wraps=cohere_model_profile)
grok_model_profile_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile)
openai_model_profile_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile)

meta_profile = provider.model_profile('meta/Llama-3.2-11B-Vision-Instruct')
meta_model_profile_mock.assert_called_with('llama-3.2-11b-vision-instruct')
assert meta_profile is not None
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer

meta_profile = provider.model_profile('meta/Llama-3.1-405B-Instruct')
meta_model_profile_mock.assert_called_with('llama-3.1-405b-instruct')
assert meta_profile is not None
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer

deepseek_profile = provider.model_profile('deepseek/deepseek-coder')
deepseek_model_profile_mock.assert_called_with('deepseek-coder')
assert deepseek_profile is not None
assert deepseek_profile.json_schema_transformer == OpenAIJsonSchemaTransformer

mistral_profile = provider.model_profile('mistral-ai/mixtral-8x7b-instruct')
mistral_model_profile_mock.assert_called_with('mixtral-8x7b-instruct')
assert mistral_profile is not None
assert mistral_profile.json_schema_transformer == OpenAIJsonSchemaTransformer

cohere_profile = provider.model_profile('cohere/command-r-plus')
cohere_model_profile_mock.assert_called_with('command-r-plus')
assert cohere_profile is not None
assert cohere_profile.json_schema_transformer == OpenAIJsonSchemaTransformer

grok_profile = provider.model_profile('xai/grok-3-mini')
grok_model_profile_mock.assert_called_with('grok-3-mini')
assert grok_profile is not None
assert grok_profile.json_schema_transformer == OpenAIJsonSchemaTransformer

microsoft_profile = provider.model_profile('microsoft/Phi-3.5-mini-instruct')
openai_model_profile_mock.assert_called_with('phi-3.5-mini-instruct')
assert microsoft_profile is not None
assert microsoft_profile.json_schema_transformer == OpenAIJsonSchemaTransformer

unknown_profile = provider.model_profile('some-unknown-model')
openai_model_profile_mock.assert_called_with('some-unknown-model')
assert unknown_profile is not None
assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer

unknown_profile_with_prefix = provider.model_profile('unknown-publisher/some-unknown-model')
openai_model_profile_mock.assert_called_with('some-unknown-model')
assert unknown_profile_with_prefix is not None
assert unknown_profile_with_prefix.json_schema_transformer == OpenAIJsonSchemaTransformer
2 changes: 2 additions & 0 deletions tests/providers/test_provider_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydantic_ai.providers.cohere import CohereProvider
from pydantic_ai.providers.deepseek import DeepSeekProvider
from pydantic_ai.providers.fireworks import FireworksProvider
from pydantic_ai.providers.github import GitHubProvider
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from pydantic_ai.providers.google_vertex import GoogleVertexProvider
from pydantic_ai.providers.grok import GrokProvider
Expand All @@ -44,6 +45,7 @@
('fireworks', FireworksProvider, 'FIREWORKS_API_KEY'),
('together', TogetherProvider, 'TOGETHER_API_KEY'),
('heroku', HerokuProvider, 'HEROKU_INFERENCE_KEY'),
('github', GitHubProvider, 'GITHUB_API_KEY'),
]

if not imports_successful():
Expand Down