Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func New(cfg Config) (Agent, error) {
beforeAgentCallbacks: cfg.BeforeAgentCallbacks,
run: cfg.Run,
afterAgentCallbacks: cfg.AfterAgentCallbacks,
onAgentErrorCallbacks: cfg.OnAgentErrorCallbacks,
State: agentinternal.State{
AgentType: agentinternal.TypeCustomAgent,
},
Expand Down Expand Up @@ -104,6 +105,10 @@ type Config struct {
// created from the content or error of that callback and the remaining
// callbacks will be skipped.
AfterAgentCallbacks []AfterAgentCallback

// OnAgentErrorCallbacks is a list of callbacks that are called sequentially
// when the agent encounters an error during its run.
OnAgentErrorCallbacks []OnAgentErrorCallback
}

// Artifacts interface provides methods to work with artifacts of the current
Expand Down Expand Up @@ -136,6 +141,9 @@ type BeforeAgentCallback func(CallbackContext) (*genai.Content, error)
// BeforeAgentCallbacks returned non-nil results.
type AfterAgentCallback func(CallbackContext) (*genai.Content, error)

// OnAgentErrorCallback is a function that is called when the agent encounters an error.
type OnAgentErrorCallback func(CallbackContext, error) (*genai.Content, error)

type agent struct {
agentinternal.State

Expand All @@ -145,6 +153,7 @@ type agent struct {
beforeAgentCallbacks []BeforeAgentCallback
run func(InvocationContext) iter.Seq2[*session.Event, error]
afterAgentCallbacks []AfterAgentCallback
onAgentErrorCallbacks []OnAgentErrorCallback
}

func (a *agent) Name() string {
Expand Down Expand Up @@ -195,6 +204,20 @@ func (a *agent) Run(ctx InvocationContext) iter.Seq2[*session.Event, error] {
}

for event, err := range a.run(ctx) {
if err != nil {
content, callbackErr := runOnAgentErrorCallbacks(ctx, err)
if callbackErr != nil {
err = callbackErr
} else if content != nil {
event = session.NewEvent(ctx.InvocationID())
event.LLMResponse = model.LLMResponse{
Content: content,
}
event.Author = a.Name()
event.Branch = ctx.Branch()
err = nil
}
}
if event != nil && event.Author == "" {
event.Author = getAuthorForEvent(ctx, event)
}
Expand Down Expand Up @@ -304,6 +327,40 @@ func runBeforeAgentCallbacks(ctx InvocationContext) (*session.Event, error) {
return nil, nil
}

// runOnAgentErrorCallbacks calls onAgentErrorCallbacks when agent encounters an error.
func runOnAgentErrorCallbacks(ctx InvocationContext, err error) (*genai.Content, error) {
agent := ctx.Agent()
pluginManager := pluginManagerFromContext(ctx)

callbackCtx := &callbackContext{
Context: ctx,
invocationContext: ctx,
actions: &session.EventActions{StateDelta: make(map[string]any), ArtifactDelta: make(map[string]int64)},
}

if pluginManager != nil {
content, callbackErr := pluginManager.RunOnAgentErrorCallback(callbackCtx, err)
if callbackErr != nil {
return nil, fmt.Errorf("failed to run plugin on agent error callback: %w", callbackErr)
}
if content != nil {
return content, nil
}
}

for _, callback := range agent.internal().onAgentErrorCallbacks {
content, callbackErr := callback(callbackCtx, err)
if callbackErr != nil {
return nil, fmt.Errorf("failed to run on agent error callback: %w", callbackErr)
}
if content != nil {
return content, nil
}
}

return nil, nil
}

// runAfterAgentCallbacks checks if any afterAgentCallback returns non-nil content or a state modification
// then it create a new event with the new content and state delta.
func runAfterAgentCallbacks(ctx InvocationContext) (*session.Event, error) {
Expand Down Expand Up @@ -516,6 +573,7 @@ func pluginManagerFromContext(ctx context.Context) pluginManager {
type pluginManager interface {
RunBeforeAgentCallback(cctx CallbackContext) (*genai.Content, error)
RunAfterAgentCallback(cctx CallbackContext) (*genai.Content, error)
RunOnAgentErrorCallback(cctx CallbackContext, err error) (*genai.Content, error)
}

var _ InvocationContext = (*invocationContext)(nil)
17 changes: 17 additions & 0 deletions internal/plugininternal/plugin_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,23 @@ func (pm *PluginManager) RunAfterAgentCallback(cctx agent.CallbackContext) (*gen
return nil, nil
}

// RunOnAgentErrorCallback runs the OnAgentErrorCallback for all plugins.
func (pm *PluginManager) RunOnAgentErrorCallback(cctx agent.CallbackContext, err error) (*genai.Content, error) {
for _, plugin := range pm.plugins {
callback := plugin.OnAgentErrorCallback()
if callback != nil {
newContent, callbackErr := callback(cctx, err)
if callbackErr != nil {
return nil, callbackErr
}
if newContent != nil {
return newContent, nil // Early exit
}
}
}
return nil, nil
}

// RunBeforeToolCallback runs the BeforeToolCallback for all plugins.
func (pm *PluginManager) RunBeforeToolCallback(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) {
for _, plugin := range pm.plugins {
Expand Down
11 changes: 11 additions & 0 deletions plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type Config struct {
BeforeAgentCallback agent.BeforeAgentCallback
AfterAgentCallback agent.AfterAgentCallback

OnAgentErrorCallback OnAgentErrorCallback

BeforeModelCallback llmagent.BeforeModelCallback
AfterModelCallback llmagent.AfterModelCallback
OnModelErrorCallback llmagent.OnModelErrorCallback
Expand All @@ -56,6 +58,7 @@ func New(cfg Config) (*Plugin, error) {
afterRunCallback: cfg.AfterRunCallback,
beforeAgentCallback: cfg.BeforeAgentCallback,
afterAgentCallback: cfg.AfterAgentCallback,
onAgentErrorCallback: cfg.OnAgentErrorCallback,
beforeModelCallback: cfg.BeforeModelCallback,
afterModelCallback: cfg.AfterModelCallback,
onModelErrorCallback: cfg.OnModelErrorCallback,
Expand Down Expand Up @@ -87,6 +90,8 @@ type Plugin struct {
beforeAgentCallback agent.BeforeAgentCallback
afterAgentCallback agent.AfterAgentCallback

onAgentErrorCallback OnAgentErrorCallback

beforeModelCallback llmagent.BeforeModelCallback
afterModelCallback llmagent.AfterModelCallback
onModelErrorCallback llmagent.OnModelErrorCallback
Expand Down Expand Up @@ -134,6 +139,10 @@ func (p *Plugin) AfterAgentCallback() agent.AfterAgentCallback {
return p.afterAgentCallback
}

func (p *Plugin) OnAgentErrorCallback() OnAgentErrorCallback {
return p.onAgentErrorCallback
}

func (p *Plugin) BeforeModelCallback() llmagent.BeforeModelCallback {
return p.beforeModelCallback
}
Expand Down Expand Up @@ -165,3 +174,5 @@ type BeforeRunCallback func(agent.InvocationContext) (*genai.Content, error)
type AfterRunCallback func(agent.InvocationContext)

type OnEventCallback func(agent.InvocationContext, *session.Event) (*session.Event, error)

type OnAgentErrorCallback func(agent.CallbackContext, error) (*genai.Content, error)