diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 67a05fd6c..2a601c319 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -220,9 +220,15 @@ func expectRequestBody(t *testing.T, expectedRequestBody any) *partialMock { type partialMock struct { t *testing.T - expectedPath string - expectedQueryParams map[string]string - expectedRequestBody any + expectedPath string + expectedQueryParams map[string]string + expectedRequestBody any + expectedHeaderContains map[string]string +} + +func (p *partialMock) withHeaders(headers map[string]string) *partialMock { + p.expectedHeaderContains = headers + return p } func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc { @@ -247,6 +253,12 @@ func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc require.Equal(p.t, p.expectedRequestBody, unmarshaledRequestBody) } + if p.expectedHeaderContains != nil { + for k, v := range p.expectedHeaderContains { + require.Contains(p.t, r.Header.Get(k), v, "expected header %q to contain %q", k, v) + } + } + responseHandler(w, r) } } diff --git a/pkg/github/minimal_types.go b/pkg/github/minimal_types.go index a8757c51c..b1e7c2357 100644 --- a/pkg/github/minimal_types.go +++ b/pkg/github/minimal_types.go @@ -51,6 +51,22 @@ type MinimalSearchRepositoriesResult struct { Items []MinimalRepository `json:"items"` } +// MinimalCodeSearchResult is the trimmed output type for code search results. +type MinimalCodeSearchResult struct { + TotalCount int `json:"total_count"` + IncompleteResults bool `json:"incomplete_results"` + Items []MinimalCodeResult `json:"items"` +} + +// MinimalCodeResult is the trimmed output type for a single code search hit. +type MinimalCodeResult struct { + Name string `json:"name"` + Path string `json:"path"` + SHA string `json:"sha"` + Repository string `json:"repository"` + TextMatches []*github.TextMatch `json:"text_matches,omitempty"` +} + // MinimalCommitAuthor represents commit author information. type MinimalCommitAuthor struct { Name string `json:"name,omitempty"` diff --git a/pkg/github/search.go b/pkg/github/search.go index 500921376..8edfc948a 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -270,8 +270,9 @@ func SearchCode(t translations.TranslationHelperFunc) inventory.ServerTool { } opts := &github.SearchOptions{ - Sort: sort, - Order: order, + Sort: sort, + Order: order, + TextMatch: true, ListOptions: github.ListOptions{ PerPage: pagination.PerPage, Page: pagination.Page, @@ -301,7 +302,27 @@ func SearchCode(t translations.TranslationHelperFunc) inventory.ServerTool { return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search code", resp, body), nil, nil } - r, err := json.Marshal(result) + minimalItems := make([]MinimalCodeResult, 0, len(result.CodeResults)) + for _, code := range result.CodeResults { + item := MinimalCodeResult{ + Name: code.GetName(), + Path: code.GetPath(), + SHA: code.GetSHA(), + TextMatches: code.TextMatches, + } + if code.Repository != nil { + item.Repository = code.Repository.GetFullName() + } + minimalItems = append(minimalItems, item) + } + + minimalResult := &MinimalCodeSearchResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalItems, + } + + r, err := json.Marshal(minimalResult) if err != nil { return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index eb5d98075..0c4a30c32 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -430,22 +430,35 @@ func Test_SearchCode(t *testing.T) { IncompleteResults: github.Ptr(false), CodeResults: []*github.CodeResult{ { - Name: github.Ptr("file1.go"), - Path: github.Ptr("path/to/file1.go"), - SHA: github.Ptr("abc123def456"), - HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file1.go"), - Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")}, + Name: github.Ptr("file1.go"), + Path: github.Ptr("path/to/file1.go"), + SHA: github.Ptr("abc123def456"), + Repository: &github.Repository{ + Name: github.Ptr("repo"), + FullName: github.Ptr("owner/repo"), + }, + TextMatches: []*github.TextMatch{ + { + Fragment: github.Ptr("func main() { fmt.Println(\"hello\") }"), + }, + }, }, { - Name: github.Ptr("file2.go"), - Path: github.Ptr("path/to/file2.go"), - SHA: github.Ptr("def456abc123"), - HTMLURL: github.Ptr("https://github.com/owner/repo/blob/main/path/to/file2.go"), - Repository: &github.Repository{Name: github.Ptr("repo"), FullName: github.Ptr("owner/repo")}, + Name: github.Ptr("file2.go"), + Path: github.Ptr("path/to/file2.go"), + SHA: github.Ptr("def456abc123"), + Repository: &github.Repository{ + Name: github.Ptr("repo"), + FullName: github.Ptr("owner/repo"), + }, }, }, } + textMatchAcceptHeader := map[string]string{ + "Accept": "text-match", + } + tests := []struct { name string mockedClient *http.Client @@ -463,7 +476,7 @@ func Test_SearchCode(t *testing.T) { "order": "desc", "page": "1", "per_page": "30", - }).andThen( + }).withHeaders(textMatchAcceptHeader).andThen( mockResponse(t, http.StatusOK, mockSearchResult), ), }), @@ -484,7 +497,7 @@ func Test_SearchCode(t *testing.T) { "q": "fmt.Println language:go", "page": "1", "per_page": "30", - }).andThen( + }).withHeaders(textMatchAcceptHeader).andThen( mockResponse(t, http.StatusOK, mockSearchResult), ), }), @@ -537,22 +550,28 @@ func Test_SearchCode(t *testing.T) { require.NoError(t, err) require.False(t, result.IsError) - // Parse the result and get the text content if no error textContent := getTextResult(t, result) - // Unmarshal and verify the result - var returnedResult github.CodeSearchResult + var returnedResult MinimalCodeSearchResult err = json.Unmarshal([]byte(textContent.Text), &returnedResult) require.NoError(t, err) - assert.Equal(t, *tc.expectedResult.Total, *returnedResult.Total) - assert.Equal(t, *tc.expectedResult.IncompleteResults, *returnedResult.IncompleteResults) - assert.Len(t, returnedResult.CodeResults, len(tc.expectedResult.CodeResults)) - for i, code := range returnedResult.CodeResults { - assert.Equal(t, *tc.expectedResult.CodeResults[i].Name, *code.Name) - assert.Equal(t, *tc.expectedResult.CodeResults[i].Path, *code.Path) - assert.Equal(t, *tc.expectedResult.CodeResults[i].SHA, *code.SHA) - assert.Equal(t, *tc.expectedResult.CodeResults[i].HTMLURL, *code.HTMLURL) - assert.Equal(t, *tc.expectedResult.CodeResults[i].Repository.FullName, *code.Repository.FullName) + assert.Equal(t, *tc.expectedResult.Total, returnedResult.TotalCount) + assert.Equal(t, *tc.expectedResult.IncompleteResults, returnedResult.IncompleteResults) + assert.Len(t, returnedResult.Items, len(tc.expectedResult.CodeResults)) + for i, code := range returnedResult.Items { + assert.Equal(t, tc.expectedResult.CodeResults[i].GetName(), code.Name) + assert.Equal(t, tc.expectedResult.CodeResults[i].GetPath(), code.Path) + assert.Equal(t, tc.expectedResult.CodeResults[i].GetSHA(), code.SHA) + assert.Equal(t, tc.expectedResult.CodeResults[i].Repository.GetFullName(), code.Repository) + } + + // Verify text matches are included when present + if len(tc.expectedResult.CodeResults[0].TextMatches) > 0 { + require.NotEmpty(t, returnedResult.Items[0].TextMatches) + assert.Equal(t, + tc.expectedResult.CodeResults[0].TextMatches[0].GetFragment(), + returnedResult.Items[0].TextMatches[0].GetFragment(), + ) } }) }