Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 360 additions & 0 deletions examples/session_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
#!/usr/bin/env python
"""session_monitor.py — behavioral consistency monitoring for long SDK sessions.

This example stays on the public SDK surface:

- `HookMatcher`-based `PreToolUse` / `PostToolUse` callbacks
- `ClaudeSDKClient.query()` + `receive_response()` for turns
- `ClaudeSDKClient.get_context_usage()` for context-window telemetry

Together, those are enough to build a lightweight monitor for long-running
sessions where context compaction or summarization may silently change the
agent's behavior.

Because a short fresh session will not reliably trigger compaction on demand,
the default runnable demo below uses a simulated token-usage boundary while the
live integration helpers keep the exact public SDK wiring you would use in a
real session.

Reference: https://github.com/anthropics/claude-agent-sdk-python/issues/772
"""

import asyncio
import json
import math
import os
import re
import time
from collections import Counter
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional

from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient, HookMatcher
from claude_agent_sdk.types import (
AssistantMessage,
HookContext,
HookJSONOutput,
PostToolUseHookInput,
PreToolUseHookInput,
ResultMessage,
TextBlock,
)


@dataclass
class BehavioralSnapshot:
"""What the agent looks like at one point in the session."""

turn: int
tokens: int
timestamp: float
tool_counts: Counter = field(default_factory=Counter)
vocabulary: set[str] = field(default_factory=set)


class SessionMonitor:
"""Track vocabulary and tool-use drift across a Claude SDK session."""

def __init__(
self,
compaction_drop_ratio: float = 0.20,
drift_threshold: float = 0.30,
log_path: Optional[Path] = None,
) -> None:
self.compaction_drop_ratio = compaction_drop_ratio
self.drift_threshold = drift_threshold
self.log_path = log_path

self._baseline: Optional[BehavioralSnapshot] = None
self._current: Optional[BehavioralSnapshot] = None
self._turn = 0
self._compaction_events: list[dict[str, Any]] = []
self._drift_scores: list[float] = []
self._pending_tool_counts: Counter = Counter()
self._pending_vocabulary: set[str] = set()

async def on_pre_tool_use(
self,
input_data: PreToolUseHookInput,
tool_use_id: Optional[str],
context: HookContext,
) -> HookJSONOutput:
"""Record each tool call before execution."""

del tool_use_id, context
self._pending_tool_counts[input_data["tool_name"]] += 1
return {}

async def on_post_tool_use(
self,
input_data: PostToolUseHookInput,
tool_use_id: Optional[str],
context: HookContext,
) -> HookJSONOutput:
"""Capture vocabulary emitted by tool results."""

del tool_use_id, context
tool_response = str(input_data.get("tool_response", ""))
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", tool_response.lower()))
self._pending_vocabulary.update(words)
return {}

def record_turn(self, message_text: str, total_tokens: int) -> Optional[dict[str, Any]]:
"""Record a completed turn and return any detected event."""

self._turn += 1
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", message_text.lower()))
prev_tokens = self._current.tokens if self._current else 0

self._current = BehavioralSnapshot(
turn=self._turn,
tokens=total_tokens,
timestamp=time.time(),
tool_counts=Counter(self._pending_tool_counts),
vocabulary=words | self._pending_vocabulary,
)

self._pending_tool_counts.clear()
self._pending_vocabulary.clear()

if self._baseline is None and total_tokens > 0:
self._baseline = BehavioralSnapshot(
turn=self._turn,
tokens=total_tokens,
timestamp=self._current.timestamp,
tool_counts=Counter(self._current.tool_counts),
vocabulary=set(self._current.vocabulary),
)
return None

if self._baseline is None:
return None

compaction_detected = False
if prev_tokens > 0 and total_tokens < prev_tokens * (1 - self.compaction_drop_ratio):
compaction_detected = True
event = {
"event": "compaction_suspected",
"turn": self._turn,
"tokens_before": prev_tokens,
"tokens_after": total_tokens,
"drop_ratio": round(1.0 - total_tokens / prev_tokens, 3),
"timestamp": self._current.timestamp,
}
self._compaction_events.append(event)
self._log(event)

ccs = self._compute_ccs()
self._drift_scores.append(ccs)

if ccs < (1.0 - self.drift_threshold) or compaction_detected:
event = {
"event": "post_compaction_drift" if compaction_detected else "behavioral_drift",
"turn": self._turn,
"ccs": round(ccs, 3),
"compaction_at_this_turn": compaction_detected,
"ghost_terms": self._ghost_terms(),
"tool_shift": self._tool_shift_summary(),
}
self._log(event)
return event

return None

def _compute_ccs(self) -> float:
"""Context Consistency Score: 1.0 means no behavioral change."""

return 0.6 * self._vocab_overlap() + 0.4 * self._tool_consistency()

def _vocab_overlap(self) -> float:
if not self._baseline or not self._baseline.vocabulary or not self._current:
return 1.0
if not self._current.vocabulary:
return 1.0
intersection = self._baseline.vocabulary & self._current.vocabulary
union = self._baseline.vocabulary | self._current.vocabulary
return len(intersection) / len(union) if union else 1.0

def _ghost_terms(self) -> list[str]:
if not self._baseline or not self._current:
return []
return sorted(self._baseline.vocabulary - self._current.vocabulary)[:20]

def _tool_consistency(self) -> float:
if not self._baseline or not self._current:
return 1.0
if not self._baseline.tool_counts or not self._current.tool_counts:
return 1.0

all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
baseline_total = sum(self._baseline.tool_counts.values()) or 1
current_total = sum(self._current.tool_counts.values()) or 1
baseline_distribution = {
tool: self._baseline.tool_counts.get(tool, 0) / baseline_total
for tool in all_tools
}
Comment on lines +192 to +196
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Nit: The epsilon in the KL divergence helper is misplaced — math.log(lhs[tool] / rhs[tool] + 1e-10) adds epsilon to the ratio (post-division), not to the denominator. If rhs[tool] were 0, ZeroDivisionError would fire before the epsilon takes effect. The fix is math.log(lhs[tool] / (rhs[tool] + 1e-10)). This is latent in the current JSD usage (the midpoint is always positive when lhs[tool] > 0), but the misplaced epsilon gives a false sense of numerical safety in example code others may adapt.

Extended reasoning...

What the bug is

The kl_divergence inner function computes:

lhs[tool] * math.log(lhs[tool] / rhs[tool] + 1e-10)

Due to Python operator precedence, this parses as math.log((lhs[tool] / rhs[tool]) + 1e-10). The epsilon is added to the ratio p/q, not to the denominator q. The standard numerical-safety pattern for KL divergence is math.log(lhs[tool] / (rhs[tool] + 1e-10)), which prevents division by zero when rhs[tool] == 0.

How it would manifest

If rhs[tool] is ever 0 while lhs[tool] > 0, the expression lhs[tool] / rhs[tool] evaluates first (before the + 1e-10), raising a ZeroDivisionError. The epsilon never gets a chance to help.

Why it does not trigger today

In the JSD context, rhs is always the midpoint distribution: midpoint[tool] = 0.5 * (baseline_distribution[tool] + current_distribution[tool]). The generator expression filters on if lhs[tool] > 0, meaning lhs[tool] (which is either baseline_distribution[tool] or current_distribution[tool]) is positive. Since the midpoint averages two non-negative values and at least one of them (lhs[tool]) is positive, midpoint[tool] >= 0.5 * lhs[tool] > 0. Division by zero is structurally impossible.

Step-by-step proof with a concrete example

Consider baseline_distribution = {"Bash": 1.0} and current_distribution = {"Bash": 0.0, "Read": 1.0}. Then midpoint = {"Bash": 0.5, "Read": 0.5}. For kl_divergence(baseline_distribution, midpoint), tool="Bash": lhs["Bash"] = 1.0 > 0, rhs["Bash"] = 0.5, so we compute 1.0 * math.log(1.0 / 0.5 + 1e-10) = math.log(2.0000000001) — works fine. Now imagine someone calls kl_divergence(baseline_distribution, current_distribution) directly (not via JSD): tool="Bash": lhs["Bash"] = 1.0 > 0, rhs["Bash"] = 0.0, so 1.0 / 0.0ZeroDivisionError before + 1e-10 is reached.

Impact and fix

Since this is example code that users may copy and adapt for general KL divergence computation, the misleading epsilon placement could lead to bugs in derivative code. The fix is simply adding parentheses: math.log(lhs[tool] / (rhs[tool] + 1e-10)). This is a cosmetic/correctness nit — the current code works correctly for its specific JSD use case.

current_distribution = {
tool: self._current.tool_counts.get(tool, 0) / current_total
for tool in all_tools
}
midpoint = {
tool: 0.5 * (baseline_distribution[tool] + current_distribution[tool])
for tool in all_tools
}

def kl_divergence(lhs: dict[str, float], rhs: dict[str, float]) -> float:
return sum(
lhs[tool] * math.log(lhs[tool] / rhs[tool] + 1e-10)
for tool in all_tools
if lhs[tool] > 0
)

jsd = 0.5 * kl_divergence(baseline_distribution, midpoint) + 0.5 * kl_divergence(
current_distribution, midpoint
)
return max(0.0, 1.0 - jsd)

def _tool_shift_summary(self) -> dict[str, dict[str, int]]:
if not self._baseline or not self._current:
return {}
all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
return {
tool: {
"baseline": self._baseline.tool_counts.get(tool, 0),
"current": self._current.tool_counts.get(tool, 0),
}
for tool in all_tools
}

def summary(self) -> dict[str, Any]:
return {
"turns": self._turn,
"compaction_events": len(self._compaction_events),
"avg_ccs": round(sum(self._drift_scores) / len(self._drift_scores), 3)
if self._drift_scores
else None,
"min_ccs": round(min(self._drift_scores), 3) if self._drift_scores else None,
"compaction_detail": self._compaction_events,
}

def _log(self, event: dict[str, Any]) -> None:
if self.log_path:
with self.log_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(event) + "\n")
else:
print(f"[session_monitor] {json.dumps(event)}")


async def run_monitored_turn(
client: ClaudeSDKClient,
monitor: SessionMonitor,
prompt: str,
) -> Optional[dict[str, Any]]:
"""Run one SDK turn, then score it using public message + usage APIs."""

await client.query(prompt)

text_parts: list[str] = []
async for message in client.receive_response():
if isinstance(message, AssistantMessage):
for block in message.content:
if isinstance(block, TextBlock):
text_parts.append(block.text)
elif isinstance(message, ResultMessage) and message.is_error:
raise RuntimeError(message.result or "Claude SDK turn failed")

usage = await client.get_context_usage()
total_tokens = int(usage.get("totalTokens", 0))
return monitor.record_turn(" ".join(text_parts), total_tokens)


def run_simulated_boundary_demo() -> None:
"""Run a deterministic boundary demo using the same scoring logic."""
monitor = SessionMonitor(
compaction_drop_ratio=0.20,
drift_threshold=0.30,
log_path=None,
)

synthetic_turns = [
(
"Use jwt validation with bcrypt hashes, redis-backed sessions, and "
"foreign_key-safe migrations for the auth schema.",
1200,
),
(
"Keep jwt auth, bcrypt password storage, and redis session checks "
"intact while you add the profile endpoint.",
1480,
),
(
"Add PATCH /profile rate limiting with concise validation and 429 "
"responses. Focus on the endpoint only.",
860,
),
]

print("=== Deterministic session boundary demo ===")
print("This uses simulated token snapshots so the monitor always shows a boundary event.\n")
for text, total_tokens in synthetic_turns:
event = monitor.record_turn(text, total_tokens)
if event:
print(json.dumps(event, indent=2))

print("\n=== Session summary ===")
print(json.dumps(monitor.summary(), indent=2))


async def run_live_demo() -> None:
"""Optional live SDK demo using the same monitor."""
monitor = SessionMonitor(
compaction_drop_ratio=0.20,
drift_threshold=0.30,
log_path=None,
)

options = ClaudeAgentOptions(
allowed_tools=["Bash"],
hooks={
"PreToolUse": [
HookMatcher(matcher="Bash", hooks=[monitor.on_pre_tool_use]),
],
"PostToolUse": [
HookMatcher(matcher="Bash", hooks=[monitor.on_post_tool_use]),
],
},
)

prompts = [
"Use Bash to print 'jwt bcrypt redis', then explain how those terms fit together in a web auth stack.",
"Use Bash to print 'id,name\\n1,Ada', then explain how pandas would load this CSV.",
"Use Bash to print '[0 1 2]', then explain numpy arrays in one short paragraph.",
]

async with ClaudeSDKClient(options=options) as client:
for prompt in prompts:
event = await run_monitored_turn(client, monitor, prompt)
if event:
print(f"\n[session_monitor] Behavioral event: {json.dumps(event, indent=2)}")

print("\n=== Session summary ===")
print(json.dumps(monitor.summary(), indent=2))
print()
print("Note: native OnCompaction / OnContextThreshold hooks would still be better.")
print("This sample shows the closest monitor you can build today with public hooks")
print("plus get_context_usage() as the compaction-boundary heuristic.")


async def main() -> None:
run_simulated_boundary_demo()

if os.getenv("CLAUDE_SESSION_MONITOR_LIVE") == "1":
print("\n=== Live SDK session demo ===")
await run_live_demo()
else:
print("\nSet CLAUDE_SESSION_MONITOR_LIVE=1 to also run the live SDK session demo.")


if __name__ == "__main__":
asyncio.run(main())