Skip to content
131 changes: 99 additions & 32 deletions go/adk/pkg/models/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,34 @@ import (
)

// bedrockToolIDValid matches Bedrock's toolUseId constraint: [a-zA-Z0-9_.:-]+
// bedrockToolNameInvalid matches characters not allowed in Bedrock tool names: [a-zA-Z0-9_-]+
var (
bedrockToolIDValid = regexp.MustCompile(`^[a-zA-Z0-9_.:-]+$`)
bedrockToolIDInvalid = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
bedrockToolIDValid = regexp.MustCompile(`^[a-zA-Z0-9_.:-]+$`)
bedrockToolIDInvalid = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
bedrockToolNameInvalid = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
)

// sanitizeBedrockToolName returns a valid Bedrock tool name.
// Bedrock requires tool names to match [a-zA-Z0-9_-]+ and be non-empty.
// nameMap caches original->sanitized so repeated lookups for the same name are
// consistent. counter is incremented only when a synthetic name is needed.
func sanitizeBedrockToolName(name string, nameMap map[string]string, counter *int) string {
if name == "" {
*counter++
return fmt.Sprintf("tool_fn_%d", *counter)
}
if sanitized, ok := nameMap[name]; ok {
return sanitized
}
sanitized := bedrockToolNameInvalid.ReplaceAllString(name, "_")
if sanitized == "" {
*counter++
sanitized = fmt.Sprintf("tool_fn_%d", *counter)
}
nameMap[name] = sanitized
return sanitized
}

// sanitizeBedrockToolID returns a valid Bedrock toolUseId.
// Bedrock requires toolUseId to match [a-zA-Z0-9_.:-]+ and be non-empty.
// idMap caches original→sanitized so FunctionCall and FunctionResponse
Expand Down Expand Up @@ -121,8 +144,32 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques
modelName = req.Model
}

// Convert content to Bedrock messages
messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents)
// Build tool configuration first so nameMap is available for message conversion.
// convertGenaiToolsToBedrock sanitizes tool names and returns the
// original->sanitized mapping so the same mapping can be applied to
// conversation history and reversed when restoring names from responses.
var toolConfig *types.ToolConfiguration
nameMap := make(map[string]string)
if req.Config != nil && len(req.Config.Tools) > 0 {
tools, nm := convertGenaiToolsToBedrock(req.Config.Tools)
nameMap = nm
if len(tools) > 0 {
toolConfig = &types.ToolConfiguration{
Tools: tools,
}
}
}

// Build reverse map for restoring original tool names from Bedrock responses.
reverseNameMap := make(map[string]string, len(nameMap))
for orig, sanitized := range nameMap {
reverseNameMap[sanitized] = orig
}

// Convert content to Bedrock messages.
// nameMap is passed so that any tool call recorded in conversation history
// is written with the sanitized name Bedrock already knows about.
messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents, nameMap)

// Build inference config
var inferenceConfig *types.InferenceConfiguration
Expand All @@ -147,27 +194,15 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques
})
}

// Build tool configuration
var toolConfig *types.ToolConfiguration
if req.Config != nil && len(req.Config.Tools) > 0 {
tools := convertGenaiToolsToBedrock(req.Config.Tools)
if len(tools) > 0 {
toolConfig = &types.ToolConfiguration{
Tools: tools,
}
}
}

// Build model-specific additional fields (Claude top_k, thinking, etc.)
additionalFields := m.buildAdditionalModelRequestFields()

// Set telemetry attributes
telemetry.SetLLMRequestAttributes(ctx, modelName, req)

if stream {
m.generateStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, yield)
m.generateStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, reverseNameMap, yield)
} else {
m.generateNonStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, yield)
m.generateNonStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, reverseNameMap, yield)
}
}
}
Expand All @@ -185,7 +220,8 @@ func (m *BedrockModel) buildAdditionalModelRequestFields() document.Interface {

// generateStreaming handles streaming responses from Bedrock ConverseStream.
// It properly handles both text and tool use content blocks during streaming.
func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, yield func(*model.LLMResponse, error) bool) {
// reverseNameMap maps sanitized Bedrock tool names back to their original names.
func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, reverseNameMap map[string]string, yield func(*model.LLMResponse, error) bool) {
output, err := m.Client.ConverseStream(ctx, &bedrockruntime.ConverseStreamInput{
ModelId: aws.String(modelId),
Messages: messages,
Expand Down Expand Up @@ -266,11 +302,17 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok {
blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex)
if tc, ok := toolCalls[blockIdx]; ok {
// Tool use block completed - parse the accumulated JSON and create FunctionCall
// Tool use block completed - parse the accumulated JSON and create FunctionCall.
// Restore the original tool name from the reverse map so the ADK framework
// can dispatch to the correctly registered tool.
originalName := tc.Name
if orig, found := reverseNameMap[tc.Name]; found {
originalName = orig
}
args := tc.parseArgs()
functionCall := &genai.FunctionCall{
ID: tc.ID,
Name: tc.Name,
Name: originalName,
Args: args,
}
completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall})
Expand Down Expand Up @@ -338,7 +380,8 @@ func (tc *streamingToolCall) parseArgs() map[string]any {
}

// generateNonStreaming handles non-streaming responses from Bedrock Converse.
func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, yield func(*model.LLMResponse, error) bool) {
// reverseNameMap maps sanitized Bedrock tool names back to their original names.
func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, reverseNameMap map[string]string, yield func(*model.LLMResponse, error) bool) {
output, err := m.Client.Converse(ctx, &bedrockruntime.ConverseInput{
ModelId: aws.String(modelId),
Messages: messages,
Expand Down Expand Up @@ -366,9 +409,15 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string,
}
// Handle tool use content
if toolUseBlock, ok := block.(*types.ContentBlockMemberToolUse); ok {
// Restore the original tool name so the ADK framework can dispatch
// to the correctly registered tool.
toolName := aws.ToString(toolUseBlock.Value.Name)
if orig, found := reverseNameMap[toolName]; found {
toolName = orig
}
functionCall := &genai.FunctionCall{
ID: aws.ToString(toolUseBlock.Value.ToolUseId),
Name: aws.ToString(toolUseBlock.Value.Name),
Name: toolName,
}
// Convert document.Interface to map using the String() method and JSON parsing
// The document type in AWS SDK implements String() that returns JSON
Expand Down Expand Up @@ -425,7 +474,10 @@ func documentToMap(doc document.Interface) map[string]any {
}

// convertGenaiContentsToBedrockMessages converts genai.Content to Bedrock Converse API message format.
func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.Message, string) {
// nameMap is the original->sanitized tool name map produced by convertGenaiToolsToBedrock.
// Any FunctionCall found in the conversation history is written with the sanitized name so
// that Bedrock can correlate it with the tool spec it already received. A nil nameMap is safe.
func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap map[string]string) ([]types.Message, string) {
var messages []types.Message
var systemInstruction string

Expand Down Expand Up @@ -465,11 +517,17 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.M
continue
}

// Handle function call (tool use in Bedrock terminology)
// Handle function call (tool use in Bedrock terminology).
// Use the sanitized name from nameMap so Bedrock can correlate the
// tool call with the tool spec sent in the same request.
if part.FunctionCall != nil {
callName := part.FunctionCall.Name
if sanitized, ok := nameMap[callName]; ok {
callName = sanitized
}
toolUse := types.ToolUseBlock{
ToolUseId: aws.String(sanitizeBedrockToolID(part.FunctionCall.ID, idMap, &idCounter)),
Name: aws.String(part.FunctionCall.Name),
Name: aws.String(callName),
Input: document.NewLazyDocument(part.FunctionCall.Args),
}
contentBlocks = append(contentBlocks, &types.ContentBlockMemberToolUse{
Expand Down Expand Up @@ -507,11 +565,16 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.M
}

// convertGenaiToolsToBedrock converts genai.Tool to Bedrock Tool format.
func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
// It sanitizes tool names to satisfy Bedrock's [a-zA-Z0-9_-]+ constraint and
// returns the original->sanitized name mapping so callers can apply it to
// conversation history and reverse it when restoring names from responses.
func convertGenaiToolsToBedrock(tools []*genai.Tool) ([]types.Tool, map[string]string) {
if len(tools) == 0 {
return nil
return nil, nil
}

nameMap := make(map[string]string)
nameCounter := 0
var bedrockTools []types.Tool

for _, tool := range tools {
Expand All @@ -525,7 +588,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
}

// Build input schema as JSON document.
// MCP tools and built-in local toolsset ParametersJsonSchema
// MCP tools and built-in local toolsets set ParametersJsonSchema.
var schema map[string]any
if decl.ParametersJsonSchema != nil {
schema = parametersJsonSchemaToMap(decl.ParametersJsonSchema)
Expand All @@ -536,7 +599,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
// then lowercase all type values to match JSON Schema standard.
schema = genaiSchemaToMap(decl.Parameters)
}
// Fallback to empty object if no schema is found
// Fallback to empty object if no schema is found.
if schema == nil {
schema = map[string]any{"type": "object", "properties": map[string]any{}}
}
Expand All @@ -545,8 +608,12 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
Value: document.NewLazyDocument(schema),
}

// Sanitize the tool name: MCP tool names often contain dots, colons,
// or spaces (e.g. "fetch.get_url") that Bedrock rejects.
sanitizedName := sanitizeBedrockToolName(decl.Name, nameMap, &nameCounter)

toolSpec := types.ToolSpecification{
Name: aws.String(decl.Name),
Name: aws.String(sanitizedName),
Description: aws.String(decl.Description),
InputSchema: inputSchema,
}
Expand All @@ -558,7 +625,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
}
}

return bedrockTools
return bedrockTools, nameMap
}

// bedrockStopReasonToGenai maps Bedrock stop reason to genai.FinishReason.
Expand Down
95 changes: 90 additions & 5 deletions go/adk/pkg/models/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs, systemText := convertGenaiContentsToBedrockMessages(tt.contents)
msgs, systemText := convertGenaiContentsToBedrockMessages(tt.contents, nil)
if len(msgs) != tt.wantMsgCount {
t.Errorf("expected %d messages, got %d", tt.wantMsgCount, len(msgs))
}
Expand All @@ -124,7 +124,7 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) {
// sources: genai.Schema (declaration-based), map[string]any (MCP), and
// *jsonschema.Schema (functiontool.New).
func TestConvertGenaiToolsToBedrock(t *testing.T) {
extractSchema := func(t *testing.T, tools []types.Tool) map[string]any {
extractSchema := func(t *testing.T, tools []types.Tool, _ map[string]string) map[string]any {
t.Helper()
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
Expand Down Expand Up @@ -162,7 +162,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) {
},
}}}}

schema := extractSchema(t, convertGenaiToolsToBedrock(tools))
bt1, nm1 := convertGenaiToolsToBedrock(tools)
schema := extractSchema(t, bt1, nm1)

props := schema["properties"].(map[string]any)
for prop, want := range map[string]string{"location": "string", "count": "integer", "detailed": "boolean"} {
Expand All @@ -189,7 +190,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) {
},
}}}}

schema := extractSchema(t, convertGenaiToolsToBedrock(tools))
bt2, nm2 := convertGenaiToolsToBedrock(tools)
schema := extractSchema(t, bt2, nm2)
props, ok := schema["properties"].(map[string]any)
if !ok || len(props) == 0 {
t.Fatalf("expected non-empty properties, got %v", schema["properties"])
Expand All @@ -209,7 +211,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) {
ParametersJsonSchema: s,
}}}}

schema := extractSchema(t, convertGenaiToolsToBedrock(tools))
bt3, nm3 := convertGenaiToolsToBedrock(tools)
schema := extractSchema(t, bt3, nm3)
props, ok := schema["properties"].(map[string]any)
if !ok || len(props) == 0 {
t.Fatalf("expected non-empty properties (means *jsonschema.Schema was not converted): %v", schema["properties"])
Expand Down Expand Up @@ -310,6 +313,88 @@ func TestSanitizeBedrockToolID(t *testing.T) {
})
}

func TestSanitizeBedrockToolName(t *testing.T) {
tests := []struct {
name string
tool string
want string
}{
{name: "valid name unchanged", tool: "get_weather", want: "get_weather"},
{name: "valid name with hyphen", tool: "fetch-data", want: "fetch-data"},
{name: "dot replaced", tool: "fetch.get_url", want: "fetch_get_url"},
{name: "colon replaced", tool: "filesystem:read_file", want: "filesystem_read_file"},
{name: "space replaced", tool: "my tool", want: "my_tool"},
{name: "multiple invalid chars", tool: "a.b:c d", want: "a_b_c_d"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nameMap := make(map[string]string)
counter := 0
if got := sanitizeBedrockToolName(tt.tool, nameMap, &counter); got != tt.want {
t.Errorf("sanitizeBedrockToolName(%q) = %q, want %q", tt.tool, got, tt.want)
}
})
}

t.Run("empty name gets synthetic", func(t *testing.T) {
nameMap, counter := make(map[string]string), 0
got := sanitizeBedrockToolName("", nameMap, &counter)
if got != "tool_fn_1" {
t.Errorf("expected tool_fn_1, got %q", got)
}
if counter != 1 {
t.Errorf("expected counter=1, got %d", counter)
}
})

t.Run("caching returns same sanitized name", func(t *testing.T) {
nameMap, counter := make(map[string]string), 0
first := sanitizeBedrockToolName("fetch.get_url", nameMap, &counter)
second := sanitizeBedrockToolName("fetch.get_url", nameMap, &counter)
if first != second {
t.Errorf("expected same cached result, got %q and %q", first, second)
}
if counter != 0 {
t.Errorf("expected counter unchanged, got %d", counter)
}
})
}

func TestConvertGenaiToolsToBedrockSanitizesNames(t *testing.T) {
tools := []*genai.Tool{{FunctionDeclarations: []*genai.FunctionDeclaration{
{Name: "fetch.get_url", Description: "Fetch a URL"},
{Name: "filesystem:read_file", Description: "Read a file"},
}}}

bedrockTools, nameMap := convertGenaiToolsToBedrock(tools)
if len(bedrockTools) != 2 {
t.Fatalf("expected 2 tools, got %d", len(bedrockTools))
}

// Verify sanitized names in the Bedrock tool specs.
for i, want := range []string{"fetch_get_url", "filesystem_read_file"} {
tm, ok := bedrockTools[i].(*types.ToolMemberToolSpec)
if !ok {
t.Fatalf("tool %d: expected *types.ToolMemberToolSpec", i)
}
got := ""
if tm.Value.Name != nil {
got = *tm.Value.Name
}
if got != want {
t.Errorf("tool %d: expected name %q, got %q", i, want, got)
}
}

// Verify nameMap contains the mappings.
if nameMap["fetch.get_url"] != "fetch_get_url" {
t.Errorf("nameMap[fetch.get_url] = %q, want fetch_get_url", nameMap["fetch.get_url"])
}
if nameMap["filesystem:read_file"] != "filesystem_read_file" {
t.Errorf("nameMap[filesystem:read_file] = %q, want filesystem_read_file", nameMap["filesystem:read_file"])
}
}

func TestStreamingToolCallParseArgs(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading
Loading