diff options
Diffstat (limited to 'packages')
| -rw-r--r-- | packages/opencode/package.json | 5 | ||||
| -rw-r--r-- | packages/opencode/src/provider/provider.ts | 76 |
2 files changed, 56 insertions, 25 deletions
diff --git a/packages/opencode/package.json b/packages/opencode/package.json index 5a409c7f5..25b7f7117 100644 --- a/packages/opencode/package.json +++ b/packages/opencode/package.json @@ -12,13 +12,14 @@ "./*": "./src/*.ts" }, "devDependencies": { + "@ai-sdk/amazon-bedrock": "2.2.10", + "@ai-sdk/anthropic": "1.2.12", "@tsconfig/bun": "1.0.7", "@types/bun": "latest", "@types/turndown": "5.0.5", "@types/yargs": "17.0.33", "typescript": "catalog:", - "zod-to-json-schema": "3.24.5", - "@ai-sdk/anthropic": "1.2.12" + "zod-to-json-schema": "3.24.5" }, "dependencies": { "@clack/prompts": "0.11.0", diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 1215d29ed..e074e5d2b 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -27,9 +27,13 @@ import { TaskTool } from "../tool/task" export namespace Provider { const log = Log.create({ service: "provider" }) - type CustomLoader = ( - provider: ModelsDev.Provider, - ) => Promise<Record<string, any> | false> + type CustomLoader = (provider: ModelsDev.Provider) => Promise< + | { + getModel?: (sdk: any, modelID: string) => Promise<any> + options: Record<string, any> + } + | false + > type Source = "env" | "config" | "custom" | "api" @@ -44,30 +48,52 @@ export namespace Provider { } } return { - apiKey: "", - async fetch(input: any, init: any) { - const access = await AuthAnthropic.access() - const headers = { - ...init.headers, - authorization: `Bearer ${access}`, - "anthropic-beta": "oauth-2025-04-20", - } - delete headers["x-api-key"] - return fetch(input, { - ...init, - headers, - }) + options: { + apiKey: "", + async fetch(input: any, init: any) { + const access = await AuthAnthropic.access() + const headers = { + ...init.headers, + authorization: `Bearer ${access}`, + "anthropic-beta": "oauth-2025-04-20", + } + delete headers["x-api-key"] + return fetch(input, { + ...init, + headers, + }) + }, }, } }, + openai: async () => { + return { + async getModel(sdk: any, modelID: string) { + return sdk.responses(modelID) + }, + options: {}, + } + }, "amazon-bedrock": async () => { - if (!process.env["AWS_PROFILE"]) return false + if (!process.env["AWS_PROFILE"]) false + + const region = process.env["AWS_REGION"] ?? "us-east-1" + const { fromNodeProviderChain } = await import( await BunProc.install("@aws-sdk/credential-providers") ) return { - region: process.env["AWS_REGION"] ?? "us-east-1", - credentialProvider: fromNodeProviderChain(), + options: { + region, + credentialProvider: fromNodeProviderChain(), + }, + async getModel(sdk: any, modelID: string) { + if (modelID.includes("claude")) { + const prefix = region.split("-")[0] + modelID = `${prefix}.${modelID}` + } + return sdk.languageModel(modelID) + }, } }, } @@ -80,6 +106,7 @@ export namespace Provider { [providerID: string]: { source: Source info: ModelsDev.Provider + getModel?: (sdk: any, modelID: string) => Promise<any> options: Record<string, any> } } = {} @@ -95,6 +122,7 @@ export namespace Provider { id: string, options: Record<string, any>, source: Source, + getModel?: (sdk: any, modelID: string) => Promise<any>, ) { const provider = providers[id] if (!provider) { @@ -110,6 +138,7 @@ export namespace Provider { } provider.options = mergeDeep(provider.options, options) provider.source = source + provider.getModel = getModel ?? provider.getModel } const configProviders = Object.entries(config.provider ?? {}) @@ -173,7 +202,8 @@ export namespace Provider { for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) { if (disabled.has(providerID)) continue const result = await fn(database[providerID]) - if (result) mergeProvider(providerID, result, "custom") + if (result) + mergeProvider(providerID, result.options, "custom", result.getModel) } // load config @@ -236,9 +266,9 @@ export namespace Provider { const sdk = await getSDK(provider.info) try { - const language = - // @ts-expect-error - "responses" in sdk ? sdk.responses(modelID) : sdk.languageModel(modelID) + const language = provider.getModel + ? await provider.getModel(sdk, modelID) + : sdk.languageModel(modelID) log.info("found", { providerID, modelID }) s.models.set(key, { info, |
