Source code for dags.signature

from __future__ import annotations

import functools
import inspect
from typing import TYPE_CHECKING, Any, cast, overload

from dags.annotations import get_annotations
from dags.exceptions import DagsError, InvalidFunctionArgumentsError

if TYPE_CHECKING:
    from collections.abc import Callable

    from dags.typing import P, R


def _create_signature(
    args_types: dict[str, str] | dict[str, type[inspect._empty]],
    kwargs_types: dict[str, str] | dict[str, type[inspect._empty]],
    return_annotation: Any = inspect.Parameter.empty,
) -> inspect.Signature:
    """Create an inspect.Signature object based on args and kwargs.

    Args:
        args_types: The positional arguments mapped to their types as strings, or if no
            type is available, mapped to `inspect.Parameter.empty`.
        kwargs_types: The keyword arguments mapped to their types as strings, or if no
            type is available, mapped to `inspect.Parameter.empty`.
        return_annotation: The return annotation. By default, the return annotation is
            `inspect.Parameter.empty`.

    Returns
    -------
        The signature.

    """
    parameter_objects = []
    for arg, arg_type in args_types.items():
        param = inspect.Parameter(
            name=arg,
            kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
            annotation=arg_type,
        )
        parameter_objects.append(param)

    for kwarg, kwarg_type in kwargs_types.items():
        param = inspect.Parameter(
            name=kwarg,
            kind=inspect.Parameter.KEYWORD_ONLY,
            annotation=kwarg_type,
        )
        parameter_objects.append(param)

    return inspect.Signature(
        parameters=parameter_objects, return_annotation=return_annotation
    )


def _create_annotations(
    args_types: dict[str, str] | dict[str, type[inspect._empty]],
    kwargs_types: dict[str, str] | dict[str, type[inspect._empty]],
    return_annotation: Any,
) -> (
    dict[str, str]
    | dict[str, str | type[inspect._empty]]
    | dict[str, type[inspect._empty]]
):
    annotations = args_types | kwargs_types
    if return_annotation is not inspect.Parameter.empty:
        annotations["return"] = return_annotation
    return annotations


@overload
def with_signature(
    func: Callable[P, R],
    *,
    args: dict[str, str] | list[str] | None = None,
    kwargs: dict[str, str] | list[str] | None = None,
    enforce: bool = True,
    return_annotation: Any = inspect.Parameter.empty,
) -> Callable[P, R]: ...


@overload
def with_signature(
    *,
    args: dict[str, str] | list[str] | None = None,
    kwargs: dict[str, str] | list[str] | None = None,
    enforce: bool = True,
    return_annotation: Any = inspect.Parameter.empty,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...


def with_signature(
    func: Callable[P, R] | None = None,
    *,
    args: dict[str, str] | list[str] | None = None,
    kwargs: dict[str, str] | list[str] | None = None,
    enforce: bool = True,
    return_annotation: Any = inspect.Parameter.empty,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
    """Add a signature to a function of type `f(*args, **kwargs)` (decorator).

    Caveats: The created signature only contains the names of arguments and whether
    they are keyword-only. There is no way of setting default values or type hints.

    Args:
        func: The function to be decorated. Should take `*args`
            and `**kwargs` as only arguments.
        args: If a list, the names of positional or keyword arguments. If a dict,
            the names of positional or keyword arguments and their types as strings.
        kwargs: If a list, the names of keyword only arguments. If a dict,
            the names of keyword only arguments and their types as strings.
        enforce: Whether the signature should be enforced or just
            added to the function for introspection. This creates runtime
            overhead.
        return_annotation: The return annotation. By default, the return annotation is
            `inspect.Parameter.empty`.

    Returns
    -------
        function: The function with signature.
    """

    def decorator_with_signature(func: Callable[P, R]) -> Callable[P, R]:
        _args = _map_names_to_types(args)
        _kwargs = _map_names_to_types(kwargs)
        signature = _create_signature(_args, _kwargs, return_annotation)
        annotations = _create_annotations(_args, _kwargs, return_annotation)
        valid_kwargs: set[str] = set(_kwargs) | set(_args)
        funcname: str = getattr(func, "__name__", "function")

        @functools.wraps(func)
        def wrapper_with_signature(*args: P.args, **kwargs: P.kwargs) -> R:
            if enforce:
                _fail_if_too_many_positional_arguments(args, list(_args), funcname)
                present_args: set[str] = set(list(_args)[: len(args)])
                present_kwargs: set[str] = set(kwargs)
                _fail_if_duplicated_arguments(present_args, present_kwargs, funcname)
                _fail_if_invalid_keyword_arguments(
                    present_kwargs, valid_kwargs, funcname
                )
            return func(*args, **kwargs)

        wrapper_with_signature.__signature__ = signature  # ty: ignore[unresolved-attribute]
        wrapper_with_signature.__annotations__ = annotations
        return wrapper_with_signature

    if func is not None:
        return decorator_with_signature(func)
    return decorator_with_signature


def _fail_if_too_many_positional_arguments(
    present_args: tuple[Any, ...], argnames: list[str], funcname: str
) -> None:
    if len(present_args) > len(argnames):
        msg = (
            f"{funcname}() takes {len(argnames)} positional arguments "
            f"but {len(present_args)} were given"
        )
        raise InvalidFunctionArgumentsError(msg)


def _fail_if_duplicated_arguments(
    present_args: set[str], present_kwargs: set[str], funcname: str
) -> None:
    problematic = present_args & present_kwargs
    if problematic:
        s = "s" if len(problematic) >= 2 else ""
        problem_str = ", ".join(list(problematic))
        msg = f"{funcname}() got multiple values for argument{s} {problem_str}"
        raise InvalidFunctionArgumentsError(msg)


def _fail_if_invalid_keyword_arguments(
    present_kwargs: set[str], valid_kwargs: set[str], funcname: str
) -> None:
    problematic = present_kwargs - valid_kwargs
    if problematic:
        s = "s" if len(problematic) >= 2 else ""
        problem_str = ", ".join(list(problematic))
        msg = f"{funcname}() got unexpected keyword argument{s} {problem_str}"
        raise InvalidFunctionArgumentsError(msg)


@overload
def rename_arguments(
    func: Callable[P, R],
    *,
    mapper: dict[str, str],
) -> Callable[..., R]: ...


@overload
def rename_arguments(
    *, mapper: dict[str, str]
) -> Callable[[Callable[P, R]], Callable[..., R]]: ...


[docs] def rename_arguments( # noqa: C901 func: Callable[P, R] | None = None, *, mapper: dict[str, str] | None = None ) -> Callable[..., R] | Callable[[Callable[P, R]], Callable[..., R]]: """Rename positional and keyword arguments of func. Args: func (callable): The function of which the arguments are renamed. mapper (dict): Dict of strings where keys are old names and values are new of arguments. Returns ------- function: The function with renamed arguments. """ def decorator_rename_arguments(func: Callable[P, R]) -> Callable[..., R]: old_signature = inspect.signature(func) old_parameters: dict[str, inspect.Parameter] = dict(old_signature.parameters) old_annotations = get_annotations(func) parameters: list[inspect.Parameter] = [] annotations: dict[str, str] = {} # mapper is assumed not to be None when renaming is desired. for name, param in old_parameters.items(): if mapper is not None and name in mapper: parameters.append(param.replace(name=mapper[name])) else: parameters.append(param) # annotations do not contain information on partialled arguments, and therefore # do not exactly align with the parameters. for name, annotation in old_annotations.items(): if mapper is not None and name in mapper: annotations[mapper[name]] = annotation else: annotations[name] = annotation signature = inspect.Signature( parameters=parameters, return_annotation=old_signature.return_annotation ) reverse_mapper: dict[str, str] = ( {v: k for k, v in mapper.items()} if mapper is not None else {} ) @functools.wraps(func) def wrapper_rename_arguments(*args: P.args, **kwargs: P.kwargs) -> R: internal_kwargs: dict[str, Any] = {} for name, value in kwargs.items(): if name in reverse_mapper: internal_kwargs[reverse_mapper[name]] = value elif mapper is None or name not in mapper: internal_kwargs[name] = value return func(*args, **internal_kwargs) wrapper_rename_arguments.__signature__ = signature # ty: ignore[unresolved-attribute] wrapper_rename_arguments.__annotations__ = annotations # Preserve function type if isinstance(func, functools.partial): partial_wrapper = functools.partial( wrapper_rename_arguments, *func.args, **func.keywords ) out = cast("Callable[P, R]", partial_wrapper) else: out = wrapper_rename_arguments return out if func is not None: return decorator_rename_arguments(func) return decorator_rename_arguments
def _map_names_to_types( arg: dict[str, str] | list[str] | None, ) -> dict[str, str] | dict[str, type[inspect._empty]]: if arg is None: return {} if isinstance(arg, list): return dict.fromkeys(arg, inspect.Parameter.empty) if isinstance(arg, dict): return arg raise DagsError(f"Invalid type for arg: {type(arg)}. Expected dict, list, or None.")