diff --git a/cpex/framework/memory.py b/cpex/framework/memory.py index dadfbcee..ec2dc8d6 100644 --- a/cpex/framework/memory.py +++ b/cpex/framework/memory.py @@ -14,6 +14,7 @@ import copy import logging import weakref +from collections.abc import Mapping from typing import Any, Iterator, Optional, TypeVar # Third-Party @@ -173,6 +174,48 @@ def __repr__(self) -> str: """ return f"CopyOnWriteDict({dict(self.items())})" + __hash__ = None + + def __eq__(self, other: Any) -> bool: + """ + Compare equality with another mapping. + + Compares the materialized logical mapping (original + modifications - deletions) + rather than the empty base dict storage. + + Args: + other: The object to compare with. + + Returns: + True if other is a Mapping with the same key-value pairs, False otherwise. + Returns NotImplemented for non-Mapping types to allow other.__eq__ to handle it. + """ + if not isinstance(other, Mapping): + return NotImplemented + + # Fast-path: if lengths differ, mappings cannot be equal + if len(self) != len(other): + return False + + # Compare materialized items + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other: Any) -> bool: + """ + Compare inequality with another mapping. + + Args: + other: The object to compare with. + + Returns: + True if not equal, False if equal. + Returns NotImplemented for non-Mapping types. + """ + eq = self.__eq__(other) + if eq is NotImplemented: + return NotImplemented + return not eq + def get(self, key: Any, default: Optional[Any] = None) -> Any: """ Get an item with a default fallback. diff --git a/pyproject.toml b/pyproject.toml index 8c53c4f5..e947d1ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,10 @@ preview = true fixable = ["ALL"] unfixable = [] +[tool.ruff.lint.pylint] +# Relaxed from the default of 5; existing code has wider try clauses (max observed 38). +max-statements-in-try = 50 + # Ignore D1 (docstring checks) and Pylint checks in tests and other non-production code [tool.ruff.lint.per-file-ignores] "tests/**/*.py" = ["D1", "PL"] diff --git a/tests/unit/cpex/framework/test_memory.py b/tests/unit/cpex/framework/test_memory.py index 070a0785..30ff6221 100644 --- a/tests/unit/cpex/framework/test_memory.py +++ b/tests/unit/cpex/framework/test_memory.py @@ -819,6 +819,104 @@ def test_iter_skips_deleted_keys_in_modifications(self): assert set(keys) == {"b", "c"} assert "a" not in keys + def test_equality_with_empty_dict(self): + """CopyOnWriteDict with data should not equal empty dict.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow != {} + assert {} != cow + assert not (cow == {}) + assert not ({} == cow) + + def test_equality_with_matching_dict(self): + """CopyOnWriteDict should equal dict with same key-value pairs.""" + original = {"a": 1, "b": 2, "c": 3} + cow = CopyOnWriteDict(original) + assert cow == {"a": 1, "b": 2, "c": 3} + assert {"a": 1, "b": 2, "c": 3} == cow + + def test_equality_with_different_dict(self): + """CopyOnWriteDict should not equal dict with different content.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow != {"a": 1, "b": 3} + assert cow != {"a": 1} + assert cow != {"a": 1, "b": 2, "c": 3} + # Same length, different keys + assert cow != {"a": 1, "c": 2} + + def test_equality_after_modifications(self): + """Equality should reflect modifications.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + cow["c"] = 3 + assert cow == {"a": 1, "b": 2, "c": 3} + assert cow != {"a": 1, "b": 2} + + def test_equality_after_deletions(self): + """Equality should reflect deletions.""" + cow = CopyOnWriteDict({"a": 1, "b": 2, "c": 3}) + del cow["b"] + assert cow == {"a": 1, "c": 3} + assert cow != {"a": 1, "b": 2, "c": 3} + + def test_equality_after_override(self): + """Equality should reflect overridden values.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + cow["a"] = 10 + assert cow == {"a": 10, "b": 2} + assert cow != {"a": 1, "b": 2} + + def test_equality_with_another_copyonwritedict(self): + """Two CopyOnWriteDict instances with same content should be equal.""" + cow1 = CopyOnWriteDict({"a": 1, "b": 2}) + cow2 = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow1 == cow2 + assert cow2 == cow1 + + def test_equality_empty_copyonwritedict(self): + """Empty CopyOnWriteDict should equal empty dict.""" + cow = CopyOnWriteDict({}) + assert cow == {} + assert {} == cow + + def test_equality_with_non_mapping_returns_notimplemented(self): + """Equality with non-Mapping types should return NotImplemented.""" + cow = CopyOnWriteDict({"a": 1}) + # These should not raise, Python will handle NotImplemented + assert cow != "not a dict" + assert cow != 123 + assert cow != ["a", "list"] + assert cow != None + + def test_inequality_operator(self): + """Test __ne__ operator works correctly.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow != {} + assert cow != {"a": 1} + assert not (cow != {"a": 1, "b": 2}) + + def test_copyonwritedict_is_unhashable(self): + """CopyOnWriteDict should remain unhashable like dict.""" + cow = CopyOnWriteDict({"a": 1}) + with pytest.raises(TypeError): + hash(cow) + + def test_equality_wxo_args_scenario(self): + """Regression test for the WXO args bug scenario.""" + # This is the exact scenario from the bug report + cow = CopyOnWriteDict({ + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + }) + + # These were the failing assertions in the bug + assert cow != {} + assert {} != cow + assert cow == { + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + } + class TestCopyOnWriteFunction: """Test suite for copyonwrite() factory function.""" diff --git a/tests/unit/cpex/framework/test_policies.py b/tests/unit/cpex/framework/test_policies.py index ef674566..7726c49d 100644 --- a/tests/unit/cpex/framework/test_policies.py +++ b/tests/unit/cpex/framework/test_policies.py @@ -172,6 +172,65 @@ class PayloadWithModel(PluginPayload): assert result is not None assert result.nested.x == 99 # type: ignore[union-attr] + def test_copyonwritedict_args_empty_modification_preserved(self): + """Regression test for bug where CopyOnWriteDict equality caused + apply_policy to drop valid empty args modification. + + When a plugin receives args as CopyOnWriteDict with data and returns + an empty dict, apply_policy should treat this as a valid modification. + Previously, CopyOnWriteDict.__eq__ was not implemented, causing the + comparison to use dict's default equality which compared the empty + base storage, incorrectly returning True for CopyOnWriteDict({...}) == {}. + """ + from cpex.framework.memory import CopyOnWriteDict + + policy = HookPayloadPolicy(writable_fields=frozenset({"args"})) + + # Simulate plugin receiving payload with CopyOnWriteDict args + original = SamplePayload( + name="test", + args=CopyOnWriteDict({ + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + }), + secret="s", + ) + + # Plugin strips all args, returning empty dict + modified = SamplePayload(name="test", args={}, secret="s") + + result = apply_policy(original, modified, policy) + + # The modification should be preserved, not dropped + assert result is not None, "apply_policy should not return None when args changed from {...} to {}" + assert result.args == {} # type: ignore[union-attr] + assert result.name == "test" # type: ignore[union-attr] + assert result.secret == "s" # type: ignore[union-attr] + + def test_copyonwritedict_args_partial_modification_preserved(self): + """Test that partial arg removal is also preserved correctly.""" + from cpex.framework.memory import CopyOnWriteDict + + policy = HookPayloadPolicy(writable_fields=frozenset({"args"})) + + original = SamplePayload( + name="test", + args=CopyOnWriteDict({ + "wxo_auth": "token", + "real_arg": "value", + }), + secret="s", + ) + + # Plugin removes only wxo_auth, keeping real_arg + modified = SamplePayload(name="test", args={"real_arg": "value"}, secret="s") + + result = apply_policy(original, modified, policy) + + assert result is not None + assert result.args == {"real_arg": "value"} # type: ignore[union-attr] + class TestPluginPayloadFrozen: """Tests for frozen PluginPayload base class.""" @@ -752,6 +811,61 @@ async def tool_pre_invoke(self, payload, context): assert result.modified_payload.secret == "safe" # Policy filtered this out + @pytest.mark.asyncio + async def test_tool_pre_invoke_empty_args_modification_preserved_through_executor(self): + """Regression test for the tool_pre_invoke executor path. + + A plugin receives CoW-wrapped args containing only specific fields, + strips them all, and returns a payload with args={}. The executor should + preserve that empty args modification instead of dropping it as + "unchanged". + """ + from cpex.framework.base import HookRef, Plugin, PluginRef + from cpex.framework.hooks.policies import HookPayloadPolicy + from cpex.framework.hooks.tools import ToolPreInvokePayload + from cpex.framework.manager import PluginExecutor + from cpex.framework.memory import CopyOnWriteDict + from cpex.framework.models import GlobalContext, PluginConfig, PluginResult + + seen_arg_types = [] + + class StripWxoArgsPlugin(Plugin): + async def tool_pre_invoke(self, payload, context): + seen_arg_types.append(type(payload.args)) + cleaned_args = {k: v for k, v in payload.args.items() if not k.startswith("wxo_")} + modified = payload.model_copy(update={"args": cleaned_args}) + return PluginResult(continue_processing=True, modified_payload=modified) + + policies = { + "tool_pre_invoke": HookPayloadPolicy(writable_fields=frozenset({"args"})), + } + executor = PluginExecutor(hook_policies=policies) + + config = PluginConfig(name="stripper", kind="test.Plugin", version="1.0", hooks=["tool_pre_invoke"]) + plugin = StripWxoArgsPlugin(config) + hook_ref = HookRef("tool_pre_invoke", PluginRef(plugin)) + + payload = ToolPreInvokePayload( + name="list_all_secrets", + args={ + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + }, + ) + global_ctx = GlobalContext(request_id="tool-pre-empty-args") + + result, _ = await executor.execute([hook_ref], payload, global_ctx, hook_type="tool_pre_invoke") + + assert seen_arg_types == [CopyOnWriteDict] + assert result.modified_payload is not None + assert result.modified_payload == ToolPreInvokePayload(name="list_all_secrets", args={}) + assert payload.args == { + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + } + class TestMultiPluginDictChain: """Tests for multi-plugin chains where an earlier plugin returns a dict payload."""