Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ incremental in minor, bugfixes only are patches.
See [0Ver](https://0ver.org/).


## Unreleased

### Bugfixes

- Fixes the `curry.partial` compatibility with mypy 1.6.1+


## 0.26.0

### Features
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ lint.per-file-ignores."tests/test_examples/test_result/test_result_pattern_match
"D103",
]
lint.per-file-ignores."tests/test_pattern_matching.py" = [ "S101" ]
lint.per-file-ignores."typesafety/test_curry/test_partial/test_partial.py" = [ "S101" ]
lint.external = [ "WPS" ]
lint.flake8-quotes.inline-quotes = "single"
lint.mccabe.max-complexity = 6
Expand Down
17 changes: 16 additions & 1 deletion returns/contrib/hypothesis/laws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import inspect
import sys
from collections.abc import Callable, Iterator
from contextlib import ExitStack, contextmanager
from typing import Any, TypeVar, final, overload
Expand Down Expand Up @@ -242,7 +243,21 @@ def _create_law_test_case(
)

called_from = inspect.stack()[2]
module = inspect.getmodule(called_from[0])
# `inspect.getmodule(frame)` is surprisingly fragile under some import
# modes (notably `pytest` collection with assertion rewriting) and can
# return `None`. Use the module name from the caller's globals instead.
module_name = called_from.frame.f_globals.get('__name__')
if module_name is None:
module = None
else:
module = sys.modules.get(module_name)
if module is None:
module = inspect.getmodule(called_from.frame)
if module is None:
raise RuntimeError(
'Cannot determine a module to attach generated law tests to. '
'Please call `check_all_laws` from an imported module scope.',
)

template = 'test_{container}_{interface}_{name}'
test_function.__name__ = template.format( # noqa: WPS125
Expand Down
83 changes: 58 additions & 25 deletions returns/contrib/mypy/_features/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from mypy.nodes import ARG_STAR, ARG_STAR2
from mypy.plugin import FunctionContext
from mypy.types import (
AnyType,
CallableType,
FunctionLike,
Instance,
Overloaded,
ProperType,
TypeOfAny,
TypeType,
get_proper_type,
)
Expand Down Expand Up @@ -51,30 +53,55 @@ def analyze(ctx: FunctionContext) -> ProperType:
default_return = get_proper_type(ctx.default_return_type)
if not isinstance(default_return, CallableType):
return default_return
return _analyze_partial(ctx, default_return)


def _analyze_partial(
ctx: FunctionContext,
default_return: CallableType,
) -> ProperType:
if not ctx.arg_types or not ctx.arg_types[0]:
# No function passed: treat as decorator factory and fallback to Any.
return AnyType(TypeOfAny.implementation_artifact)

function_def = get_proper_type(ctx.arg_types[0][0])
func_args = _AppliedArgs(ctx)

if len(list(filter(len, ctx.arg_types))) == 1:
return function_def # this means, that `partial(func)` is called
if not isinstance(function_def, _SUPPORTED_TYPES):
is_valid, applied_args = func_args.build_from_context()
if not is_valid:
return default_return
if isinstance(function_def, Instance | TypeType):
# We force `Instance` and similar types to coercse to callable:
function_def = func_args.get_callable_from_context()
if not applied_args:
return function_def # this means, that `partial(func)` is called

is_valid, applied_args = func_args.build_from_context()
if not isinstance(function_def, CallableType | Overloaded) or not is_valid:
callable_def = _coerce_to_callable(function_def, func_args)
if callable_def is None:
return default_return

return _PartialFunctionReducer(
default_return,
function_def,
callable_def,
applied_args,
ctx,
).new_partial()


def _coerce_to_callable(
function_def: ProperType,
func_args: '_AppliedArgs',
) -> CallableType | Overloaded | None:
if not isinstance(function_def, _SUPPORTED_TYPES):
return None
if isinstance(function_def, Instance | TypeType):
# We force `Instance` and similar types to coerce to callable:
from_context = func_args.get_callable_from_context()
return (
from_context
if isinstance(from_context, CallableType | Overloaded)
else None
)
return function_def


@final
class _PartialFunctionReducer:
"""
Expand Down Expand Up @@ -219,16 +246,10 @@ def __init__(self, function_ctx: FunctionContext) -> None:
"""
We need the function default context.

The first arguments of ``partial`` is skipped:
The first argument of ``partial`` is skipped:
it is the applied function itself.
"""
self._function_ctx = function_ctx
self._parts = zip(
self._function_ctx.arg_names[1:],
self._function_ctx.arg_types[1:],
self._function_ctx.arg_kinds[1:],
strict=False,
)

def get_callable_from_context(self) -> ProperType:
"""Returns callable type from the context."""
Expand All @@ -254,17 +275,29 @@ def build_from_context(self) -> tuple[bool, list[FuncArg]]:
Here ``*args`` and ``**kwargs`` can be literally anything!
In these cases we fallback to the default return type.
"""
applied_args = []
for names, types, kinds in self._parts:
applied_args: list[FuncArg] = []
for arg in self._iter_applied_args():
if arg.kind in {ARG_STAR, ARG_STAR2}:
# We cannot really work with `*args`, `**kwargs`.
return False, []
applied_args.append(arg)
return True, applied_args

def _iter_applied_args(self) -> Iterator[FuncArg]:
skipped_applied_function = False
for names, types, kinds in zip(
self._function_ctx.arg_names,
self._function_ctx.arg_types,
self._function_ctx.arg_kinds,
strict=False,
):
for arg in self._generate_applied_args(
zip(names, types, kinds, strict=False)
zip(names, types, kinds, strict=False),
):
if arg.kind in {ARG_STAR, ARG_STAR2}:
# We cannot really work with `*args`, `**kwargs`.
return False, []

applied_args.append(arg)
return True, applied_args
if not skipped_applied_function:
skipped_applied_function = True
continue
yield arg

def _generate_applied_args(self, arg_parts) -> Iterator[FuncArg]:
yield from starmap(FuncArg, arg_parts)
25 changes: 13 additions & 12 deletions returns/contrib/mypy/_typeops/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,26 @@ def _infer_constraints(
"""Creates mapping of ``typevar`` to real type that we already know."""
checker = self._ctx.api.expr_checker # type: ignore
kinds = [arg.kind for arg in applied_args]
exprs = [arg.expression(self._ctx.context) for arg in applied_args]

formal_to_actual = map_actuals_to_formals(
kinds,
[arg.name for arg in applied_args],
self._fallback.arg_kinds,
self._fallback.arg_names,
lambda index: checker.accept(exprs[index]),
)
constraints = infer_constraints_for_callable(
self._fallback,
arg_types=[arg.type for arg in applied_args],
arg_kinds=kinds,
arg_names=[arg.name for arg in applied_args],
formal_to_actual=formal_to_actual,
context=checker.argument_infer_context(),
lambda index: checker.accept(
applied_args[index].expression(self._ctx.context),
),
)

return {
constraint.type_var: constraint.target for constraint in constraints
constraint.type_var: constraint.target
for constraint in infer_constraints_for_callable(
self._fallback,
arg_types=[arg.type for arg in applied_args],
arg_kinds=kinds,
arg_names=[arg.name for arg in applied_args],
formal_to_actual=formal_to_actual,
context=checker.argument_infer_context(),
)
}


Expand Down
35 changes: 32 additions & 3 deletions returns/curry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,41 @@
from functools import partial as _partial
from functools import wraps
from inspect import BoundArguments, Signature
from typing import Any, TypeAlias, TypeVar
from typing import Any, Generic, TypeAlias, TypeVar, overload

_ReturnType = TypeVar('_ReturnType')
_Decorator: TypeAlias = Callable[
[Callable[..., _ReturnType]],
Callable[..., _ReturnType],
]


class _PartialDecorator(Generic[_ReturnType]):
"""Wraps ``functools.partial`` into a decorator without nesting."""
__slots__ = ('_args', '_kwargs')

def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
self._args = args
self._kwargs = kwargs

def __call__(self, inner: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:
return _partial(inner, *self._args, **self._kwargs)


@overload
def partial(
func: Callable[..., _ReturnType],
/,
*args: Any,
**kwargs: Any,
) -> Callable[..., _ReturnType]:
) -> Callable[..., _ReturnType]: ...


@overload
def partial(*args: Any, **kwargs: Any) -> _Decorator: ...


def partial(*args: Any, **kwargs: Any) -> Any:
"""
Typed partial application.

Expand All @@ -35,7 +60,11 @@ def partial(
- https://docs.python.org/3/library/functools.html#functools.partial

"""
return _partial(func, *args, **kwargs)
if args and callable(args[0]):
return _partial(args[0], *args[1:], **kwargs)
if args and args[0] is None:
args = args[1:]
return _PartialDecorator(args, kwargs)


def curry(function: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ select = WPS, E999

extend-exclude =
.venv
.cache
build
# Bad code that I write to test things:
ex.py
Expand Down
30 changes: 30 additions & 0 deletions tests/test_curry/test_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Callable, TypeAlias, TypeVar, cast

from returns.curry import partial

_ReturnType = TypeVar('_ReturnType')
_Decorator: TypeAlias = Callable[
[Callable[..., _ReturnType]],
Callable[..., _ReturnType],
]


def add(first: int, second: int) -> int:
return first + second


def test_partial_direct_call() -> None:
add_one = partial(add, 1)
assert add_one(2) == 3


def test_partial_as_decorator_factory() -> None:
decorator = cast(_Decorator[int], partial())
add_with_decorator = decorator(add)
assert add_with_decorator(1, 2) == 3


def test_partial_with_none_placeholder() -> None:
decorator = cast(_Decorator[int], partial(None, 1))
add_with_none_decorator = decorator(add)
assert add_with_none_decorator(2) == 3
3 changes: 3 additions & 0 deletions typesafety/test_curry/test_partial/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[mypy]
python_version = 3.11
plugins = returns.contrib.mypy.returns_plugin
Loading
Loading