diff --git a/packages/ai-gateway-provider/src/index.ts b/packages/ai-gateway-provider/src/index.ts index f4ede9450..42b9c0869 100644 --- a/packages/ai-gateway-provider/src/index.ts +++ b/packages/ai-gateway-provider/src/index.ts @@ -118,6 +118,17 @@ export class AiGatewayChatLanguageModel implements LanguageModelV3 { } else { delete req.request.headers[authHeader]; } + + // Strip companion auth headers (e.g. AWS SigV4 headers for Bedrock) + if (providerConfig.extraStripHeaders) { + for (const header of providerConfig.extraStripHeaders) { + if ("delete" in req.request.headers) { + req.request.headers.delete(header); + } else { + delete req.request.headers[header]; + } + } + } } return { diff --git a/packages/ai-gateway-provider/src/providers.ts b/packages/ai-gateway-provider/src/providers.ts index e32604153..ab9d37e66 100644 --- a/packages/ai-gateway-provider/src/providers.ts +++ b/packages/ai-gateway-provider/src/providers.ts @@ -1,4 +1,12 @@ -export const providers = [ +type ProviderConfig = { + name: string; + regex: RegExp; + transformEndpoint: (url: string) => string; + headerKey?: string; + extraStripHeaders?: string[]; +}; + +export const providers: ProviderConfig[] = [ { name: "openai", regex: /^https:\/\/api\.openai\.com\//, @@ -82,6 +90,20 @@ export const providers = [ regex: /^https:\/\/openrouter\.ai\/api\//, transformEndpoint: (url: string) => url.replace(/^https:\/\/openrouter\.ai\/api\//, ""), }, + { + name: "aws-bedrock", + regex: /^https:\/\/bedrock-runtime\.([a-z0-9-]+)\.amazonaws\.com\//, + transformEndpoint: (url: string) => { + const match = url.match( + /^https:\/\/bedrock-runtime\.([a-z0-9-]+)\.amazonaws\.com\/(.*)/, + ); + if (!match || !match[1] || !match[2]) { + throw new Error("Failed to parse AWS Bedrock endpoint URL."); + } + return `bedrock-runtime/${match[1]}/${match[2]}`; + }, + extraStripHeaders: ["x-amz-date", "x-amz-security-token", "x-amz-content-sha256"], + }, { name: "compat", regex: /^https:\/\/gateway\.ai\.cloudflare\.com\/v1\/compat\//, diff --git a/packages/ai-gateway-provider/src/providers/amazon-bedrock.ts b/packages/ai-gateway-provider/src/providers/amazon-bedrock.ts index 607f736ca..ab9502182 100644 --- a/packages/ai-gateway-provider/src/providers/amazon-bedrock.ts +++ b/packages/ai-gateway-provider/src/providers/amazon-bedrock.ts @@ -1,5 +1,20 @@ import { createAmazonBedrock as createAmazonBedrockOriginal } from "@ai-sdk/amazon-bedrock"; -import { authWrapper } from "../auth"; +import { CF_TEMP_TOKEN } from "../auth"; -export const createAmazonBedrock = (...args: Parameters) => - authWrapper(createAmazonBedrockOriginal)(...args); +export const createAmazonBedrock = (...args: Parameters) => { + let [config] = args; + if (config === undefined) { + config = { region: "us-east-1", accessKeyId: CF_TEMP_TOKEN, secretAccessKey: CF_TEMP_TOKEN }; + } else { + if (config.region === undefined) { + config.region = "us-east-1"; + } + if (config.accessKeyId === undefined) { + config.accessKeyId = CF_TEMP_TOKEN; + } + if (config.secretAccessKey === undefined) { + config.secretAccessKey = CF_TEMP_TOKEN; + } + } + return createAmazonBedrockOriginal(config); +}; diff --git a/packages/ai-gateway-provider/test/endpoint.test.ts b/packages/ai-gateway-provider/test/endpoint.test.ts index 408ba1295..04cab1b89 100644 --- a/packages/ai-gateway-provider/test/endpoint.test.ts +++ b/packages/ai-gateway-provider/test/endpoint.test.ts @@ -52,6 +52,16 @@ const testCases = [ name: "azure-openai", url: "https://myresource.openai.azure.com/openai/deployments/mydeployment/chat/completions?api-version=2024-02-15-preview", }, + { + expected: "bedrock-runtime/us-east-1/model/anthropic.claude-sonnet-4-5-20250929-v1:0/invoke", + name: "aws-bedrock", + url: "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-5-20250929-v1:0/invoke", + }, + { + expected: "bedrock-runtime/us-east-1/model/anthropic.claude-sonnet-4-5-20250929-v1%3A0/converse", + name: "aws-bedrock", + url: "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-5-20250929-v1%3A0/converse", + }, ]; describe("ProvidersConfigs endpoint parsing", () => {