diff --git a/agent/agent.go b/agent/agent.go index 450389743..73beceb30 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -67,6 +67,7 @@ func New(cfg Config) (Agent, error) { beforeAgentCallbacks: cfg.BeforeAgentCallbacks, run: cfg.Run, afterAgentCallbacks: cfg.AfterAgentCallbacks, + onAgentErrorCallbacks: cfg.OnAgentErrorCallbacks, State: agentinternal.State{ AgentType: agentinternal.TypeCustomAgent, }, @@ -104,6 +105,10 @@ type Config struct { // created from the content or error of that callback and the remaining // callbacks will be skipped. AfterAgentCallbacks []AfterAgentCallback + + // OnAgentErrorCallbacks is a list of callbacks that are called sequentially + // when the agent encounters an error during its run. + OnAgentErrorCallbacks []OnAgentErrorCallback } // Artifacts interface provides methods to work with artifacts of the current @@ -136,6 +141,9 @@ type BeforeAgentCallback func(CallbackContext) (*genai.Content, error) // BeforeAgentCallbacks returned non-nil results. type AfterAgentCallback func(CallbackContext) (*genai.Content, error) +// OnAgentErrorCallback is a function that is called when the agent encounters an error. +type OnAgentErrorCallback func(CallbackContext, error) (*genai.Content, error) + type agent struct { agentinternal.State @@ -145,6 +153,7 @@ type agent struct { beforeAgentCallbacks []BeforeAgentCallback run func(InvocationContext) iter.Seq2[*session.Event, error] afterAgentCallbacks []AfterAgentCallback + onAgentErrorCallbacks []OnAgentErrorCallback } func (a *agent) Name() string { @@ -195,6 +204,20 @@ func (a *agent) Run(ctx InvocationContext) iter.Seq2[*session.Event, error] { } for event, err := range a.run(ctx) { + if err != nil { + content, callbackErr := runOnAgentErrorCallbacks(ctx, err) + if callbackErr != nil { + err = callbackErr + } else if content != nil { + event = session.NewEvent(ctx.InvocationID()) + event.LLMResponse = model.LLMResponse{ + Content: content, + } + event.Author = a.Name() + event.Branch = ctx.Branch() + err = nil + } + } if event != nil && event.Author == "" { event.Author = getAuthorForEvent(ctx, event) } @@ -304,6 +327,40 @@ func runBeforeAgentCallbacks(ctx InvocationContext) (*session.Event, error) { return nil, nil } +// runOnAgentErrorCallbacks calls onAgentErrorCallbacks when agent encounters an error. +func runOnAgentErrorCallbacks(ctx InvocationContext, err error) (*genai.Content, error) { + agent := ctx.Agent() + pluginManager := pluginManagerFromContext(ctx) + + callbackCtx := &callbackContext{ + Context: ctx, + invocationContext: ctx, + actions: &session.EventActions{StateDelta: make(map[string]any), ArtifactDelta: make(map[string]int64)}, + } + + if pluginManager != nil { + content, callbackErr := pluginManager.RunOnAgentErrorCallback(callbackCtx, err) + if callbackErr != nil { + return nil, fmt.Errorf("failed to run plugin on agent error callback: %w", callbackErr) + } + if content != nil { + return content, nil + } + } + + for _, callback := range agent.internal().onAgentErrorCallbacks { + content, callbackErr := callback(callbackCtx, err) + if callbackErr != nil { + return nil, fmt.Errorf("failed to run on agent error callback: %w", callbackErr) + } + if content != nil { + return content, nil + } + } + + return nil, nil +} + // runAfterAgentCallbacks checks if any afterAgentCallback returns non-nil content or a state modification // then it create a new event with the new content and state delta. func runAfterAgentCallbacks(ctx InvocationContext) (*session.Event, error) { @@ -516,6 +573,7 @@ func pluginManagerFromContext(ctx context.Context) pluginManager { type pluginManager interface { RunBeforeAgentCallback(cctx CallbackContext) (*genai.Content, error) RunAfterAgentCallback(cctx CallbackContext) (*genai.Content, error) + RunOnAgentErrorCallback(cctx CallbackContext, err error) (*genai.Content, error) } var _ InvocationContext = (*invocationContext)(nil) diff --git a/internal/plugininternal/plugin_manager.go b/internal/plugininternal/plugin_manager.go index 97e855647..7826436c9 100644 --- a/internal/plugininternal/plugin_manager.go +++ b/internal/plugininternal/plugin_manager.go @@ -167,6 +167,23 @@ func (pm *PluginManager) RunAfterAgentCallback(cctx agent.CallbackContext) (*gen return nil, nil } +// RunOnAgentErrorCallback runs the OnAgentErrorCallback for all plugins. +func (pm *PluginManager) RunOnAgentErrorCallback(cctx agent.CallbackContext, err error) (*genai.Content, error) { + for _, plugin := range pm.plugins { + callback := plugin.OnAgentErrorCallback() + if callback != nil { + newContent, callbackErr := callback(cctx, err) + if callbackErr != nil { + return nil, callbackErr + } + if newContent != nil { + return newContent, nil // Early exit + } + } + } + return nil, nil +} + // RunBeforeToolCallback runs the BeforeToolCallback for all plugins. func (pm *PluginManager) RunBeforeToolCallback(ctx tool.Context, tool tool.Tool, args map[string]any) (map[string]any, error) { for _, plugin := range pm.plugins { diff --git a/plugin/plugin.go b/plugin/plugin.go index 162e73e68..3143731a8 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -36,6 +36,8 @@ type Config struct { BeforeAgentCallback agent.BeforeAgentCallback AfterAgentCallback agent.AfterAgentCallback + OnAgentErrorCallback OnAgentErrorCallback + BeforeModelCallback llmagent.BeforeModelCallback AfterModelCallback llmagent.AfterModelCallback OnModelErrorCallback llmagent.OnModelErrorCallback @@ -56,6 +58,7 @@ func New(cfg Config) (*Plugin, error) { afterRunCallback: cfg.AfterRunCallback, beforeAgentCallback: cfg.BeforeAgentCallback, afterAgentCallback: cfg.AfterAgentCallback, + onAgentErrorCallback: cfg.OnAgentErrorCallback, beforeModelCallback: cfg.BeforeModelCallback, afterModelCallback: cfg.AfterModelCallback, onModelErrorCallback: cfg.OnModelErrorCallback, @@ -87,6 +90,8 @@ type Plugin struct { beforeAgentCallback agent.BeforeAgentCallback afterAgentCallback agent.AfterAgentCallback + onAgentErrorCallback OnAgentErrorCallback + beforeModelCallback llmagent.BeforeModelCallback afterModelCallback llmagent.AfterModelCallback onModelErrorCallback llmagent.OnModelErrorCallback @@ -134,6 +139,10 @@ func (p *Plugin) AfterAgentCallback() agent.AfterAgentCallback { return p.afterAgentCallback } +func (p *Plugin) OnAgentErrorCallback() OnAgentErrorCallback { + return p.onAgentErrorCallback +} + func (p *Plugin) BeforeModelCallback() llmagent.BeforeModelCallback { return p.beforeModelCallback } @@ -165,3 +174,5 @@ type BeforeRunCallback func(agent.InvocationContext) (*genai.Content, error) type AfterRunCallback func(agent.InvocationContext) type OnEventCallback func(agent.InvocationContext, *session.Event) (*session.Event, error) + +type OnAgentErrorCallback func(agent.CallbackContext, error) (*genai.Content, error)