Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions effectful/handlers/llm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ def vacation() -> str:

"""

def __init__(
self, signature: inspect.Signature, name: str, default: Callable[P, T]
):
def __init__(self, default: Callable[P, T], name: str | None = None):
if not default.__doc__:
raise ValueError("Tools must have docstrings.")
signature = IsRecursive.infer_annotations(signature)
super().__init__(signature, name, default)
super().__init__(default, name=name)

@property
def __signature__(self):
return IsRecursive.infer_annotations(super().__signature__)

@classmethod
def define(cls, *args, **kwargs) -> "Tool[P, T]":
Expand Down
28 changes: 19 additions & 9 deletions effectful/ops/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,30 @@ class Operation[**Q, V]:

"""

__signature__: inspect.Signature
__name__: str
__default__: Callable[Q, V]
__apply__: typing.ClassVar["Operation"]

def __init__(
self, signature: inspect.Signature, name: str, default: Callable[Q, V]
):
def __init__(self, default: Callable[Q, V], name: str | None = None):
functools.update_wrapper(self, default)

self.__signature__ = signature
self.__name__ = name
self.__default__ = default
self.__name__ = name or default.__name__

@property
def __signature__(self):
# Resolve forward references (e.g. -> "MyClass") using the
# default function's __globals__. This handles module-level
# forward refs; local forward refs will raise NameError.
# Python 3.14's annotationlib.get_annotations(format=FORWARDREF)
# could resolve local refs too via PEP 649 __annotate__ functions.
annots = typing.get_type_hints(self.__default__, include_extras=True)
sig = inspect.signature(self.__default__)
updated_params = [
p.replace(annotation=annots[p.name]) if p.name in annots else p
for p in sig.parameters.values()
]
updated_ret = annots.get("return", sig.return_annotation)
return sig.replace(parameters=updated_params, return_annotation=updated_ret)

def __eq__(self, other):
if not isinstance(other, Operation):
Expand Down Expand Up @@ -267,8 +278,7 @@ def func(*args, **kwargs):

op = cls.define(func, name=name)
else:
name = name or t.__name__
op = cls(inspect.signature(t), name, t) # type: ignore[arg-type]
op = cls(t, name=name) # type: ignore[arg-type]

return op # type: ignore[return-value]

Expand Down
15 changes: 15 additions & 0 deletions tests/test_handlers_llm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,3 +1518,18 @@ def test_validate_format_spec_on_undefined_var():
def bad(x: int) -> str:
"""Value: {x} and {missing:.2f}."""
raise NotHandled


# Forward ref through Tool subclass of Operation.
# Use types Pydantic can serialize (not arbitrary classes) to avoid
# PydanticSchemaGenerationError when other tests build tool schemas.
@Tool.define
def _tool_forward_ref(x: "int") -> "str":
"""A tool with forward-referenced parameter and return types."""
raise NotHandled


def test_tool_forward_ref():
sig = inspect.signature(_tool_forward_ref)
assert sig.parameters["x"].annotation is int
assert sig.return_annotation is str
137 changes: 137 additions & 0 deletions tests/test_ops_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,3 +1126,140 @@ def id[T](base: T) -> T:
raise NotHandled

assert isinstance(id(A(0)).x, Term)


# Forward references in types only work on module-level definitions.
@defop
def forward_ref_op() -> "A":
raise NotHandled


class A: ...


def test_defop_forward_ref():
term = forward_ref_op()
assert term.op == forward_ref_op
assert typeof(term) is A

@defop
def local_forward_ref_op() -> "B":
raise NotHandled

class B: ...

with pytest.raises(NameError):
local_forward_ref_op()


# Forward ref in a parameter annotation.
@defop
def _forward_ref_param_op(x: "_ForwardRefParam") -> int:
raise NotHandled


class _ForwardRefParam:
pass


def test_defop_forward_ref_param():
sig = inspect.signature(_forward_ref_param_op)
assert sig.parameters["x"].annotation is _ForwardRefParam
assert sig.return_annotation is int


# Forward ref through Operation.define on a type.
class _ForwardRefType:
pass


_forward_ref_type_op = Operation.define(_ForwardRefType)


def test_define_type_forward_ref():
term = _forward_ref_type_op()
assert term.op == _forward_ref_type_op
assert typeof(term) is _ForwardRefType


# Forward ref on an instance method.
class _ForwardRefMethodHost:
@defop
def my_method(self, x: int) -> "_ForwardRefMethodResult":
raise NotHandled


class _ForwardRefMethodResult:
pass


def test_defop_forward_ref_method():
instance = _ForwardRefMethodHost()
term = instance.my_method(5)
assert isinstance(term, Term)
sig = inspect.signature(_ForwardRefMethodHost.my_method)
assert sig.return_annotation is _ForwardRefMethodResult


# Forward ref on a staticmethod.
class _ForwardRefStaticHost:
@defop
@staticmethod
def my_static(x: int) -> "_ForwardRefStaticResult":
raise NotHandled


class _ForwardRefStaticResult:
pass


def test_defop_forward_ref_staticmethod():
term = _ForwardRefStaticHost.my_static(5)
assert isinstance(term, Term)
sig = inspect.signature(_ForwardRefStaticHost.my_static)
assert sig.return_annotation is _ForwardRefStaticResult


# Forward ref on a classmethod.
class _ForwardRefClassmethodHost:
@defop
@classmethod
def my_classmethod(cls, x: int) -> "_ForwardRefClassmethodResult":
raise NotHandled


class _ForwardRefClassmethodResult:
pass


def test_defop_forward_ref_classmethod():
term = _ForwardRefClassmethodHost.my_classmethod(5)
assert isinstance(term, Term)
sig = inspect.signature(_ForwardRefClassmethodHost.my_classmethod)
assert sig.return_annotation is _ForwardRefClassmethodResult


# Mutual recursion: two classes with forward refs to each other.
class _Coordinate:
@defop
def log(self) -> "_CoordinateTangent":
raise NotHandled


class _CoordinateTangent:
@defop
def exp(self) -> "_Coordinate":
raise NotHandled


def test_defop_forward_ref_mutual_recursion():
coord = _Coordinate()
tangent = _CoordinateTangent()

log_term = coord.log()
assert isinstance(log_term, Term)
assert typeof(log_term) is _CoordinateTangent

exp_term = tangent.exp()
assert isinstance(exp_term, Term)
assert typeof(exp_term) is _Coordinate
41 changes: 40 additions & 1 deletion tests/test_ops_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect

from effectful.ops.syntax import defop
from effectful.ops.types import Interpretation
from effectful.ops.types import Interpretation, NotHandled


def test_interpretation_isinstance():
Expand All @@ -10,3 +12,40 @@ def test_interpretation_isinstance():
assert not isinstance({a: 0, b: "hello"}, Interpretation)
assert not isinstance([a, b], Interpretation)
assert not isinstance({"a": lambda: 0, "b": lambda: "hello"}, Interpretation)


def test_instance_method_signature_excludes_self():
"""Instance-bound operations should not have 'self' in their signature.

When an Operation is used as a method and accessed on an instance,
__get__ creates a new Operation from a bound method. The signature
should reflect the bound method (without 'self'), not the original
unbound function.

This failed with cached_property because functools.update_wrapper
copied a stale __signature__ (with 'self') into __dict__, shadowing
the descriptor.
"""

class MyClass:
@defop
def my_method(self, x: int) -> str:
raise NotHandled

# Access the class-level signature first, which with cached_property
# stores (self, x: int) -> str in MyClass.my_method.__dict__['__signature__'].
# This is the key trigger: __get__ later copies __dict__ via functools.wraps
# to the instance operation, shadowing a cached_property but not a property.
cls_sig = MyClass.my_method.__signature__
assert "self" in cls_sig.parameters # class-level should have self

instance = MyClass()
instance_op = instance.my_method

# The instance operation should have a signature without 'self'
sig = inspect.signature(instance_op)
assert "self" not in sig.parameters
assert "x" in sig.parameters

# Binding should work with just the real args (no 'self')
sig.bind(42)
Loading