Source code for dags.annotations

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, overload

from dags.exceptions import NonStringAnnotationError

if TYPE_CHECKING:
    from collections.abc import Callable

import functools
import inspect


[docs] def get_free_arguments( func: Callable[..., Any], ) -> list[str]: arguments = list(inspect.signature(func).parameters) if isinstance(func, functools.partial): # arguments that are partialled by position are not part of the signature # anyways, so they do not need special handling. non_free = set(func.keywords) arguments = [arg for arg in arguments if arg not in non_free] return arguments
@overload def get_annotations( func: Callable[..., Any], eval_str: Literal[False] = False, default: str | None = None, ) -> dict[str, str]: ... @overload def get_annotations( func: Callable[..., Any], eval_str: Literal[True] = True, default: type | None = None, ) -> dict[str, type]: ...
[docs] def get_annotations( func: Callable[..., Any], eval_str: bool = False, default: str | type | None = None, ) -> dict[str, str] | dict[str, type]: """Thin wrapper around inspect.get_annotations. Compared to inspect.get_annotations, this function also handles partialled funcs, and it returns annotations for all arguments, not just the ones with annotations. Args: func: The function to get annotations from. eval_str: If True, the string type annotations are evaluated. default: The default value to use if an annotation is missing. If None, the default value is inspect.Parameter.empty if eval_str is True, otherwise "no_annotation_found". Returns ------- A dictionary with the argument names as keys and the type annotations as values. The type annotations are strings if eval_str is False, otherwise they are types. """ if default is None: default = inspect.Parameter.empty if eval_str else "no_annotation_found" if isinstance(func, functools.partial): annotations = inspect.get_annotations(func.func, eval_str=eval_str) else: annotations = inspect.get_annotations(func, eval_str=eval_str) free_arguments = get_free_arguments(func) annotation_keys = {k for k in annotations if k != "return"} signature_params = set(free_arguments) if _has_args_kwargs_annotation_mismatch(annotation_keys, signature_params): annotations = _get_annotations_from_signature(func, eval_str) return {arg: annotations.get(arg, default) for arg in ["return", *free_arguments]}
def verify_annotations_are_strings( annotations: dict[str, str], function_name: str ) -> None: # If all annotations are strings, we are done. if all(isinstance(v, str) for v in annotations.values()): return non_string_annotations = [ k for k, v in annotations.items() if not isinstance(v, str) ] arg_annotations = {k: v for k, v in annotations.items() if k != "return"} return_annotation = annotations["return"] # Create a representation of the signature with string annotations # ---------------------------------------------------------------------------------- stringified_arg_annotations = [] for k, v in arg_annotations.items(): if k in non_string_annotations: stringified_arg_annotations.append(f"{k}: '{_get_str_repr(v)}'") else: annot = f"{k}: '{v}'" stringified_arg_annotations.append(annot) if "return" in non_string_annotations: stringified_return_annotation = f"'{_get_str_repr(return_annotation)}'" else: stringified_return_annotation = f"'{return_annotation}'" stringified_signature = ( f"{function_name}({', '.join(stringified_arg_annotations)}) -> " f"{stringified_return_annotation}" ) # Create message on which argument and/or return annotation is invalid # ---------------------------------------------------------------------------------- invalid_arg_annotations = [k for k in non_string_annotations if k != "return"] if invalid_arg_annotations: s = "s" if len(invalid_arg_annotations) > 1 else "" invalid_arg_msg = f"argument{s} ({', '.join(invalid_arg_annotations)})" else: invalid_arg_msg = "" invalid_annotations_msg = "" if invalid_arg_msg and "return" in non_string_annotations: invalid_annotations_msg = f"{invalid_arg_msg} and the return value" elif invalid_arg_msg: invalid_annotations_msg = invalid_arg_msg elif "return" in non_string_annotations: invalid_annotations_msg = "return value" raise NonStringAnnotationError( f"All function annotations must be strings. The annotations for the " f"{invalid_annotations_msg} are not strings.\nA simple way for Python to treat " "type annotations as strings is to add\n\n\tfrom __future__ import annotations" "\n\nat the top of your file. Alternatively, you can do it manually by " f"enclosing the annotations in quotes:\n\n\t{stringified_signature}." ) def _get_str_repr(obj: object) -> str: return getattr(obj, "__name__", str(obj)) def _has_args_kwargs_annotation_mismatch( annotation_keys: set[str], signature_params: set[str] ) -> bool: """Check if annotations have the args/kwargs mismatch from Python 3.14. In Python 3.14, when functools.wraps wraps a non-function object (like PolicyFunction in ttsim) and the wrapper defines annotations with ParamSpec (*args: P.args, **kwargs: P.kwargs), functools.wraps no longer copies __annotations__ from the wrapped object; it uses the wrapper's annotations ({'args': 'P.args', 'kwargs': 'P.kwargs'}), which don't match the signature parameters that functools.wraps still copies correctly. """ return annotation_keys != signature_params and annotation_keys == {"args", "kwargs"} def _get_annotations_from_signature( func: Callable[..., Any], eval_str: bool ) -> dict[str, Any]: """Extract annotations from the function signature. This is a fallback for when inspect.get_annotations returns incorrect results, such as in Python 3.14's args/kwargs annotation mismatch case. """ sig = inspect.signature(func) annotations: dict[str, Any] = {} for param_name, param in sig.parameters.items(): if param.annotation != inspect.Parameter.empty: annotations[param_name] = ( param.annotation if eval_str or isinstance(param.annotation, str) else _get_str_repr(param.annotation) ) if sig.return_annotation != inspect.Signature.empty: annotations["return"] = ( sig.return_annotation if eval_str or isinstance(sig.return_annotation, str) else _get_str_repr(sig.return_annotation) ) return annotations