Skip to content
Merged
2 changes: 2 additions & 0 deletions go/adk/pkg/a2a/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/a2aproject/a2a-go/a2asrv"
"github.com/a2aproject/a2a-go/a2asrv/eventqueue"
"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/adk/pkg/auth"
"github.com/kagent-dev/kagent/go/adk/pkg/models"
"github.com/kagent-dev/kagent/go/adk/pkg/session"
"github.com/kagent-dev/kagent/go/adk/pkg/skills"
Expand Down Expand Up @@ -117,6 +118,7 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
sessionID := reqCtx.ContextID

ctx = withBearerToken(ctx)
ctx = auth.WithUserID(ctx, userID)

e.logger.Info("Execute",
"taskID", reqCtx.TaskID,
Expand Down
18 changes: 18 additions & 0 deletions go/adk/pkg/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ import (
"time"
)

type contextKey int

const userIDKey contextKey = iota

// WithUserID returns a copy of ctx that carries the user ID for injection into
// outgoing HTTP requests by TokenRoundTripper.
func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, userIDKey, userID)
}

func userIDFromContext(ctx context.Context) string {
id, _ := ctx.Value(userIDKey).(string)
return id
}

const kagentTokenPath = "/var/run/secrets/tokens/kagent-token"

// KAgentTokenService reads a k8s token from a file and reloads it periodically
Expand Down Expand Up @@ -61,6 +76,9 @@ func (s *KAgentTokenService) AddHeaders(req *http.Request) {
if token := s.GetToken(); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
if userID := userIDFromContext(req.Context()); userID != "" {
req.Header.Set("X-User-Id", userID)
}
}

// readToken reads the token from the file
Expand Down
75 changes: 41 additions & 34 deletions go/core/internal/httpserver/auth/proxy_authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,67 +27,74 @@ func NewProxyAuthenticator(userIDClaim string) *ProxyAuthenticator {

func (a *ProxyAuthenticator) Authenticate(ctx context.Context, reqHeaders http.Header, query url.Values) (auth.Session, error) {
authHeader := reqHeaders.Get("Authorization")

// Always read agent identity from X-Agent-Name header (used by agents calling back)
agentID := reqHeaders.Get("X-Agent-Name")

// If we have a Bearer token, parse JWT
if tokenString, ok := strings.CutPrefix(authHeader, "Bearer "); ok {
// Parse JWT without validation (oauth2-proxy or k8s service account already validated)
rawClaims, err := parseJWTPayload(tokenString)
if err != nil {
return nil, ErrUnauthenticated
}
tokenString, ok := strings.CutPrefix(authHeader, "Bearer ")
if !ok {
return nil, ErrUnauthenticated
Copy link
Copy Markdown
Contributor Author

@onematchfox onematchfox Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a change in behaviour as it requires that a Bearer token be provided whereas the original code allowed any caller that knew the X-Agent-Name and X-User-Id to bypass use of a Bearer token. By default any agent created via the Agent CRD will supply a Bearer token - so this would only affect direct consumers of the API that are currently relying on this insecurity.

}

// Parse JWT without validation (oauth2-proxy or k8s service account already validated)
rawClaims, err := parseJWTPayload(tokenString)
if err != nil {
return nil, ErrUnauthenticated
}

userID, _ := rawClaims[a.userIDClaim].(string)
if userID == "" && a.userIDClaim != "sub" {
if agentID != "" {
// Agent call: the Bearer SA token authenticates the pod; the caller's
// identity should be supplied explicitly via X-User-Id / user_id.
// Fall back to the SA sub claim for direct calls to agent pods that
// do not yet propagate the caller identity.
userID := userIDFromRequest(reqHeaders, query)
if userID == "" {
userID, _ = rawClaims["sub"].(string)
}
if userID == "" {
return nil, ErrUnauthenticated
}

return &SimpleSession{
P: auth.Principal{
User: auth.User{ID: userID},
Agent: auth.Agent{ID: agentID},
Claims: rawClaims,
User: auth.User{ID: userID},
Agent: auth.Agent{ID: agentID},
},
authHeader: authHeader,
}, nil
}

// Fall back to service account auth for internal agent-to-controller calls.
// Requires X-Agent-Name to identify the calling agent.
if agentID == "" {
return nil, ErrUnauthenticated
}

// Agents authenticate via user_id query param or X-User-Id header
userID := query.Get("user_id")
if userID == "" {
userID = reqHeaders.Get("X-User-Id")
// Direct user call: identity comes from the OIDC JWT claims.
userID, _ := rawClaims[a.userIDClaim].(string)
if userID == "" && a.userIDClaim != "sub" {
userID, _ = rawClaims["sub"].(string)
}
if userID == "" {
return nil, ErrUnauthenticated
}

return &SimpleSession{
P: auth.Principal{
User: auth.User{
ID: userID,
},
Agent: auth.Agent{
ID: agentID,
},
User: auth.User{ID: userID},
Claims: rawClaims,
},
authHeader: authHeader,
}, nil
}

// userIDFromRequest returns the user identity from the user_id query param or
// X-User-Id header, preferring the query param.
func userIDFromRequest(headers http.Header, query url.Values) string {
if v := query.Get("user_id"); v != "" {
return v
}
return headers.Get("X-User-Id")
}

func (a *ProxyAuthenticator) UpstreamAuth(r *http.Request, session auth.Session, upstreamPrincipal auth.Principal) error {
if simpleSession, ok := session.(*SimpleSession); ok && simpleSession.authHeader != "" {
r.Header.Set("Authorization", simpleSession.authHeader)
if simpleSession, ok := session.(*SimpleSession); ok {
if simpleSession.authHeader != "" {
r.Header.Set("Authorization", simpleSession.authHeader)
}
if userID := simpleSession.P.User.ID; userID != "" {
r.Header.Set("X-User-Id", userID)
}
}
return nil
}
Expand Down
116 changes: 36 additions & 80 deletions go/core/internal/httpserver/auth/proxy_authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,112 +159,63 @@ func TestProxyAuthenticator_Authenticate(t *testing.T) {
}
}

func TestProxyAuthenticator_JWTWithAgentHeader(t *testing.T) {
func TestProxyAuthenticator_AgentCalls(t *testing.T) {
tests := []struct {
name string
claims map[string]any
agentName string
headers map[string]string
queryParams map[string]string
wantUserID string
wantAgentID string
wantErr bool
}{
{
name: "extracts agent identity from header when JWT is present",
claims: map[string]any{
"sub": "system:serviceaccount:kagent:kebab-agent",
"iss": "https://kubernetes.default.svc.cluster.local",
"aud": []any{"kagent"},
name: "agent with SA Bearer token and X-User-Id header uses header identity",
headers: map[string]string{
"Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}),
"X-Agent-Name": "kagent/test-agent",
"X-User-Id": "user@example.com",
},
agentName: "kagent__NS__kebab_agent",
wantUserID: "system:serviceaccount:kagent:kebab-agent",
wantAgentID: "kagent__NS__kebab_agent",
wantUserID: "user@example.com",
wantAgentID: "kagent/test-agent",
},
{
name: "works with OIDC JWT and agent header",
claims: map[string]any{
"sub": "user123",
"email": "user@example.com",
name: "agent with SA Bearer token and user_id query param uses query identity",
headers: map[string]string{
"Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}),
"X-Agent-Name": "kagent/test-agent",
},
agentName: "kagent__NS__my_agent",
wantUserID: "user123",
wantAgentID: "kagent__NS__my_agent",
},
{
name: "handles JWT without agent header",
claims: map[string]any{
"sub": "user123",
queryParams: map[string]string{
"user_id": "user@example.com",
},
agentName: "",
wantUserID: "user123",
wantAgentID: "",
wantUserID: "user@example.com",
wantAgentID: "kagent/test-agent",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
auth := authimpl.NewProxyAuthenticator("")

headers := http.Header{}
token := createTestJWT(tt.claims)
headers.Set("Authorization", "Bearer "+token)
if tt.agentName != "" {
headers.Set("X-Agent-Name", tt.agentName)
}

session, err := auth.Authenticate(context.Background(), headers, url.Values{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

principal := session.Principal()
if principal.User.ID != tt.wantUserID {
t.Errorf("User.ID = %q, want %q", principal.User.ID, tt.wantUserID)
}
if principal.Agent.ID != tt.wantAgentID {
t.Errorf("Agent.ID = %q, want %q", principal.Agent.ID, tt.wantAgentID)
}
})
}
}

func TestProxyAuthenticator_ServiceAccountFallback(t *testing.T) {
tests := []struct {
name string
headers map[string]string
queryParams map[string]string
wantUserID string
wantAgentID string
wantErr bool
}{
{
name: "authenticates via user_id query param with agent name",
queryParams: map[string]string{
"user_id": "system:serviceaccount:kagent:kebab-agent",
},
name: "agent with no X-User-Id falls back to SA sub claim",
headers: map[string]string{
"X-Agent-Name": "kagent/kebab-agent",
"Authorization": "Bearer " + createTestJWT(map[string]any{"sub": "system:serviceaccount:kagent:test-agent"}),
"X-Agent-Name": "kagent/test-agent",
},
wantUserID: "system:serviceaccount:kagent:kebab-agent",
wantAgentID: "kagent/kebab-agent",
wantErr: false,
wantUserID: "system:serviceaccount:kagent:test-agent",
wantAgentID: "kagent/test-agent",
},
// Error cases.
{
name: "authenticates via X-User-Id header with agent name",
name: "agent without Bearer token is rejected",
headers: map[string]string{
"X-User-Id": "system:serviceaccount:kagent:test-agent",
"X-Agent-Name": "kagent/test-agent",
"X-User-Id": "user@example.com",
},
wantUserID: "system:serviceaccount:kagent:test-agent",
wantAgentID: "kagent/test-agent",
wantErr: false,
wantErr: true,
},
{
name: "returns error when no auth method available",
name: "no token and no X-Agent-Name is rejected",
wantErr: true,
},
{
name: "returns error when no X-Agent-Name header for fallback",
name: "user_id without X-Agent-Name is rejected",
queryParams: map[string]string{
"user_id": "system:serviceaccount:kagent:kebab-agent",
"user_id": "user@example.com",
},
wantErr: true,
},
Expand Down Expand Up @@ -339,4 +290,9 @@ func TestProxyAuthenticator_UpstreamAuth(t *testing.T) {
if got := req.Header.Get("Authorization"); got != authHeader {
t.Errorf("Authorization header = %q, want %q", got, authHeader)
}

// Verify X-User-Id is forwarded so downstream A2A runtimes receive the real user identity
if got := req.Header.Get("X-User-Id"); got != "user123" {
t.Errorf("X-User-Id header = %q, want %q", got, "user123")
}
}
14 changes: 3 additions & 11 deletions python/packages/kagent-adk/src/kagent/adk/_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ async def create_session(
response = await self.client.post(
"/api/sessions",
json=request_data,
headers={"X-User-ID": user_id},
)
response.raise_for_status()

Expand Down Expand Up @@ -88,10 +87,7 @@ async def get_session(
url += "&limit=-1"

# Make API call to get session
response: httpx.Response = await self.client.get(
url,
headers={"X-User-ID": user_id},
)
response: httpx.Response = await self.client.get(url)
if response.status_code == 404:
return None
response.raise_for_status()
Expand Down Expand Up @@ -131,7 +127,7 @@ async def get_session(
@override
async def list_sessions(self, *, app_name: str, user_id: str) -> ListSessionsResponse:
# Make API call to list sessions
response = await self.client.get(f"/api/sessions?user_id={user_id}", headers={"X-User-ID": user_id})
response = await self.client.get(f"/api/sessions?user_id={user_id}")
response.raise_for_status()

data = response.json()
Expand All @@ -151,10 +147,7 @@ def list_sessions_sync(self, *, app_name: str, user_id: str) -> ListSessionsResp
@override
async def delete_session(self, *, app_name: str, user_id: str, session_id: str) -> None:
# Make API call to delete session
response = await self.client.delete(
f"/api/sessions/{session_id}?user_id={user_id}",
headers={"X-User-ID": user_id},
)
response = await self.client.delete(f"/api/sessions/{session_id}?user_id={user_id}")
response.raise_for_status()

@override
Expand All @@ -172,7 +165,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
response = await self.client.post(
f"/api/sessions/{session.id}/events?user_id={session.user_id}",
json=event_data,
headers={"X-User-ID": session.user_id},
)
response.raise_for_status()

Expand Down
8 changes: 5 additions & 3 deletions python/packages/kagent-adk/src/kagent/adk/_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional

import httpx
from kagent.core.a2a import get_request_user_id

KAGENT_TOKEN_PATH = "/var/run/secrets/tokens/kagent-token"
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -35,7 +36,7 @@ def event_hooks(self):
"""Returns a dictionary of event hooks for the application
to use when creating the httpx.AsyncClient.
"""
return {"request": [self._add_bearer_token]}
return {"request": [self._add_headers]}

async def _update_token_loop(self) -> None:
self.token = await self._read_kagent_token()
Expand All @@ -61,12 +62,13 @@ async def _refresh_token(self):
async with self.update_lock:
self.token = token

async def _add_bearer_token(self, request: httpx.Request):
# Your function to generate headers dynamically
async def _add_headers(self, request: httpx.Request):
token = await self._get_token()
headers = {"X-Agent-Name": self.app_name}
if token:
headers["Authorization"] = f"Bearer {token}"
if user_id := get_request_user_id():
headers["X-User-Id"] = user_id
request.headers.update(headers)


Expand Down
Loading
Loading