Skip to content

Commit 39cda49

Browse files
Merge pull request #1182 from utsavtulsyan/feature/azureopenai-managed-identity
feat(dspy): added support for managed identity in AzureOpenAI
2 parents 5ceb906 + 00c7b65 commit 39cda49

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

docs/api/language_model_clients/AzureOpenAI.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,32 @@ class AzureOpenAI(LM):
2929
):
3030
```
3131

32+
**Parameters:**
3233

33-
34-
**Parameters:**
3534
- `api_base` (str): Azure Base URL.
3635
- `api_version` (str): Version identifier for Azure OpenAI API.
3736
- `api_key` (_Optional[str]_, _optional_): API provider authentication token. Retrieves from `AZURE_OPENAI_KEY` environment variable if None.
3837
- `model_type` (_Literal["chat", "text"]_): Specified model type to use, defaults to 'chat'.
38+
- `azure_ad_token_provider` (_Optional[AzureADTokenProvider]_, _optional_): Pass the bearer token provider created by _get_bearer_token_provider()_ when using DefaultAzureCredential, alternative to api token.
3939
- `**kwargs`: Additional language model arguments to pass to the API provider.
4040

4141
### Methods
4242

4343
#### `__call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs) -> List[Dict[str, Any]]`
4444

45-
Retrieves completions from Azure OpenAI Endpoints by calling `request`.
45+
Retrieves completions from Azure OpenAI Endpoints by calling `request`.
4646

4747
Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.
4848

4949
After generation, the completions are post-processed based on the `model_type` parameter. If the parameter is set to 'chat', the generated content look like `choice["message"]["content"]`. Otherwise, the generated text will be `choice["text"]`.
5050

5151
**Parameters:**
52+
5253
- `prompt` (_str_): Prompt to send to Azure OpenAI.
5354
- `only_completed` (_bool_, _optional_): Flag to return only completed responses and ignore completion due to length. Defaults to True.
5455
- `return_sorted` (_bool_, _optional_): Flag to sort the completion choices using the returned averaged log-probabilities. Defaults to False.
5556
- `**kwargs`: Additional keyword arguments for completion request.
5657

5758
**Returns:**
59+
5860
- `List[Dict[str, Any]]`: List of completion choices.

dsp/modules/azure_openai.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import json
33
import logging
4-
from typing import Any, Literal, Optional, cast
4+
from typing import Any, Callable, Literal, Optional, cast
55

66
import backoff
77
import openai
@@ -37,6 +37,9 @@ def backoff_hdlr(details):
3737
)
3838

3939

40+
AzureADTokenProvider = Callable[[], str]
41+
42+
4043
class AzureOpenAI(LM):
4144
"""Wrapper around Azure's API for OpenAI.
4245
@@ -57,6 +60,7 @@ def __init__(
5760
api_key: Optional[str] = None,
5861
model_type: Literal["chat", "text"] = "chat",
5962
system_prompt: Optional[str] = None,
63+
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
6064
**kwargs,
6165
):
6266
super().__init__(model)
@@ -75,6 +79,7 @@ def __init__(
7579
openai.api_key = api_key
7680
openai.api_type = "azure"
7781
openai.api_version = api_version
82+
openai.azure_ad_token_provider = azure_ad_token_provider
7883

7984
self.client = None
8085

@@ -83,6 +88,7 @@ def __init__(
8388
azure_endpoint=api_base,
8489
api_key=api_key,
8590
api_version=api_version,
91+
azure_ad_token_provider=azure_ad_token_provider,
8692
)
8793

8894
self.client = client

0 commit comments

Comments
 (0)