diff --git a/effectful/handlers/llm/encoding.py b/effectful/handlers/llm/encoding.py index edd5b07d..11f209dd 100644 --- a/effectful/handlers/llm/encoding.py +++ b/effectful/handlers/llm/encoding.py @@ -1,8 +1,10 @@ import ast import base64 +import collections import functools import inspect import io +import sys import textwrap import types import typing @@ -630,6 +632,123 @@ def deserialize(self, serialized_value: str) -> SynthesizedFunction: return SynthesizedFunction.model_validate_json(serialized_value) +class SynthesizedType(pydantic.BaseModel): + """Structured output for type/class synthesis. + + Pydantic model representing synthesized class code with type name and module code. + """ + + type_name: str = pydantic.Field( + ..., + description="The name of the class that satisfies the specification", + ) + module_code: str = pydantic.Field( + ..., + description="Complete Python module code with the class definition (no imports needed)", + ) + + +@dataclass +class TypeEncodable(Encodable[type, SynthesizedType]): + base: type[type] + enc: type[SynthesizedType] + ctx: Mapping[str, Any] + + _decode_counter: typing.ClassVar[int] = 0 + + def encode(self, value: type) -> SynthesizedType: + type_name = value.__name__ + try: + source = inspect.getsource(value) + except (OSError, TypeError): + source = f"class {type_name}: pass # Source unavailable" + + return SynthesizedType( + type_name=type_name, module_code=textwrap.dedent(source).strip() + ) + + def decode(self, encoded_value: SynthesizedType) -> type: + """Decode a SynthesizedType to a type. + + Executes the module code and returns the named class. + """ + type_name = encoded_value.type_name + module_code = textwrap.dedent(encoded_value.module_code).strip() + "\n" + + TypeEncodable._decode_counter += 1 + module_name = ( + f"_llm_effectful_synthesized_types.{type_name}" + f".{TypeEncodable._decode_counter}" + ) + filename = f"" + + # Create a real module and put it to sys.modules + mod = types.ModuleType(module_name) + mod.__file__ = filename + sys.modules[module_name] = mod + + # globals = module.__dict__ + context + g = mod.__dict__ + g.update({"collections": collections}) + if self.ctx: + g.update(self.ctx) + g.update({"__name__": module_name, "__file__": filename}) + g.setdefault("__package__", module_name.rpartition(".")[0]) + + try: + # Parse via evaluation effect (also registers source in linecache) + tree = evaluation.parse(module_code, filename) + + # Type-check the synthesized module + evaluation.type_check(tree, self.ctx, None, type) + + # Compile and execute via evaluation effects + code_obj = evaluation.compile(tree, filename) + evaluation.exec(code_obj, g) + except SyntaxError as exc: + raise ValueError(f"Syntax error in generated code: {exc}") from exc + + if type_name not in g: + raise ValueError( + f"Type '{type_name}' not found after execution. " + f"Available names: {[k for k in g.keys() if not k.startswith('_')]}" + ) + + synthesized_type = g[type_name] + + if not isinstance(synthesized_type, type): + raise ValueError( + f"'{type_name}' is not a type, got {type(synthesized_type).__name__}" + ) + + # Attach source code and module name + synthesized_type.__source__ = module_code # type: ignore[attr-defined] + synthesized_type.__synthesized__ = encoded_value # type: ignore[attr-defined] + synthesized_type.__module__ = module_name + + # Set __firstlineno__ for Python 3.13+ (inspect.getsource requires it). + # Must be set AFTER __module__ since __module__ assignment can clear it. + firstlineno = next( + ( + n.lineno + for n in ast.walk(ast.parse(module_code)) + if isinstance(n, ast.ClassDef) and n.name == type_name + ), + 1, + ) + synthesized_type.__firstlineno__ = firstlineno # type: ignore[attr-defined] + + return synthesized_type + + def serialize( + self, encoded_value: SynthesizedType + ) -> Sequence[OpenAIMessageContentListBlock]: + return [{"type": "text", "text": encoded_value.model_dump_json()}] + + def deserialize(self, serialized_value: str) -> SynthesizedType: + return SynthesizedType.model_validate_json(serialized_value) + + def _param_model(sig: inspect.Signature) -> type[pydantic.BaseModel]: return pydantic.create_model( "Params", @@ -1048,6 +1167,14 @@ def _encodable_callable( return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return) +@Encodable.define.register(type) +def _encodable_type( + ty: type, ctx: Mapping[str, Any] | None +) -> Encodable[type, SynthesizedType]: + ctx = ctx or {} + return TypeEncodable(ty, SynthesizedType, ctx) + + @Encodable.define.register(Tool) def _encodable_tool[**P, T]( ty: type[Tool[P, T]], ctx: Mapping[str, Any] | None diff --git a/effectful/handlers/llm/evaluation.py b/effectful/handlers/llm/evaluation.py index 07348cc9..e7042486 100644 --- a/effectful/handlers/llm/evaluation.py +++ b/effectful/handlers/llm/evaluation.py @@ -564,12 +564,11 @@ def mypy_type_check( if not module.body: raise TypeError("mypy_type_check: module.body is empty") last = module.body[-1] - if not isinstance(last, ast.FunctionDef): + if not isinstance(last, ast.FunctionDef | ast.ClassDef): raise TypeError( - f"mypy_type_check: last statement must be a function definition, " + f"mypy_type_check: last statement must be a function or class definition, " f"got {type(last).__name__}" ) - func_name = last.name imports = collect_imports(ctx) # Ensure annotations in the postlude can be resolved (e.g. collections.abc.Callable, typing) @@ -614,33 +613,40 @@ def mypy_type_check( stub_module_body = ast.Module(body=module_body, type_ignores=[]) _RenameTransformer(rename_map).visit(stub_module_body) module_body = stub_module_body.body - tc_func_name = rename_map.get(func_name, func_name) else: module_body = list(module.body) - tc_func_name = func_name - - param_types = expected_params - expected_callable_type: type = typing.cast( - type, - collections.abc.Callable[param_types, expected_return] - if expected_params is not None - else collections.abc.Callable[..., expected_return], - ) - expected_callable_ast = type_to_ast(expected_callable_type) - postlude = ast.AnnAssign( - target=ast.Name(id="_synthesized_check", ctx=ast.Store()), - annotation=expected_callable_ast, - value=ast.Name(id=tc_func_name, ctx=ast.Load()), - simple=1, - ) + postlude: list[ast.stmt] = [] + if isinstance(last, ast.FunctionDef): + func_name = last.name + tc_func_name = ( + rename_map.get(func_name, func_name) if colliding_names else func_name + ) + param_types = expected_params + expected_callable_type: type = typing.cast( + type, + collections.abc.Callable[param_types, expected_return] + if expected_params is not None + else collections.abc.Callable[..., expected_return], + ) + expected_callable_ast = type_to_ast(expected_callable_type) + postlude = [ + ast.AnnAssign( + target=ast.Name(id="_synthesized_check", ctx=ast.Store()), + annotation=expected_callable_ast, + value=ast.Name(id=tc_func_name, ctx=ast.Load()), + simple=1, + ) + ] + # For ClassDef: no postlude needed, mypy checks the class body directly. + full_body = ( baseline_imports + list(imports) + list(stubs) + list(variables) + module_body - + [postlude] + + postlude ) stub_module = ast.Module(body=full_body, type_ignores=[]) source = ast.unparse(ast.fix_missing_locations(stub_module)) diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py index 9cccf93d..cda66508 100644 --- a/tests/test_handlers_llm_encoding.py +++ b/tests/test_handlers_llm_encoding.py @@ -22,6 +22,7 @@ DecodedToolCall, Encodable, SynthesizedFunction, + SynthesizedType, ) from effectful.handlers.llm.evaluation import RestrictedEvalProvider, UnsafeEvalProvider from effectful.handlers.llm.template import Tool @@ -699,6 +700,191 @@ def __call__(self): enc.encode(_NoDocCallable()) +# ============================================================================ +# Type: roundtrip, type_check pass/fail, serialize/deserialize +# ============================================================================ + + +class SimplePoint: + def make(self, x: int, y: int) -> "SimplePoint": + self.x = x + self.y = y + return self + + +class Greeter: + def hello(self) -> str: + return "world" + + +# --- pass cases: type_check should succeed --- + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_encode_decode_simple_class(eval_provider): + """Roundtrip encode/decode of a simple class.""" + enc = Encodable.define(type) + with handler(eval_provider): + decoded = enc.decode(enc.encode(SimplePoint)) + assert isinstance(decoded, type) + assert decoded.__name__ == "SimplePoint" + obj = decoded().make(1, 2) + assert obj.x == 1 + assert obj.y == 2 + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_valid_class_code(eval_provider): + """Decode hand-crafted SynthesizedType with valid class code.""" + code = SynthesizedType( + type_name="Adder", + module_code="class Adder:\n def add(self, a: int, b: int) -> int:\n return a + b\n", + ) + enc = Encodable.define(type) + with handler(eval_provider): + decoded = enc.decode(code) + assert decoded.__name__ == "Adder" + assert decoded().add(3, 4) == 7 + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_class_with_context(eval_provider): + """Decode a class that references a type from lexical context.""" + code = SynthesizedType( + type_name="ChildGreeter", + module_code=( + "class ChildGreeter(BaseGreeter):\n" + " def greet(self) -> str:\n" + " return 'child'\n" + ), + ) + + class BaseGreeter: + def greet(self) -> str: + return "base" + + ctx = {"BaseGreeter": BaseGreeter} + enc = Encodable.define(type, ctx) + with handler(eval_provider): + decoded = enc.decode(code) + obj = decoded() + assert obj.greet() == "child" + assert isinstance(obj, BaseGreeter) + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_full_pipeline(eval_provider): + """Full encode->serialize->deserialize->decode pipeline.""" + enc = Encodable.define(type) + encoded = enc.encode(Greeter) + serialized = enc.serialize(encoded) + deserialized = enc.deserialize(serialized[0]["text"]) + with handler(eval_provider): + decoded = enc.decode(deserialized) + assert isinstance(decoded, type) + assert decoded().hello() == "world" + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_inspect_getsource_works(eval_provider): + """inspect.getsource() works on decoded synthesized types.""" + code = SynthesizedType( + type_name="Greeter", + module_code="class Greeter:\n def hello(self) -> str:\n return 'world'\n", + ) + enc = Encodable.define(type) + with handler(eval_provider): + decoded = enc.decode(code) + source = inspect.getsource(decoded) + assert "class Greeter" in source + assert "hello" in source + + +# --- fail cases: type_check should reject / decode should raise --- + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_syntax_error(eval_provider): + """Syntax error in module_code raises ValueError.""" + code = SynthesizedType( + type_name="Bad", + module_code="class Bad:\n def __init__(self) # missing colon\n pass\n", + ) + enc = Encodable.define(type) + with pytest.raises(ValueError): + with handler(eval_provider): + enc.decode(code) + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_missing_type_name(eval_provider): + """Code executes but doesn't define the expected type name.""" + code = SynthesizedType( + type_name="Expected", + module_code="class Actual:\n pass\n", + ) + enc = Encodable.define(type) + with pytest.raises(ValueError, match="Expected"): + with handler(eval_provider): + enc.decode(code) + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_not_a_type(eval_provider): + """Code that doesn't define a class is rejected by type_check.""" + code = SynthesizedType( + type_name="MyType", + module_code="MyType = 42\n", + ) + enc = Encodable.define(type) + with pytest.raises(TypeError): + with handler(eval_provider): + enc.decode(code) + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_undefined_base_class(eval_provider): + """Code references an undefined base class not in context.""" + code = SynthesizedType( + type_name="Child", + module_code="class Child(UndefinedBase):\n pass\n", + ) + enc = Encodable.define(type, {}) + with pytest.raises((ValueError, TypeError)): + with handler(eval_provider): + enc.decode(code) + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_runtime_error_in_class_body(eval_provider): + """Class body raises an error during execution.""" + code = SynthesizedType( + type_name="Broken", + module_code="class Broken:\n x = 1 / 0\n", + ) + enc = Encodable.define(type) + with pytest.raises((ValueError, ZeroDivisionError)): + with handler(eval_provider): + enc.decode(code) + + +@pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) +def test_type_decode_type_check_catches_bad_method_types(eval_provider): + """type_check rejects a class with mistyped method (returns str, annotation says int).""" + code = SynthesizedType( + type_name="BadTypes", + module_code=( + "class BadTypes:\n" + " def compute(self) -> int:\n" + ' return "not an int"\n' + ), + ) + enc = Encodable.define(type) + with pytest.raises(TypeError): + with handler(eval_provider): + enc.decode(code) + + # --------------------------------------------------------------------------- # Provider integration tests # ---------------------------------------------------------------------------