Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions pkg/inventory/server_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package inventory
import (
"context"
"encoding/json"
"fmt"

"github.com/github/github-mcp-server/pkg/octicons"
"github.com/modelcontextprotocol/go-sdk/mcp"
Expand Down Expand Up @@ -133,7 +134,12 @@ func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, hand
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var arguments In
if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil {
return nil, err
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: fmt.Sprintf("invalid arguments: %s", err)},
},
IsError: true,
}, nil
}
resp, _, err := typedHandler(ctx, req, arguments)
return resp, err
Expand All @@ -157,7 +163,12 @@ func NewServerToolWithContextHandler[In any, Out any](tool mcp.Tool, toolset Too
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var arguments In
if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil {
return nil, err
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: fmt.Sprintf("invalid arguments: %s", err)},
},
IsError: true,
}, nil
}
resp, _, err := handler(ctx, req, arguments)
return resp, err
Expand Down
118 changes: 118 additions & 0 deletions pkg/inventory/server_tool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package inventory

import (
"context"
"encoding/json"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewServerTool_InvalidArguments_ReturnsIsError(t *testing.T) {
type expectedArgs struct {
Owner string `json:"owner"`
Repo string `json:"repo"`
}

tool := NewServerTool(
mcp.Tool{Name: "test_tool"},
testToolsetMetadata("test"),
func(deps any) mcp.ToolHandlerFor[expectedArgs, *mcp.CallToolResult] {
return func(ctx context.Context, req *mcp.CallToolRequest, args expectedArgs) (*mcp.CallToolResult, *mcp.CallToolResult, error) {
t.Fatal("handler should not be called with invalid arguments")
return nil, nil, nil
}
},
)

handler := tool.HandlerFunc(nil)

badArgs, _ := json.Marshal(map[string]any{"owner": 12345, "repo": true})
result, err := handler(context.Background(), &mcp.CallToolRequest{
Params: &mcp.CallToolParamsRaw{
Name: "test_tool",
Arguments: badArgs,
},
})

require.NoError(t, err)
require.NotNil(t, result)
assert.True(t, result.IsError)
assert.Len(t, result.Content, 1)
textContent, ok := result.Content[0].(*mcp.TextContent)
require.True(t, ok)
assert.Contains(t, textContent.Text, "invalid arguments")
}

func TestNewServerToolWithContextHandler_InvalidArguments_ReturnsIsError(t *testing.T) {
type expectedArgs struct {
Query string `json:"query"`
Limit int `json:"limit"`
}

tool := NewServerToolWithContextHandler(
mcp.Tool{Name: "test_context_tool"},
testToolsetMetadata("test"),
func(ctx context.Context, req *mcp.CallToolRequest, args expectedArgs) (*mcp.CallToolResult, any, error) {
t.Fatal("handler should not be called with invalid arguments")
return nil, nil, nil
},
)

handler := tool.HandlerFunc(nil)

result, err := handler(context.Background(), &mcp.CallToolRequest{
Params: &mcp.CallToolParamsRaw{
Name: "test_context_tool",
Arguments: json.RawMessage(`{not valid json`),
},
})

require.NoError(t, err)
require.NotNil(t, result)
assert.True(t, result.IsError)
assert.Len(t, result.Content, 1)
textContent, ok := result.Content[0].(*mcp.TextContent)
require.True(t, ok)
assert.Contains(t, textContent.Text, "invalid arguments")
}

func TestNewServerTool_ValidArguments_Succeeds(t *testing.T) {
type expectedArgs struct {
Owner string `json:"owner"`
Repo string `json:"repo"`
}

tool := NewServerTool(
mcp.Tool{Name: "test_tool"},
testToolsetMetadata("test"),
func(deps any) mcp.ToolHandlerFor[expectedArgs, *mcp.CallToolResult] {
return func(ctx context.Context, req *mcp.CallToolRequest, args expectedArgs) (*mcp.CallToolResult, *mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "success: " + args.Owner + "/" + args.Repo},
},
}, nil, nil
}
},
)

handler := tool.HandlerFunc(nil)

goodArgs, _ := json.Marshal(map[string]any{"owner": "octocat", "repo": "hello-world"})
result, err := handler(context.Background(), &mcp.CallToolRequest{
Params: &mcp.CallToolParamsRaw{
Name: "test_tool",
Arguments: goodArgs,
},
})

require.NoError(t, err)
require.NotNil(t, result)
assert.False(t, result.IsError)
textContent, ok := result.Content[0].(*mcp.TextContent)
require.True(t, ok)
assert.Equal(t, "success: octocat/hello-world", textContent.Text)
}