Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions core/http/endpoints/localai/agent_collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")
file, err := c.FormFile("file")
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"})
Expand Down Expand Up @@ -116,7 +116,7 @@ func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFun
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
entries, err := svc.ListCollectionEntriesForUser(userID, c.Param("name"))
entries, err := svc.ListCollectionEntriesForUser(userID, decodedParam(c, "name"))
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
Expand All @@ -139,7 +139,7 @@ func GetCollectionEntryContentEndpoint(app *application.Application) echo.Handle
if err != nil {
entry = entryParam
}
content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, c.Param("name"), entry)
content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, decodedParam(c, "name"), entry)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
Expand All @@ -164,7 +164,7 @@ func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc {
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
results, err := svc.SearchCollectionForUser(userID, c.Param("name"), payload.Query, payload.MaxResults)
results, err := svc.SearchCollectionForUser(userID, decodedParam(c, "name"), payload.Query, payload.MaxResults)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
Expand All @@ -182,7 +182,7 @@ func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
if err := svc.ResetCollectionForUser(userID, c.Param("name")); err != nil {
if err := svc.ResetCollectionForUser(userID, decodedParam(c, "name")); err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
Expand All @@ -202,7 +202,7 @@ func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFun
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
remaining, err := svc.DeleteCollectionEntryForUser(userID, c.Param("name"), payload.Entry)
remaining, err := svc.DeleteCollectionEntryForUser(userID, decodedParam(c, "name"), payload.Entry)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
Expand Down Expand Up @@ -230,7 +230,7 @@ func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc
if payload.UpdateInterval < 1 {
payload.UpdateInterval = 60
}
if err := svc.AddCollectionSourceForUser(userID, c.Param("name"), payload.URL, payload.UpdateInterval); err != nil {
if err := svc.AddCollectionSourceForUser(userID, decodedParam(c, "name"), payload.URL, payload.UpdateInterval); err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
Expand All @@ -250,7 +250,7 @@ func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFu
if err := c.Bind(&payload); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
}
if err := svc.RemoveCollectionSourceForUser(userID, c.Param("name"), payload.URL); err != nil {
if err := svc.RemoveCollectionSourceForUser(userID, decodedParam(c, "name"), payload.URL); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
Expand All @@ -267,7 +267,7 @@ func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.Handle
if err != nil {
entry = entryParam
}
fpath, err := svc.GetCollectionEntryFilePathForUser(userID, c.Param("name"), entry)
fpath, err := svc.GetCollectionEntryFilePathForUser(userID, decodedParam(c, "name"), entry)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
Expand All @@ -282,7 +282,7 @@ func ListCollectionSourcesEndpoint(app *application.Application) echo.HandlerFun
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
sources, err := svc.ListCollectionSourcesForUser(userID, c.Param("name"))
sources, err := svc.ListCollectionSourcesForUser(userID, decodedParam(c, "name"))
if err != nil {
if strings.Contains(err.Error(), "not found") {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
Expand Down
49 changes: 49 additions & 0 deletions core/http/endpoints/localai/agent_collections_param_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package localai

import (
"net/http"
"net/http/httptest"

"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

// Regression for #10443: agent/collection names carry a "legacy-api-key:"
// prefix, so the ':' is percent-encoded as %3A in the request path. Echo routes
// such paths via URL.RawPath and stores the path-param value still escaped, so
// handlers must URL-decode it before looking the collection up in the store -
// otherwise the lookup sees "legacy-api-key%3ALiteraryResearch" and 404s.
var _ = Describe("decodedParam", func() {
var e *echo.Echo

BeforeEach(func() {
e = echo.New()
})

// route runs a request through Echo's real router so the path param is
// populated exactly as it would be in production, then returns the decoded
// value the handler would observe.
route := func(rawPath string) string {
var got string
e.GET("/api/agents/collections/:name/upload", func(c echo.Context) error {
got = decodedParam(c, "name")
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, rawPath, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(http.StatusOK))
return got
}

It("decodes a percent-encoded colon in the collection name", func() {
got := route("/api/agents/collections/legacy-api-key%3ALiteraryResearch/upload")
Expect(got).To(Equal("legacy-api-key:LiteraryResearch"))
})

It("leaves an unencoded name untouched", func() {
got := route("/api/agents/collections/PlainCollection/upload")
Expect(got).To(Equal("PlainCollection"))
})
})
41 changes: 29 additions & 12 deletions core/http/endpoints/localai/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"maps"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
Expand Down Expand Up @@ -33,6 +34,22 @@ func getUserID(c echo.Context) string {
return user.ID
}

// decodedParam returns the named path parameter, URL-decoding it.
//
// Echo routes a request via URL.RawPath whenever the path contains
// percent-encoded characters (e.g. %3A for ':'), and in that case stores the
// matched path-param value raw/escaped. Agent and collection names carry a
// "legacy-api-key:" prefix, so the ':' arrives as %3A and the raw param no
// longer matches the stored name. Callers must unescape before lookups.
// Falls back to the raw value if it isn't valid percent-encoding.
func decodedParam(c echo.Context, name string) string {
raw := c.Param(name)
if decoded, err := url.PathUnescape(raw); err == nil {
return decoded
}
return raw
}

// isAdminUser returns true if the authenticated user has admin role.
func isAdminUser(c echo.Context) bool {
user := auth.GetUser(c)
Expand Down Expand Up @@ -127,7 +144,7 @@ func GetAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")

statuses := svc.ListAgentsForUser(userID)
active, exists := statuses[name]
Expand All @@ -142,7 +159,7 @@ func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")
var cfg state.AgentConfig
if err := c.Bind(&cfg); err != nil {
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
Expand All @@ -161,7 +178,7 @@ func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")
if err := svc.DeleteAgentForUser(userID, name); err != nil {
return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()})
}
Expand All @@ -173,7 +190,7 @@ func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")
cfg := svc.GetAgentConfigForUser(userID, name)
if cfg == nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"})
Expand All @@ -186,7 +203,7 @@ func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
if err := svc.PauseAgentForUser(userID, c.Param("name")); err != nil {
if err := svc.PauseAgentForUser(userID, decodedParam(c, "name")); err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
Expand All @@ -197,7 +214,7 @@ func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
if err := svc.ResumeAgentForUser(userID, c.Param("name")); err != nil {
if err := svc.ResumeAgentForUser(userID, decodedParam(c, "name")); err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
Expand All @@ -208,7 +225,7 @@ func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")

history := svc.GetAgentStatusForUser(userID, name)
if history == nil {
Expand Down Expand Up @@ -241,7 +258,7 @@ func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")

history, err := svc.GetAgentObservablesForUser(userID, name)
if err != nil {
Expand All @@ -261,7 +278,7 @@ func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFun
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")
if err := svc.ClearAgentObservablesForUser(userID, name); err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
}
Expand All @@ -273,7 +290,7 @@ func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")
var payload struct {
Message string `json:"message"`
}
Expand Down Expand Up @@ -302,7 +319,7 @@ func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")

// Try local SSE manager first
manager := svc.GetSSEManagerForUser(userID, name)
Expand Down Expand Up @@ -334,7 +351,7 @@ func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
svc := app.AgentPoolService()
userID := effectiveUserID(c)
name := c.Param("name")
name := decodedParam(c, "name")
data, err := svc.ExportAgentForUser(userID, name)
if err != nil {
return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()})
Expand Down
Loading