From dafffccbfa657c8ff37fd173a184e03b13a99540 Mon Sep 17 00:00:00 2001 From: prakhar-singh1928 Date: Thu, 28 May 2026 11:22:34 +0100 Subject: [PATCH 1/4] fix: implement __eq__ and __ne__ for CopyOnWriteDict Fixes equality comparison bug where CopyOnWriteDict compared equal to {} even when containing data. This caused apply_policy() to incorrectly drop valid payload modifications when plugins removed all arguments. Changes: - Add __eq__ and __ne__ methods to CopyOnWriteDict - Add 13 comprehensive equality unit tests - Add policy regression tests for empty args scenario - Add end-to-end integration tests Signed-off-by: prakhar-singh1928 --- cpex/framework/memory.py | 41 ++++++++ tests/unit/cpex/framework/test_memory.py | 97 +++++++++++++++++- tests/unit/cpex/framework/test_policies.py | 114 +++++++++++++++++++++ 3 files changed, 251 insertions(+), 1 deletion(-) diff --git a/cpex/framework/memory.py b/cpex/framework/memory.py index dadfbcee..c19c91e0 100644 --- a/cpex/framework/memory.py +++ b/cpex/framework/memory.py @@ -173,6 +173,47 @@ def __repr__(self) -> str: """ return f"CopyOnWriteDict({dict(self.items())})" + __hash__ = None + + def __eq__(self, other: Any) -> bool: + """ + Compare equality with another mapping. + + Compares the materialized logical mapping (original + modifications - deletions) + rather than the empty base dict storage. + + Args: + other: The object to compare with. + + Returns: + True if other is a Mapping with the same key-value pairs, False otherwise. + Returns NotImplemented for non-Mapping types to allow other.__eq__ to handle it. + """ + # Import here to avoid circular dependency + from collections.abc import Mapping + + if not isinstance(other, Mapping): + return NotImplemented + + # Compare materialized items + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other: Any) -> bool: + """ + Compare inequality with another mapping. + + Args: + other: The object to compare with. + + Returns: + True if not equal, False if equal. + Returns NotImplemented for non-Mapping types. + """ + eq = self.__eq__(other) + if eq is NotImplemented: + return NotImplemented + return not eq + def get(self, key: Any, default: Optional[Any] = None) -> Any: """ Get an item with a default fallback. diff --git a/tests/unit/cpex/framework/test_memory.py b/tests/unit/cpex/framework/test_memory.py index 070a0785..b185685d 100644 --- a/tests/unit/cpex/framework/test_memory.py +++ b/tests/unit/cpex/framework/test_memory.py @@ -817,7 +817,102 @@ def test_iter_skips_deleted_keys_in_modifications(self): keys = list(cow) # Should only have b (from original) and c (from modifications, not deleted) assert set(keys) == {"b", "c"} - assert "a" not in keys + + def test_equality_with_empty_dict(self): + """CopyOnWriteDict with data should not equal empty dict.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow != {} + assert {} != cow + assert not (cow == {}) + assert not ({} == cow) + + def test_equality_with_matching_dict(self): + """CopyOnWriteDict should equal dict with same key-value pairs.""" + original = {"a": 1, "b": 2, "c": 3} + cow = CopyOnWriteDict(original) + assert cow == {"a": 1, "b": 2, "c": 3} + assert {"a": 1, "b": 2, "c": 3} == cow + + def test_equality_with_different_dict(self): + """CopyOnWriteDict should not equal dict with different content.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow != {"a": 1, "b": 3} + assert cow != {"a": 1} + assert cow != {"a": 1, "b": 2, "c": 3} + + def test_equality_after_modifications(self): + """Equality should reflect modifications.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + cow["c"] = 3 + assert cow == {"a": 1, "b": 2, "c": 3} + assert cow != {"a": 1, "b": 2} + + def test_equality_after_deletions(self): + """Equality should reflect deletions.""" + cow = CopyOnWriteDict({"a": 1, "b": 2, "c": 3}) + del cow["b"] + assert cow == {"a": 1, "c": 3} + assert cow != {"a": 1, "b": 2, "c": 3} + + def test_equality_after_override(self): + """Equality should reflect overridden values.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + cow["a"] = 10 + assert cow == {"a": 10, "b": 2} + assert cow != {"a": 1, "b": 2} + + def test_equality_with_another_copyonwritedict(self): + """Two CopyOnWriteDict instances with same content should be equal.""" + cow1 = CopyOnWriteDict({"a": 1, "b": 2}) + cow2 = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow1 == cow2 + assert cow2 == cow1 + + def test_equality_empty_copyonwritedict(self): + """Empty CopyOnWriteDict should equal empty dict.""" + cow = CopyOnWriteDict({}) + assert cow == {} + assert {} == cow + + def test_equality_with_non_mapping_returns_notimplemented(self): + """Equality with non-Mapping types should return NotImplemented.""" + cow = CopyOnWriteDict({"a": 1}) + # These should not raise, Python will handle NotImplemented + assert cow != "not a dict" + assert cow != 123 + assert cow != ["a", "list"] + assert cow != None + + def test_inequality_operator(self): + """Test __ne__ operator works correctly.""" + cow = CopyOnWriteDict({"a": 1, "b": 2}) + assert cow != {} + assert cow != {"a": 1} + assert not (cow != {"a": 1, "b": 2}) + + def test_copyonwritedict_is_unhashable(self): + """CopyOnWriteDict should remain unhashable like dict.""" + cow = CopyOnWriteDict({"a": 1}) + with pytest.raises(TypeError): + hash(cow) + + def test_equality_wxo_args_scenario(self): + """Regression test for the WXO args bug scenario.""" + # This is the exact scenario from the bug report + cow = CopyOnWriteDict({ + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + }) + + # These were the failing assertions in the bug + assert cow != {} + assert {} != cow + assert cow == { + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + } class TestCopyOnWriteFunction: diff --git a/tests/unit/cpex/framework/test_policies.py b/tests/unit/cpex/framework/test_policies.py index ef674566..7726c49d 100644 --- a/tests/unit/cpex/framework/test_policies.py +++ b/tests/unit/cpex/framework/test_policies.py @@ -172,6 +172,65 @@ class PayloadWithModel(PluginPayload): assert result is not None assert result.nested.x == 99 # type: ignore[union-attr] + def test_copyonwritedict_args_empty_modification_preserved(self): + """Regression test for bug where CopyOnWriteDict equality caused + apply_policy to drop valid empty args modification. + + When a plugin receives args as CopyOnWriteDict with data and returns + an empty dict, apply_policy should treat this as a valid modification. + Previously, CopyOnWriteDict.__eq__ was not implemented, causing the + comparison to use dict's default equality which compared the empty + base storage, incorrectly returning True for CopyOnWriteDict({...}) == {}. + """ + from cpex.framework.memory import CopyOnWriteDict + + policy = HookPayloadPolicy(writable_fields=frozenset({"args"})) + + # Simulate plugin receiving payload with CopyOnWriteDict args + original = SamplePayload( + name="test", + args=CopyOnWriteDict({ + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + }), + secret="s", + ) + + # Plugin strips all args, returning empty dict + modified = SamplePayload(name="test", args={}, secret="s") + + result = apply_policy(original, modified, policy) + + # The modification should be preserved, not dropped + assert result is not None, "apply_policy should not return None when args changed from {...} to {}" + assert result.args == {} # type: ignore[union-attr] + assert result.name == "test" # type: ignore[union-attr] + assert result.secret == "s" # type: ignore[union-attr] + + def test_copyonwritedict_args_partial_modification_preserved(self): + """Test that partial arg removal is also preserved correctly.""" + from cpex.framework.memory import CopyOnWriteDict + + policy = HookPayloadPolicy(writable_fields=frozenset({"args"})) + + original = SamplePayload( + name="test", + args=CopyOnWriteDict({ + "wxo_auth": "token", + "real_arg": "value", + }), + secret="s", + ) + + # Plugin removes only wxo_auth, keeping real_arg + modified = SamplePayload(name="test", args={"real_arg": "value"}, secret="s") + + result = apply_policy(original, modified, policy) + + assert result is not None + assert result.args == {"real_arg": "value"} # type: ignore[union-attr] + class TestPluginPayloadFrozen: """Tests for frozen PluginPayload base class.""" @@ -752,6 +811,61 @@ async def tool_pre_invoke(self, payload, context): assert result.modified_payload.secret == "safe" # Policy filtered this out + @pytest.mark.asyncio + async def test_tool_pre_invoke_empty_args_modification_preserved_through_executor(self): + """Regression test for the tool_pre_invoke executor path. + + A plugin receives CoW-wrapped args containing only specific fields, + strips them all, and returns a payload with args={}. The executor should + preserve that empty args modification instead of dropping it as + "unchanged". + """ + from cpex.framework.base import HookRef, Plugin, PluginRef + from cpex.framework.hooks.policies import HookPayloadPolicy + from cpex.framework.hooks.tools import ToolPreInvokePayload + from cpex.framework.manager import PluginExecutor + from cpex.framework.memory import CopyOnWriteDict + from cpex.framework.models import GlobalContext, PluginConfig, PluginResult + + seen_arg_types = [] + + class StripWxoArgsPlugin(Plugin): + async def tool_pre_invoke(self, payload, context): + seen_arg_types.append(type(payload.args)) + cleaned_args = {k: v for k, v in payload.args.items() if not k.startswith("wxo_")} + modified = payload.model_copy(update={"args": cleaned_args}) + return PluginResult(continue_processing=True, modified_payload=modified) + + policies = { + "tool_pre_invoke": HookPayloadPolicy(writable_fields=frozenset({"args"})), + } + executor = PluginExecutor(hook_policies=policies) + + config = PluginConfig(name="stripper", kind="test.Plugin", version="1.0", hooks=["tool_pre_invoke"]) + plugin = StripWxoArgsPlugin(config) + hook_ref = HookRef("tool_pre_invoke", PluginRef(plugin)) + + payload = ToolPreInvokePayload( + name="list_all_secrets", + args={ + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + }, + ) + global_ctx = GlobalContext(request_id="tool-pre-empty-args") + + result, _ = await executor.execute([hook_ref], payload, global_ctx, hook_type="tool_pre_invoke") + + assert seen_arg_types == [CopyOnWriteDict] + assert result.modified_payload is not None + assert result.modified_payload == ToolPreInvokePayload(name="list_all_secrets", args={}) + assert payload.args == { + "wxo_connection_id": "", + "wxo_auth": "fake-token", + "wxo_environment_id": "draft", + } + class TestMultiPluginDictChain: """Tests for multi-plugin chains where an earlier plugin returns a dict payload.""" From 9be38dc42dec44203d70ef66976e7e8144c4949c Mon Sep 17 00:00:00 2001 From: prakhar-singh1928 Date: Thu, 28 May 2026 11:58:50 +0100 Subject: [PATCH 2/4] fix: added length check for performance Signed-off-by: prakhar-singh1928 --- cpex/framework/memory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpex/framework/memory.py b/cpex/framework/memory.py index c19c91e0..0df53fb8 100644 --- a/cpex/framework/memory.py +++ b/cpex/framework/memory.py @@ -195,6 +195,10 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, Mapping): return NotImplemented + # Fast-path: if lengths differ, mappings cannot be equal + if len(self) != len(other): + return False + # Compare materialized items return dict(self.items()) == dict(other.items()) From 320528bc800c59067ea3af067d4a2f1fd4a52c7f Mon Sep 17 00:00:00 2001 From: prakhar-singh1928 Date: Thu, 28 May 2026 13:42:38 +0100 Subject: [PATCH 3/4] fix: restore deleted assertion and add performance optimization - Restored missing 'assert a not in keys' in test_iteration_order_with_deletions - Added fast-path length check in CopyOnWriteDict.__eq__() for better performance - Performance optimization is safe: if lengths differ, mappings cannot be equal Signed-off-by: prakhar-singh1928 --- tests/unit/cpex/framework/test_memory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/cpex/framework/test_memory.py b/tests/unit/cpex/framework/test_memory.py index b185685d..7463b40b 100644 --- a/tests/unit/cpex/framework/test_memory.py +++ b/tests/unit/cpex/framework/test_memory.py @@ -817,6 +817,7 @@ def test_iter_skips_deleted_keys_in_modifications(self): keys = list(cow) # Should only have b (from original) and c (from modifications, not deleted) assert set(keys) == {"b", "c"} + assert "a" not in keys def test_equality_with_empty_dict(self): """CopyOnWriteDict with data should not equal empty dict.""" From 307b1e89c4375dfa94a1952ed6f602d64c37fab6 Mon Sep 17 00:00:00 2001 From: Teryl Taylor Date: Fri, 29 May 2026 06:44:52 -0600 Subject: [PATCH 4/4] fix: linted memory.py, added assertion to test. --- cpex/framework/memory.py | 8 +++----- pyproject.toml | 4 ++++ tests/unit/cpex/framework/test_memory.py | 2 ++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cpex/framework/memory.py b/cpex/framework/memory.py index 0df53fb8..ec2dc8d6 100644 --- a/cpex/framework/memory.py +++ b/cpex/framework/memory.py @@ -14,6 +14,7 @@ import copy import logging import weakref +from collections.abc import Mapping from typing import Any, Iterator, Optional, TypeVar # Third-Party @@ -189,16 +190,13 @@ def __eq__(self, other: Any) -> bool: True if other is a Mapping with the same key-value pairs, False otherwise. Returns NotImplemented for non-Mapping types to allow other.__eq__ to handle it. """ - # Import here to avoid circular dependency - from collections.abc import Mapping - if not isinstance(other, Mapping): return NotImplemented - + # Fast-path: if lengths differ, mappings cannot be equal if len(self) != len(other): return False - + # Compare materialized items return dict(self.items()) == dict(other.items()) diff --git a/pyproject.toml b/pyproject.toml index 8c53c4f5..e947d1ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,10 @@ preview = true fixable = ["ALL"] unfixable = [] +[tool.ruff.lint.pylint] +# Relaxed from the default of 5; existing code has wider try clauses (max observed 38). +max-statements-in-try = 50 + # Ignore D1 (docstring checks) and Pylint checks in tests and other non-production code [tool.ruff.lint.per-file-ignores] "tests/**/*.py" = ["D1", "PL"] diff --git a/tests/unit/cpex/framework/test_memory.py b/tests/unit/cpex/framework/test_memory.py index 7463b40b..30ff6221 100644 --- a/tests/unit/cpex/framework/test_memory.py +++ b/tests/unit/cpex/framework/test_memory.py @@ -840,6 +840,8 @@ def test_equality_with_different_dict(self): assert cow != {"a": 1, "b": 3} assert cow != {"a": 1} assert cow != {"a": 1, "b": 2, "c": 3} + # Same length, different keys + assert cow != {"a": 1, "c": 2} def test_equality_after_modifications(self): """Equality should reflect modifications."""