diff options
| author | YJG <[email protected]> | 2025-04-28 10:42:57 -0300 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-04-28 15:42:57 +0200 |
| commit | 805aeff83cad4c17e25acdd671d2731be104b3e0 (patch) | |
| tree | e2bdcbda42858a9b159301d3253929b37ed39f84 /internal | |
| parent | bce2ec5c10c1895a80fae48d315b132341b7dc96 (diff) | |
| download | opencode-805aeff83cad4c17e25acdd671d2731be104b3e0.tar.gz opencode-805aeff83cad4c17e25acdd671d2731be104b3e0.zip | |
feat: add azure openai models (#74)
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/config/config.go | 11 | ||||
| -rw-r--r-- | internal/llm/models/azure.go | 157 | ||||
| -rw-r--r-- | internal/llm/models/models.go | 1 | ||||
| -rw-r--r-- | internal/llm/provider/azure.go | 47 | ||||
| -rw-r--r-- | internal/llm/provider/provider.go | 5 |
5 files changed, 221 insertions, 0 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index 4864ef18a..b3a9861e1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -272,6 +272,15 @@ func setProviderDefaults() { viper.SetDefault("agents.title.model", models.BedrockClaude37Sonnet) return } + + if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" { + // api-key may be empty when using Entra ID credentials – that's okay + viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY")) + viper.SetDefault("agents.coder.model", models.AzureGPT41) + viper.SetDefault("agents.task.model", models.AzureGPT41Mini) + viper.SetDefault("agents.title.model", models.AzureGPT41Mini) + return + } } // hasAWSCredentials checks if AWS credentials are available in the environment. @@ -506,6 +515,8 @@ func getProviderAPIKey(provider models.ModelProvider) string { return os.Getenv("GEMINI_API_KEY") case models.ProviderGROQ: return os.Getenv("GROQ_API_KEY") + case models.ProviderAzure: + return os.Getenv("AZURE_OPENAI_API_KEY") case models.ProviderBedrock: if hasAWSCredentials() { return "aws-credentials-available" diff --git a/internal/llm/models/azure.go b/internal/llm/models/azure.go new file mode 100644 index 000000000..6b7bac3a0 --- /dev/null +++ b/internal/llm/models/azure.go @@ -0,0 +1,157 @@ +package models + +const ProviderAzure ModelProvider = "azure" + +const ( + AzureGPT41 ModelID = "azure.gpt-4.1" + AzureGPT41Mini ModelID = "azure.gpt-4.1-mini" + AzureGPT41Nano ModelID = "azure.gpt-4.1-nano" + AzureGPT45Preview ModelID = "azure.gpt-4.5-preview" + AzureGPT4o ModelID = "azure.gpt-4o" + AzureGPT4oMini ModelID = "azure.gpt-4o-mini" + AzureO1 ModelID = "azure.o1" + AzureO1Mini ModelID = "azure.o1-mini" + AzureO3 ModelID = "azure.o3" + AzureO3Mini ModelID = "azure.o3-mini" + AzureO4Mini ModelID = "azure.o4-mini" +) + +var AzureModels = map[ModelID]Model{ + AzureGPT41: { + ID: AzureGPT41, + Name: "Azure OpenAI – GPT 4.1", + Provider: ProviderAzure, + APIModel: "gpt-4.1", + CostPer1MIn: OpenAIModels[GPT41].CostPer1MIn, + CostPer1MInCached: OpenAIModels[GPT41].CostPer1MInCached, + CostPer1MOut: OpenAIModels[GPT41].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[GPT41].CostPer1MOutCached, + ContextWindow: OpenAIModels[GPT41].ContextWindow, + DefaultMaxTokens: OpenAIModels[GPT41].DefaultMaxTokens, + }, + AzureGPT41Mini: { + ID: AzureGPT41Mini, + Name: "Azure OpenAI – GPT 4.1 mini", + Provider: ProviderAzure, + APIModel: "gpt-4.1-mini", + CostPer1MIn: OpenAIModels[GPT41Mini].CostPer1MIn, + CostPer1MInCached: OpenAIModels[GPT41Mini].CostPer1MInCached, + CostPer1MOut: OpenAIModels[GPT41Mini].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[GPT41Mini].CostPer1MOutCached, + ContextWindow: OpenAIModels[GPT41Mini].ContextWindow, + DefaultMaxTokens: OpenAIModels[GPT41Mini].DefaultMaxTokens, + }, + AzureGPT41Nano: { + ID: AzureGPT41Nano, + Name: "Azure OpenAI – GPT 4.1 nano", + Provider: ProviderAzure, + APIModel: "gpt-4.1-nano", + CostPer1MIn: OpenAIModels[GPT41Nano].CostPer1MIn, + CostPer1MInCached: OpenAIModels[GPT41Nano].CostPer1MInCached, + CostPer1MOut: OpenAIModels[GPT41Nano].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[GPT41Nano].CostPer1MOutCached, + ContextWindow: OpenAIModels[GPT41Nano].ContextWindow, + DefaultMaxTokens: OpenAIModels[GPT41Nano].DefaultMaxTokens, + }, + AzureGPT45Preview: { + ID: AzureGPT45Preview, + Name: "Azure OpenAI – GPT 4.5 preview", + Provider: ProviderAzure, + APIModel: "gpt-4.5-preview", + CostPer1MIn: OpenAIModels[GPT45Preview].CostPer1MIn, + CostPer1MInCached: OpenAIModels[GPT45Preview].CostPer1MInCached, + CostPer1MOut: OpenAIModels[GPT45Preview].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[GPT45Preview].CostPer1MOutCached, + ContextWindow: OpenAIModels[GPT45Preview].ContextWindow, + DefaultMaxTokens: OpenAIModels[GPT45Preview].DefaultMaxTokens, + }, + AzureGPT4o: { + ID: AzureGPT4o, + Name: "Azure OpenAI – GPT-4o", + Provider: ProviderAzure, + APIModel: "gpt-4o", + CostPer1MIn: OpenAIModels[GPT4o].CostPer1MIn, + CostPer1MInCached: OpenAIModels[GPT4o].CostPer1MInCached, + CostPer1MOut: OpenAIModels[GPT4o].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[GPT4o].CostPer1MOutCached, + ContextWindow: OpenAIModels[GPT4o].ContextWindow, + DefaultMaxTokens: OpenAIModels[GPT4o].DefaultMaxTokens, + }, + AzureGPT4oMini: { + ID: AzureGPT4oMini, + Name: "Azure OpenAI – GPT-4o mini", + Provider: ProviderAzure, + APIModel: "gpt-4o-mini", + CostPer1MIn: OpenAIModels[GPT4oMini].CostPer1MIn, + CostPer1MInCached: OpenAIModels[GPT4oMini].CostPer1MInCached, + CostPer1MOut: OpenAIModels[GPT4oMini].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[GPT4oMini].CostPer1MOutCached, + ContextWindow: OpenAIModels[GPT4oMini].ContextWindow, + DefaultMaxTokens: OpenAIModels[GPT4oMini].DefaultMaxTokens, + }, + AzureO1: { + ID: AzureO1, + Name: "Azure OpenAI – O1", + Provider: ProviderAzure, + APIModel: "o1", + CostPer1MIn: OpenAIModels[O1].CostPer1MIn, + CostPer1MInCached: OpenAIModels[O1].CostPer1MInCached, + CostPer1MOut: OpenAIModels[O1].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[O1].CostPer1MOutCached, + ContextWindow: OpenAIModels[O1].ContextWindow, + DefaultMaxTokens: OpenAIModels[O1].DefaultMaxTokens, + CanReason: OpenAIModels[O1].CanReason, + }, + AzureO1Mini: { + ID: AzureO1Mini, + Name: "Azure OpenAI – O1 mini", + Provider: ProviderAzure, + APIModel: "o1-mini", + CostPer1MIn: OpenAIModels[O1Mini].CostPer1MIn, + CostPer1MInCached: OpenAIModels[O1Mini].CostPer1MInCached, + CostPer1MOut: OpenAIModels[O1Mini].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[O1Mini].CostPer1MOutCached, + ContextWindow: OpenAIModels[O1Mini].ContextWindow, + DefaultMaxTokens: OpenAIModels[O1Mini].DefaultMaxTokens, + CanReason: OpenAIModels[O1Mini].CanReason, + }, + AzureO3: { + ID: AzureO3, + Name: "Azure OpenAI – O3", + Provider: ProviderAzure, + APIModel: "o3", + CostPer1MIn: OpenAIModels[O3].CostPer1MIn, + CostPer1MInCached: OpenAIModels[O3].CostPer1MInCached, + CostPer1MOut: OpenAIModels[O3].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[O3].CostPer1MOutCached, + ContextWindow: OpenAIModels[O3].ContextWindow, + DefaultMaxTokens: OpenAIModels[O3].DefaultMaxTokens, + CanReason: OpenAIModels[O3].CanReason, + }, + AzureO3Mini: { + ID: AzureO3Mini, + Name: "Azure OpenAI – O3 mini", + Provider: ProviderAzure, + APIModel: "o3-mini", + CostPer1MIn: OpenAIModels[O3Mini].CostPer1MIn, + CostPer1MInCached: OpenAIModels[O3Mini].CostPer1MInCached, + CostPer1MOut: OpenAIModels[O3Mini].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[O3Mini].CostPer1MOutCached, + ContextWindow: OpenAIModels[O3Mini].ContextWindow, + DefaultMaxTokens: OpenAIModels[O3Mini].DefaultMaxTokens, + CanReason: OpenAIModels[O3Mini].CanReason, + }, + AzureO4Mini: { + ID: AzureO4Mini, + Name: "Azure OpenAI – O4 mini", + Provider: ProviderAzure, + APIModel: "o4-mini", + CostPer1MIn: OpenAIModels[O4Mini].CostPer1MIn, + CostPer1MInCached: OpenAIModels[O4Mini].CostPer1MInCached, + CostPer1MOut: OpenAIModels[O4Mini].CostPer1MOut, + CostPer1MOutCached: OpenAIModels[O4Mini].CostPer1MOutCached, + ContextWindow: OpenAIModels[O4Mini].ContextWindow, + DefaultMaxTokens: OpenAIModels[O4Mini].DefaultMaxTokens, + CanReason: OpenAIModels[O4Mini].CanReason, + }, +} diff --git a/internal/llm/models/models.go b/internal/llm/models/models.go index 1bc02c49d..bad0ebdaa 100644 --- a/internal/llm/models/models.go +++ b/internal/llm/models/models.go @@ -76,4 +76,5 @@ func init() { maps.Copy(SupportedModels, OpenAIModels) maps.Copy(SupportedModels, GeminiModels) maps.Copy(SupportedModels, GroqModels) + maps.Copy(SupportedModels, AzureModels) } diff --git a/internal/llm/provider/azure.go b/internal/llm/provider/azure.go new file mode 100644 index 000000000..6368a181c --- /dev/null +++ b/internal/llm/provider/azure.go @@ -0,0 +1,47 @@ +package provider + +import ( + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/openai/openai-go" + "github.com/openai/openai-go/azure" + "github.com/openai/openai-go/option" +) + +type azureClient struct { + *openaiClient +} + +type AzureClient ProviderClient + +func newAzureClient(opts providerClientOptions) AzureClient { + + endpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // ex: https://foo.openai.azure.com + apiVersion := os.Getenv("AZURE_OPENAI_API_VERSION") // ex: 2025-04-01-preview + + if endpoint == "" || apiVersion == "" { + return &azureClient{openaiClient: newOpenAIClient(opts).(*openaiClient)} + } + + reqOpts := []option.RequestOption{ + azure.WithEndpoint(endpoint, apiVersion), + } + + if opts.apiKey != "" || os.Getenv("AZURE_OPENAI_API_KEY") != "" { + key := opts.apiKey + if key == "" { + key = os.Getenv("AZURE_OPENAI_API_KEY") + } + reqOpts = append(reqOpts, azure.WithAPIKey(key)) + } else if cred, err := azidentity.NewDefaultAzureCredential(nil); err == nil { + reqOpts = append(reqOpts, azure.WithTokenCredential(cred)) + } + + base := &openaiClient{ + providerOptions: opts, + client: openai.NewClient(reqOpts...), + } + + return &azureClient{openaiClient: base} +} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 00b7b2978..737b6fb00 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -115,6 +115,11 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption options: clientOptions, client: newOpenAIClient(clientOptions), }, nil + case models.ProviderAzure: + return &baseProvider[AzureClient]{ + options: clientOptions, + client: newAzureClient(clientOptions), + }, nil case models.ProviderMock: // TODO: implement mock client for test panic("not implemented") |
