from __future__ import annotations
import functools
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast
import networkx as nx
from dags.annotations import (
get_annotations,
get_free_arguments,
verify_annotations_are_strings,
)
from dags.exceptions import (
AnnotationMismatchError,
CyclicDependencyError,
DagsError,
MissingFunctionsError,
)
from dags.output import aggregated_output, dict_output, list_output, single_output
from dags.signature import with_signature
from dags.utils import format_list_linewise
if TYPE_CHECKING:
from collections.abc import Callable
from dags.typing import T
class DagsWarning(UserWarning):
"""Base class for all warnings in the dags library."""
@dataclass(frozen=True)
class FunctionExecutionInfo:
"""Information about a function that is needed to execute it.
Attributes
----------
name: The name of the function.
func: The function to execute.
verify_annotations: If True, we verify that the annotations are strings.
Properties
----------
annotations: The annotations of the function. For standard functions this
coincides with the __annotations__ attribute of the function. For partialled
functions, this is a dictionary with the names of the free arguments as keys
and their expected types as values, as well as the return type of the
function stored under the key "return". Type annotations must be strings,
else a NonStringAnnotationError is raised.
arguments: The names of the arguments of the function.
argument_annotations: The argument annotations of the function.
return_annotation: The return annotation of the function.
Raises
------
NonStringAnnotationError: If `verify_annotations` is `True` and the type
annotations are not strings.
"""
name: str
func: Callable[..., Any]
verify_annotations: bool = False
def __post_init__(self) -> None:
"""Verify that the annotations are strings."""
if self.verify_annotations:
verify_annotations_are_strings(self.annotations, self.name)
@functools.cached_property
def annotations(self) -> dict[str, str]:
"""The annotations of the function."""
return get_annotations(self.func)
@property
def arguments(self) -> list[str]:
"""The names of the arguments of the function."""
return [k for k in self.annotations if k != "return"]
@property
def argument_annotations(self) -> dict[str, str]:
"""The argument annotations of the function."""
return {arg: self.annotations[arg] for arg in self.arguments}
@property
def return_annotation(self) -> str:
"""The return annotation of the function."""
return self.annotations["return"]
[docs]
def concatenate_functions(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None = None,
*,
dag: nx.DiGraph[str] | None = None,
return_type: Literal["tuple", "list", "dict"] = "tuple",
aggregator: Callable[[T, T], T] | None = None,
aggregator_return_type: str | None = None,
enforce_signature: bool = True,
set_annotations: bool = False,
lexsort_key: Callable[[str], Any] | None = None,
) -> Callable[..., Any]:
"""Combine functions to one function that generates targets.
Functions can depend on the output of other functions as inputs, as long as the
dependencies can be described by a directed acyclic graph (DAG).
Functions that are not required to produce the targets will simply be ignored.
The arguments of the combined function are all arguments of relevant functions that
are not themselves function names, in alphabetical order.
Args:
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list or None): Name of the function that produces the target or
list of such function names. If the value is `None`, all variables are
returned.
dag (networkx.DiGraph or None): A DAG of functions. If None, a new DAG is
created from the functions and targets.
return_type (str): One of "tuple", "list", "dict". This is ignored if the
targets are a single string or if an aggregator is provided.
aggregator (callable or None): Binary reduction function that is used to
aggregate the targets into a single target.
aggregator_return_type (str or None): Explicit return type annotation for the
aggregated result. If None and set_annotations is True, the return type
is inferred from the aggregator's annotations or from the target types
(if all targets have the same type). This parameter is only used when
an aggregator is provided.
enforce_signature (bool): If True, the signature of the concatenated function
is enforced. Otherwise it is only provided for introspection purposes.
Enforcing the signature has a small runtime overhead.
set_annotations (bool): If True, sets the annotations of the concatenated
function based on those of the functions used to generate the targets. The
return annotation of the concatenated function reflects the requested return
type and number of targets (e.g., for two targets returned as a list, the
return annotation is a list of their respective type hints). Note that this
is not a valid type annotation and should not be used for type checking. All
annotations must be strings; otherwise, a NonStringAnnotationError is
raised. To ensure string annotations, enclose them in quotes or use "from
__future__ import annotations" at the top of your file. An
AnnotationMismatchError is raised if annotations differ between functions.
lexsort_key (callable or None): A function that takes a string and returns a
value that can be used to sort the nodes. This is used to sort the nodes
in the topological sort. If None, the nodes are sorted alphabetically.
Returns
-------
function: A function that produces targets when called with suitable arguments.
Raises
------
- NonStringAnnotationError: If `set_annotations` is `True` and the type
annotations are not strings.
- AnnotationMismatchError: If `set_annotations` is `True` and there are
incompatible annotations in the DAG's components.
"""
if dag is None:
# Create the DAG.
dag = create_dag(functions=functions, targets=targets)
# Build combined function.
return _create_combined_function_from_dag(
dag=dag,
functions=functions,
targets=targets,
return_type=return_type,
aggregator=aggregator,
aggregator_return_type=aggregator_return_type,
enforce_signature=enforce_signature,
set_annotations=set_annotations,
lexsort_key=lexsort_key,
)
[docs]
def create_dag(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
) -> nx.DiGraph[str]:
"""Build a directed acyclic graph (DAG) from functions.
Functions can depend on the output of other functions as inputs, as long as the
dependencies can be described by a directed acyclic graph (DAG).
Functions that are not required to produce the targets will simply be ignored.
Args:
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list or None): Name of the function that produces the target or
list of such function names. If the value is `None`, all variables are
returned.
Returns
-------
dag: the DAG (as networkx.DiGraph object)
"""
# Harmonize and check arguments.
_functions, _targets = harmonize_and_check_functions_and_targets(
functions,
targets,
)
# Create the DAG
_raw_dag = _create_complete_dag(_functions)
dag = _limit_dag_to_targets_and_their_ancestors(_raw_dag, _targets)
# Check if there are cycles in the DAG
_fail_if_dag_contains_cycle(dag)
return dag
def _create_combined_function_from_dag(
dag: nx.DiGraph[str],
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
return_type: Literal["tuple", "list", "dict"] = "tuple",
aggregator: Callable[[T, T], T] | None = None,
aggregator_return_type: str | None = None,
enforce_signature: bool = True,
set_annotations: bool = False,
lexsort_key: Callable[[str], Any] | None = None,
) -> Callable[..., Any]:
"""Create combined function which allows executing a DAG in one function call.
The arguments of the combined function are all arguments of relevant functions that
are not themselves function names, in alphabetical order.
Args:
dag (networkx.DiGraph): a DAG of functions
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list or None): Name of the function that produces the target or
list of such function names. If the value is `None`, all variables are
returned.
return_type (str): One of "tuple", "list", "dict". This is ignored if the
targets are a single string or if an aggregator is provided.
aggregator (callable or None): Binary reduction function that is used to
aggregate the targets into a single target.
enforce_signature (bool): If True, the signature of the concatenated function
is enforced. Otherwise it is only provided for introspection purposes.
Enforcing the signature has a small runtime overhead.
set_annotations (bool): If True, sets the annotations of the concatenated
function based on those of the functions used to generate the targets. The
return annotation of the concatenated function reflects the requested return
type and number of targets (e.g., for two targets returned as a list, the
return annotation is a list of their respective type hints). Note that this
is not a valid type annotation and should not be used for type checking. All
annotations must be strings; otherwise, a NonStringAnnotationError is
raised. To ensure string annotations, enclose them in quotes or use "from
__future__ import annotations" at the top of your file. An
AnnotationMismatchError is raised if annotations differ between functions.
lexsort_key (callable or None): A function that takes a string and returns a
value that can be used to sort the nodes. This is used to sort the nodes
in the topological sort. If None, the nodes are sorted alphabetically.
Returns
-------
function: A function that produces targets when called with suitable arguments.
Raises
------
- NonStringAnnotationError: If `set_annotations` is `True` and the type
annotations are not strings.
- AnnotationMismatchError: If `set_annotations` is `True` and there are
incompatible annotations in the DAG's components.
"""
# Harmonize and check arguments.
_functions, _targets = harmonize_and_check_functions_and_targets(
functions,
targets,
)
_arglist = create_arguments_of_concatenated_function(_functions, dag)
_exec_info = create_execution_info(
_functions, dag, verify_annotations=set_annotations, lexsort_key=lexsort_key
)
# Create the concatenated function that returns all requested targets as a tuple.
# If set_annotations is True, the return annotation is a tuple of strings,
# corresponding to the return types of the targets.
_concatenated = _create_concatenated_function(
_exec_info,
_arglist,
_targets,
enforce_signature,
set_annotations,
)
# Update the actual return type, as well as the return annotation of the
# concatenated function.
out: Callable[..., Any]
if isinstance(targets, str) or (aggregator is not None and len(_targets) == 1):
out = single_output(func=_concatenated, set_annotations=set_annotations)
elif aggregator is not None:
inferred_return_type: str | None = None
if set_annotations:
target_types = tuple(_exec_info[t].return_annotation for t in _targets)
inferred_return_type = _infer_aggregator_return_type(
aggregator=aggregator,
explicit_type=aggregator_return_type,
target_types=target_types,
)
if inferred_return_type is None:
warnings.warn(
message=(
"Cannot infer return annotation when using an aggregator on "
"multiple targets. Consider providing aggregator_return_type."
),
category=DagsWarning,
stacklevel=2,
)
out = aggregated_output(
func=_concatenated,
aggregator=aggregator,
set_annotations=set_annotations,
return_annotation=inferred_return_type,
)
elif return_type == "list":
out = cast(
"Callable[..., Any]",
list_output(func=_concatenated, set_annotations=set_annotations),
)
elif return_type == "tuple":
out = _concatenated
elif return_type == "dict":
out = cast(
"Callable[..., Any]",
dict_output(
func=_concatenated, keys=_targets, set_annotations=set_annotations
),
)
else:
msg = (
f"Invalid return type {return_type}. Must be 'list', 'tuple', or 'dict'. "
f"You provided {return_type}."
)
raise DagsError(msg)
return out
[docs]
def get_ancestors(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
include_targets: bool = False,
) -> set[str]:
"""Build a DAG and extract all ancestors of targets.
Args:
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict,
with node names as keys or just the values as a tuple for multiple outputs.
targets (str): Name of the function that produces the target function.
include_targets (bool): Whether to include the target as its own ancestor.
Returns
-------
set: The ancestors
"""
# Harmonize and check arguments.
_functions, _targets = harmonize_and_check_functions_and_targets(
functions,
targets,
)
# Create the DAG.
dag = create_dag(functions, targets)
ancestors: set[str] = set()
for target in _targets:
ancestors |= nx.ancestors(dag, target)
if include_targets:
ancestors.add(target)
return ancestors
def harmonize_and_check_functions_and_targets(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
targets: str | list[str] | None,
) -> tuple[dict[str, Callable[..., Any]], list[str]]:
"""Harmonize the type of specified functions and targets and do some checks.
Args:
functions (dict or list): Dict or list of functions. If a list, the function
name is inferred from the __name__ attribute of the entries. If a dict, the
name of the function is set to the dictionary key.
targets (str or list): Name of the function that produces the target or list of
such function names.
Returns
-------
functions_harmonized: harmonized functions
targets_harmonized: harmonized targets
"""
functions_harmonized = _harmonize_functions(functions)
targets_harmonized = _harmonize_targets(targets, list(functions_harmonized))
_fail_if_targets_have_wrong_types(targets_harmonized)
_fail_if_functions_are_missing(functions_harmonized, targets_harmonized)
return functions_harmonized, targets_harmonized
def _harmonize_functions(
functions: dict[str, Callable[..., Any]] | list[Callable[..., Any]],
) -> dict[str, Callable[..., Any]]:
if not isinstance(functions, dict):
functions_dict = {func.__name__: func for func in functions} # ty: ignore[unresolved-attribute]
else:
functions_dict = functions
return functions_dict
def _harmonize_targets(
targets: str | list[str] | None,
function_names: list[str],
) -> list[str]:
if targets is None:
targets = function_names
elif isinstance(targets, str):
targets = [targets]
return targets
def _fail_if_targets_have_wrong_types(
targets: list[str],
) -> None:
not_strings = [target for target in targets if not isinstance(target, str)]
if not_strings:
msg = f"Targets must be strings. The following targets are not: {not_strings}"
raise DagsError(msg)
def _fail_if_functions_are_missing(
functions: dict[str, Callable[..., Any]],
targets: list[str],
) -> None:
targets_not_in_functions = set(targets) - set(functions)
if targets_not_in_functions:
formatted = format_list_linewise(list(targets_not_in_functions))
msg = f"The following targets have no corresponding function:\n{formatted}"
raise MissingFunctionsError(msg)
def _fail_if_dag_contains_cycle(dag: nx.DiGraph[str]) -> None:
"""Check for cycles in DAG."""
cycles = list(nx.simple_cycles(dag))
if len(cycles) > 0:
formatted = format_list_linewise(cycles)
msg = f"The DAG contains one or more cycles:\n{formatted}"
raise CyclicDependencyError(msg)
def _create_complete_dag(
functions: dict[str, Callable[..., Any]],
) -> nx.DiGraph[str]:
"""Create the complete DAG.
This DAG is constructed from all functions and not pruned by specified root nodes or
targets.
Args:
functions (dict): Dictionary containing functions to build the DAG.
Returns
-------
networkx.DiGraph: The complete DAG
"""
functions_arguments_dict = {
name: get_free_arguments(function) for name, function in functions.items()
}
return nx.DiGraph(functions_arguments_dict).reverse()
def _limit_dag_to_targets_and_their_ancestors(
dag: nx.DiGraph[str],
targets: list[str],
) -> nx.DiGraph[str]:
"""Limit DAG to targets and their ancestors.
Args:
dag (networkx.DiGraph): The complete DAG.
targets (str): Variable of interest.
Returns
-------
networkx.DiGraph: The pruned DAG.
"""
used_nodes = set(targets)
for target in targets:
used_nodes = used_nodes | set(nx.ancestors(dag, target))
all_nodes = set(dag.nodes)
unused_nodes = all_nodes - used_nodes
dag.remove_nodes_from(unused_nodes)
return dag
def create_arguments_of_concatenated_function(
functions: dict[str, Callable[..., Any]],
dag: nx.DiGraph[str],
) -> list[str]:
"""Create the signature of the concatenated function.
Args:
functions (dict): Dictionary containing functions to build the DAG.
dag (networkx.DiGraph): The complete DAG.
Returns
-------
list: The arguments of the concatenated function.
"""
function_names = set(functions)
all_nodes = set(dag.nodes)
return sorted(all_nodes - function_names)
def create_execution_info(
functions: dict[str, Callable[..., Any]],
dag: nx.DiGraph[str],
verify_annotations: bool = False,
lexsort_key: Callable[[str], Any] | None = None,
) -> dict[str, FunctionExecutionInfo]:
"""Create a dictionary with all information needed to execute relevant functions.
Args:
functions (dict): Dictionary containing functions to build the DAG.
dag (networkx.DiGraph): The complete DAG.
verify_annotations (bool): If True, we verify that the annotations are strings.
lexsort_key (callable or None): A function that takes a string and returns a
value that can be used to sort the nodes. This is used to sort the nodes
in the topological sort. If None, the nodes are sorted alphabetically.
Returns
-------
dict: Dictionary with functions and their arguments for each node in the DAG.
The functions are already in topological_sort order.
Raises
------
NonStringAnnotationError: If `verify_annotations` is `True` and the type
annotations are not strings.
"""
out = {}
for node in nx.lexicographical_topological_sort(dag, key=lexsort_key):
if node in functions:
out[node] = FunctionExecutionInfo(
name=node,
func=functions[node],
verify_annotations=verify_annotations,
)
return out
def _create_concatenated_function(
execution_info: dict[str, FunctionExecutionInfo],
arglist: list[str],
targets: list[str],
enforce_signature: bool,
set_annotations: bool,
) -> Callable[..., tuple[Any, ...]]:
"""Create a concatenated function object with correct signature.
Args:
execution_info: Dataclass with functions and their arguments for each
node in the DAG. The functions are already in topological_sort order.
arglist: The list of arguments of the concatenated function.
targets: List that is used to determine what is returned and the
order of the outputs.
enforce_signature: If True, the signature of the concatenated function
is enforced. Otherwise it is only provided for introspection purposes.
Enforcing the signature has a small runtime overhead.
set_annotations (bool): If True, sets the annotations of the concatenated
function based on those of the functions used to generate the targets. The
return annotation of the concatenated function reflects the requested return
type and number of targets (e.g., for two targets returned as a list, the
return annotation is a list of their respective type hints). Note that this
is not a valid type annotation and should not be used for type checking. All
annotations must be strings; otherwise, a NonStringAnnotationError is
raised. To ensure string annotations, enclose them in quotes or use "from
__future__ import annotations" at the top of your file. An
AnnotationMismatchError is raised if annotations differ between functions.
Returns
-------
The concatenated function
"""
args: list[str] | dict[str, str]
return_annotation: type[inspect._empty] | tuple[str, ...]
if set_annotations:
args, return_annotation = get_annotations_from_execution_info(
execution_info,
arglist=arglist,
targets=targets,
)
else:
args = arglist
return_annotation = inspect.Parameter.empty
@with_signature(
args=args,
enforce=enforce_signature,
return_annotation=return_annotation,
)
def concatenated(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
results = {**dict(zip(arglist, args, strict=False)), **kwargs}
for name, info in execution_info.items():
func_kwargs = {arg: results[arg] for arg in info.arguments}
result = info.func(**func_kwargs)
results[name] = result
return tuple(results[target] for target in targets)
return concatenated
def _infer_aggregator_return_type(
aggregator: Callable[[T, T], T],
explicit_type: str | None,
target_types: tuple[str, ...],
) -> str | None:
"""Infer the return type annotation for an aggregated function.
Uses a three-tier approach:
1. If explicit_type is provided, use it directly.
2. Try to get the return annotation from the aggregator function.
3. If all targets have the same type, assume the aggregator preserves it.
Note: Tier 3 is a heuristic that works for aggregators like `logical_and` or
`max`, but may be wrong for type-promoting aggregators (e.g., summing bools
returns int, not bool). Use explicit_type or a typed aggregator in such cases.
Args:
aggregator: The binary reduction function.
explicit_type: Explicitly provided return type, if any.
target_types: The return types of the target functions.
Returns
-------
The inferred return type as a string, or None if inference failed.
"""
# 1. Explicit type wins
if explicit_type is not None:
return explicit_type
# 2. Try aggregator's annotations
try:
agg_annot = get_annotations(aggregator)
ret = agg_annot.get("return", "no_annotation_found")
if ret != "no_annotation_found":
return ret
except Exception: # noqa: BLE001, S110
pass
# 3. If all targets have the same type, assume the aggregator preserves it
non_missing = [t for t in target_types if t != "no_annotation_found"]
if non_missing and len(set(non_missing)) == 1:
return non_missing[0]
return None
def get_annotations_from_execution_info(
execution_info: dict[str, FunctionExecutionInfo],
arglist: list[str],
targets: list[str],
) -> tuple[dict[str, str], tuple[str, ...]]:
"""Get the (argument and return) annotations of the concatenated function.
Args:
execution_info: Dataclass with functions and their arguments for each
node in the DAG. The functions are already in topological_sort order.
arglist: The list of arguments of the concatenated function.
targets: The list of targets of the concatenated function.
Returns
-------
- Dictionary with argument names as keys and their expected types in string
format as values.
- The expected type of the return value as a string.
Raises
------
AnnotationMismatchError: If there are incompatible annotations in the DAG's
components.
"""
types: dict[str, str] = {}
errors: list[str] = []
for name, info in execution_info.items():
# We do not need to check whether name is already in types_dict, because the
# functions in execution_info are topologically sorted, and hence, it is
# impossible for a function to appear as a dependency of another function
# before appearing as a function itself.
types[name] = info.return_annotation
for arg in set(info.argument_annotations).intersection(types.keys()):
# Verify that the type information on arg that was retrieved up to this
# point (earlier_type) is consistent with the type information on arg from
# the current function info (current_type).
earlier_type = types[arg]
current_type = info.argument_annotations[arg]
# The following condition is a hack to deal with overloaded type
# annotations. E.g., we may have a function that an int and returns an int,
# or it takes a float and returns a float. We can achieve that with
# @overload, but the type hints will be "int | float". If we just checked]
# for equality, we would get an error if a downstream or upstream function
# required an int or a float. We will not be able to do much better unless
# we switch away from string-type annotations or replicate the entire logic
# of a static type checker, both of which are infeasible at the moment.
if earlier_type not in current_type and current_type not in earlier_type:
arg_is_function = arg in execution_info
if arg_is_function:
explanation = f"function {arg} has return type: {earlier_type}."
else:
explanation = (
f"type annotation '{arg}: {earlier_type}' is used elsewhere."
)
errors.append(
f"function {name} has the argument type annotation '{arg}: "
f"{current_type}', but {explanation}"
)
types.update(info.argument_annotations)
if errors:
raise AnnotationMismatchError(
"The following type annotations are inconsistent:\n" + "\n".join(errors)
)
args_annotations = {arg: types[arg] for arg in arglist}
return_annotation = tuple(types[target] for target in targets)
return args_annotations, return_annotation