summaryrefslogtreecommitdiffhomepage
path: root/packages
diff options
context:
space:
mode:
authorDmytro Yankovskyi <[email protected]>2025-06-20 20:57:33 +0200
committerGitHub <[email protected]>2025-06-20 14:57:33 -0400
commit91c4da5dbda320be0b154c37372dc096ca3f15ad (patch)
tree6ca85fd21c0d5a4272ad59788d99f98e1942cd75 /packages
parent2fd0e7dd6b0a67928609a8f2695a4b8f230ae2ab (diff)
downloadopencode-91c4da5dbda320be0b154c37372dc096ca3f15ad.tar.gz
opencode-91c4da5dbda320be0b154c37372dc096ca3f15ad.zip
fix(#243): claude on aws bedrock (#241)
Co-authored-by: Dax Raad <[email protected]>
Diffstat (limited to 'packages')
-rw-r--r--packages/opencode/package.json5
-rw-r--r--packages/opencode/src/provider/provider.ts76
2 files changed, 56 insertions, 25 deletions
diff --git a/packages/opencode/package.json b/packages/opencode/package.json
index 5a409c7f5..25b7f7117 100644
--- a/packages/opencode/package.json
+++ b/packages/opencode/package.json
@@ -12,13 +12,14 @@
"./*": "./src/*.ts"
},
"devDependencies": {
+ "@ai-sdk/amazon-bedrock": "2.2.10",
+ "@ai-sdk/anthropic": "1.2.12",
"@tsconfig/bun": "1.0.7",
"@types/bun": "latest",
"@types/turndown": "5.0.5",
"@types/yargs": "17.0.33",
"typescript": "catalog:",
- "zod-to-json-schema": "3.24.5",
- "@ai-sdk/anthropic": "1.2.12"
+ "zod-to-json-schema": "3.24.5"
},
"dependencies": {
"@clack/prompts": "0.11.0",
diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts
index 1215d29ed..e074e5d2b 100644
--- a/packages/opencode/src/provider/provider.ts
+++ b/packages/opencode/src/provider/provider.ts
@@ -27,9 +27,13 @@ import { TaskTool } from "../tool/task"
export namespace Provider {
const log = Log.create({ service: "provider" })
- type CustomLoader = (
- provider: ModelsDev.Provider,
- ) => Promise<Record<string, any> | false>
+ type CustomLoader = (provider: ModelsDev.Provider) => Promise<
+ | {
+ getModel?: (sdk: any, modelID: string) => Promise<any>
+ options: Record<string, any>
+ }
+ | false
+ >
type Source = "env" | "config" | "custom" | "api"
@@ -44,30 +48,52 @@ export namespace Provider {
}
}
return {
- apiKey: "",
- async fetch(input: any, init: any) {
- const access = await AuthAnthropic.access()
- const headers = {
- ...init.headers,
- authorization: `Bearer ${access}`,
- "anthropic-beta": "oauth-2025-04-20",
- }
- delete headers["x-api-key"]
- return fetch(input, {
- ...init,
- headers,
- })
+ options: {
+ apiKey: "",
+ async fetch(input: any, init: any) {
+ const access = await AuthAnthropic.access()
+ const headers = {
+ ...init.headers,
+ authorization: `Bearer ${access}`,
+ "anthropic-beta": "oauth-2025-04-20",
+ }
+ delete headers["x-api-key"]
+ return fetch(input, {
+ ...init,
+ headers,
+ })
+ },
},
}
},
+ openai: async () => {
+ return {
+ async getModel(sdk: any, modelID: string) {
+ return sdk.responses(modelID)
+ },
+ options: {},
+ }
+ },
"amazon-bedrock": async () => {
- if (!process.env["AWS_PROFILE"]) return false
+ if (!process.env["AWS_PROFILE"]) false
+
+ const region = process.env["AWS_REGION"] ?? "us-east-1"
+
const { fromNodeProviderChain } = await import(
await BunProc.install("@aws-sdk/credential-providers")
)
return {
- region: process.env["AWS_REGION"] ?? "us-east-1",
- credentialProvider: fromNodeProviderChain(),
+ options: {
+ region,
+ credentialProvider: fromNodeProviderChain(),
+ },
+ async getModel(sdk: any, modelID: string) {
+ if (modelID.includes("claude")) {
+ const prefix = region.split("-")[0]
+ modelID = `${prefix}.${modelID}`
+ }
+ return sdk.languageModel(modelID)
+ },
}
},
}
@@ -80,6 +106,7 @@ export namespace Provider {
[providerID: string]: {
source: Source
info: ModelsDev.Provider
+ getModel?: (sdk: any, modelID: string) => Promise<any>
options: Record<string, any>
}
} = {}
@@ -95,6 +122,7 @@ export namespace Provider {
id: string,
options: Record<string, any>,
source: Source,
+ getModel?: (sdk: any, modelID: string) => Promise<any>,
) {
const provider = providers[id]
if (!provider) {
@@ -110,6 +138,7 @@ export namespace Provider {
}
provider.options = mergeDeep(provider.options, options)
provider.source = source
+ provider.getModel = getModel ?? provider.getModel
}
const configProviders = Object.entries(config.provider ?? {})
@@ -173,7 +202,8 @@ export namespace Provider {
for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
if (disabled.has(providerID)) continue
const result = await fn(database[providerID])
- if (result) mergeProvider(providerID, result, "custom")
+ if (result)
+ mergeProvider(providerID, result.options, "custom", result.getModel)
}
// load config
@@ -236,9 +266,9 @@ export namespace Provider {
const sdk = await getSDK(provider.info)
try {
- const language =
- // @ts-expect-error
- "responses" in sdk ? sdk.responses(modelID) : sdk.languageModel(modelID)
+ const language = provider.getModel
+ ? await provider.getModel(sdk, modelID)
+ : sdk.languageModel(modelID)
log.info("found", { providerID, modelID })
s.models.set(key, {
info,