# coding: utf-8
"""
Tasks that provide common and often used functionality.
"""
from __future__ import annotations
__all__ = ["RunOnceTask", "TransferLocalFile", "ForestMerge"]
import os
import pathlib
import abc
import luigi # type: ignore[import-untyped]
from law.task.base import Task
from law.workflow.local import LocalWorkflow
from law.target.file import FileSystemTarget, FileSystemFileTarget
from law.target.local import LocalFileTarget
from law.target.collection import TargetCollection, SiblingFileCollection
from law.parameter import NO_STR
from law.decorator import factory
from law.util import iter_chunks, flatten, map_struct, range_expand, DotDict
from law.logger import get_logger
from law._types import Callable, Any, Sequence, Iterator
logger = get_logger(__name__)
[docs]
class RunOnceTask(Task):
@staticmethod
@factory(accept_generator=True)
def complete_on_success(
fn: Callable,
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable]:
def before_call() -> None:
return None
def call(state: None):
return fn(task, *args, **kwargs)
def after_call(state: None) -> None:
task.mark_complete() # type: ignore[attr-defined]
return before_call, call, after_call
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._has_run = False
@property
def has_run(self) -> bool:
return self._has_run
def mark_complete(self) -> None:
self._has_run = True
[docs]
def complete(self) -> bool:
return self.has_run
[docs]
class TransferLocalFile(Task):
source_path = luigi.Parameter(
default=NO_STR,
description="path to the file to transfer; when empty, the task input is used; default: "
"empty",
)
replicas = luigi.IntParameter(
default=0,
description="number of replicas to generate; when > 0 the output will be a file collection "
"instead of a single file; default: 0",
)
exclude_index = True
exclude_params_repr_empty = {"source_path"}
def get_source_target(self) -> LocalFileTarget:
# when self.source_path is set, return a target around it
# otherwise assume self.requires() returns a task with a single local target
if self.source_path not in (NO_STR, None):
source_path = os.path.expandvars(os.path.expanduser(str(self.source_path)))
return LocalFileTarget(os.path.abspath(source_path))
return self.input()
@abc.abstractmethod
def single_output(self) -> FileSystemFileTarget:
...
def get_replicated_path(self, basename: str, i: int | None = None) -> str:
if i is None:
return basename
name, ext = os.path.splitext(basename)
return f"{name}.{i}{ext}"
[docs]
def output(self) -> SiblingFileCollection | FileSystemFileTarget:
replicas: int = self.replicas # type: ignore[assignment]
output = self.single_output()
if replicas <= 0:
return output
# return the replicas in a SiblingFileCollection
return SiblingFileCollection([
output.sibling(self.get_replicated_path(output.basename, i), "f")
for i in range(replicas)
])
[docs]
def run(self) -> None:
self.transfer(self.get_source_target())
def trace_transfer_output(self, output: Any) -> SiblingFileCollection | FileSystemFileTarget:
return output
def transfer(
self,
src_path: str | pathlib.Path | LocalFileTarget,
output: SiblingFileCollection | FileSystemFileTarget | None = None,
) -> None:
# get the output target to transfer
if output is None:
output = self.output()
output = self.trace_transfer_output(output)
# single output
if not isinstance(output, SiblingFileCollection):
output.copy_from_local(src_path, cache=False)
return
# upload all replicas
progress_callback = self.create_progress_callback(self.replicas) # type: ignore[arg-type]
for i, replica in enumerate(output.targets):
replica.copy_from_local(src_path, cache=False)
progress_callback(i)
self.publish_message(f"uploaded {replica.basename}")
[docs]
class ForestMerge(LocalWorkflow):
tree_index = luigi.IntParameter(
default=-1,
description="the index of the merged tree in the forest; -1 denotes the forest itself "
"which requires and outputs all trees; default: -1",
)
tree_depth = luigi.IntParameter(
default=0,
description="the depth of this workflow in the merge tree; 0 denotes the root; default: 0",
)
keep_nodes = luigi.BoolParameter(
significant=False,
description="keep merged results, i.e., task outputs from intermediate nodes in the merge "
"tree; default: False",
)
# fix some workflow parameters
acceptance = 1.0 # type: ignore[assignment]
tolerance = 0.0 # type: ignore[assignment]
pilot = False # type: ignore[assignment]
node_format = "{name}.t{tree}.d{depth}.b{branch}{ext}"
postfix_format = "t{tree}_d{depth}"
merge_factor = 2
exclude_index = True
exclude_params_forest_merge = {"tree_index", "tree_depth", "keep_nodes", "branch", "branches"}
[docs]
@classmethod
def modify_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params = super().modify_param_values(params)
# when tree_index is negative, which refers to the merge forest, make sure this is branch 0
if "tree_index" in params and "branch" in params and params["tree_index"] < 0:
params["branch"] = 0
return params
@classmethod
def _req_tree(cls, inst: ForestMerge, *args, **kwargs) -> ForestMerge:
# amend workflow branch parameters to exclude
kwargs["_exclude"] = set(kwargs.pop("_exclude", set())) | {"branches"}
# just as for all workflows that require branches of themselves (or vice versa,
# skip task level excludes
kwargs["_skip_task_excludes"] = True
# create the required instance
new_inst = super(ForestMerge, cls).req(inst, *args, **kwargs)
# forward the _n_leaves attribute
new_inst._n_leaves = inst._n_leaves # type: ignore[attr-defined]
# when set, also forward the tree itself and caching decisions
if inst._merge_forest_built:
new_inst._cache_forest = inst._cache_forest # type: ignore[attr-defined]
new_inst._merge_forest = inst._merge_forest # type: ignore[attr-defined]
new_inst._merge_forest_built = new_inst._merge_forest is not None # type: ignore[attr-defined] # noqa
new_inst._leaves_per_tree = inst._leaves_per_tree # type: ignore[attr-defined]
return new_inst # type: ignore[return-value]
@classmethod
def _mark_merge_output_placeholder(cls, target: FileSystemTarget) -> FileSystemTarget:
"""
Marks a *target*, such as the output of :py:meth:`merge_output` as temporary placeholder.
When such a target is received while building the merge forest, no actual merging structure
is constructed, but rather deferred to a future call.
"""
target._is_merge_output_placeholder = True # type: ignore[attr-defined]
return target
@classmethod
def _check_merge_output_placeholder(cls, target: FileSystemTarget) -> bool:
return bool(getattr(target, "_is_merge_output_placeholder", False))
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# set attributes
self._n_leaves: int | None = None
self._cache_forest = True
self._merge_forest_built = False
self._merge_forest: list[dict[int, list[tuple[int, ...]]]] | None = None
self._leaves_per_tree: list[int] | None = None
# modify_param_values prevents the forest from being a workflow, but still check
if self.is_forest() and self.is_workflow():
raise Exception(f"merge forest must not be a workflow, {self} misconfigured")
def is_forest(self) -> bool:
return self.tree_index < 0 # type: ignore[operator]
def is_root(self) -> bool:
return not self.is_forest() and self.tree_depth == 0
def is_leaf(self) -> bool:
return not self.is_forest() and self.tree_depth == self.max_tree_depth
def req_workflow(self, **kwargs) -> ForestMerge:
# since the forest counts as a branch, as_workflow should point the tree_index 0
# which is only used to compute the overall merge tree
if self.is_forest():
kwargs["tree_index"] = 0
kwargs["_skip_task_excludes"] = False
return super().req_workflow(**kwargs) # type: ignore[return-value]
@property
def max_tree_depth(self) -> int:
return max(self._get_tree().keys())
@property
def merge_forest(self) -> list[dict[int, list[tuple[int, ...]]]]:
self._build_merge_forest()
return self._merge_forest # type: ignore[return-value]
@property
def leaves_per_tree(self) -> list[int]:
self._build_merge_forest()
return self._leaves_per_tree # type: ignore[return-value]
@property
def leaf_range(self) -> tuple[int, int]:
if not self.is_leaf():
raise Exception("leaf_range can only be accessed by leaves")
# compute the range
tree_index: int = self.tree_index # type: ignore[assignment]
leaves_per_tree = self.leaves_per_tree
merge_factor = self.merge_factor
n_leaves = leaves_per_tree[tree_index]
offset = sum(leaves_per_tree[:tree_index])
if merge_factor <= 0:
merge_factor = n_leaves
start_leaf = offset + self.branch * merge_factor # type: ignore[operator]
end_leaf = min(start_leaf + merge_factor, offset + n_leaves)
return start_leaf, end_leaf
def _get_tree(self) -> dict[int, list[tuple[int, ...]]]:
if self.is_forest():
raise Exception(
"merge tree cannot be determined for the merge forest, ForestMerge misconfigured",
)
try:
return self.merge_forest[self.tree_index] # type: ignore[call-overload]
except IndexError:
raise Exception(
f"merge tree {self.tree_index} not found, forest only contains "
f"{len(self.merge_forest)} tree(s)",
)
def _build_merge_forest(self) -> None:
# a node in the tree can be described by a tuple of integers, where each value denotes the
# branch path to go down the tree to reach the node (e.g. (2, 0) -> 2nd branch, 0th branch),
# so the length of the tuple defines the depth of the node via ``depth = len(node) - 1``
# the tree itself is a dict that maps depths to lists of nodes with that depth
# when multiple trees are used (a forest), each one handles ``n_leaves / n_trees`` leaves
# when the forest was already built and saved by means of the _cache_forest flag, do nothing
if self._merge_forest_built and self._merge_forest is not None:
return
# helper to convert nested lists of leaf number chunks into a list of nodes in the format
# described above
def nodify(obj, node=None, root_id=0):
if not isinstance(obj, list):
return []
nodes = []
if node is None:
node = tuple()
else:
nodes.append(node)
for i, _obj in enumerate(obj):
nodes += nodify(_obj, node + (i if node else root_id,))
return nodes
# infer the number of trees from the merge output
output = self.merge_output()
is_placeholder = self._check_merge_output_placeholder(output)
n_trees = 1
if not is_placeholder:
n_trees = len(output) if isinstance(output, (list, tuple, TargetCollection)) else 1
# first, determine the number of files to merge in total when not already set via params
reset_n_leaves = False
if self._n_leaves is None:
# defer computation when the output is a placeholder
if is_placeholder:
self._n_leaves = 1
reset_n_leaves = True
else:
# the following lines build the workflow requirements,
# which strictly requires this task to be a workflow
wf = self.as_workflow()
# get inputs, i.e. outputs of workflow requirements and trace actual inputs to merge
# an integer number representing the number of inputs is also valid
inputs = luigi.task.getpaths(wf.merge_workflow_requires())
inputs = wf.trace_merge_workflow_inputs(inputs)
self._n_leaves = inputs if isinstance(inputs, int) else len(inputs)
# complain when there are too few leaves for the configured number of trees to create
if self._n_leaves < n_trees:
raise Exception(
f"insufficient number of leaves ({self._n_leaves}) for number of requested trees "
f"({n_trees})",
)
# determine the number of leaves per tree
n_min = self._n_leaves // n_trees
n_trees_overlap = self._n_leaves % n_trees
leaves_per_tree = n_trees_overlap * [n_min + 1] + (n_trees - n_trees_overlap) * [n_min]
merge_factor = self.merge_factor
# when the output is a placeholder, define a one-element tree
# otherwise, built the forest the normal way
forest: list[dict[int, list[tuple[int, ...]]]] = []
if is_placeholder:
forest.append({0: [(0,)]})
else:
for i, n_leaves in enumerate(leaves_per_tree):
# build a nested list of leaf numbers using the merge factor
# e.g. 9 leaves with factor 3 -> [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
nested_leaves: list[list[int]] = list(iter_chunks(n_leaves, merge_factor))
while len(nested_leaves) > 1:
nested_leaves = list(iter_chunks(nested_leaves, merge_factor))
# convert the list of nodes to the tree format described above
tree: dict[int, list[tuple[int, ...]]] = {}
for node in nodify(nested_leaves, root_id=i):
depth = len(node) - 1
tree.setdefault(depth, []).append(node)
forest.append(tree)
# store values, declare the forest as cached for now so that the check below works
self._leaves_per_tree = leaves_per_tree
self._merge_forest = forest
self._merge_forest_built = True
# complain when the depth is too large
if not self.is_forest() and self.tree_depth > self.max_tree_depth: # type: ignore[operator] # noqa
raise ValueError(
f"tree_depth {self.tree_depth} exceeds maximum depth {self.max_tree_depth} in task "
f"{self}",
)
# set the final cache decisions
self._merge_forest_built = self._cache_forest and not is_placeholder
if reset_n_leaves:
self._n_leaves = None
[docs]
def create_branch_map(self) -> dict[int, tuple[int, ...]]:
tree = self._get_tree()
nodes = tree[self.tree_depth] # type: ignore[index]
return dict(enumerate(nodes))
def trace_merge_workflow_inputs(self, inputs: Any) -> Sequence[Any] | TargetCollection:
# should convert inputs to an object with a length (e.g. list, tuple, TargetCollection, ...)
# for convenience, check if inputs results from the default workflow output, i.e. a dict
# which stores a TargetCollection in the "collection" field
if isinstance(inputs, dict) and "collection" in inputs:
collection = inputs["collection"]
if isinstance(collection, TargetCollection):
return collection
return inputs
def trace_merge_inputs(self, inputs: Any) -> Sequence[Any]:
# should convert inputs into an iterable sequence (list, tuple, ...), no TargetCollection!
return inputs
@abc.abstractmethod
def merge_workflow_requires(self) -> Any:
# should return the requirements of the merge workflow
...
@abc.abstractmethod
def merge_requires(self, start_leaf: int, end_leaf: int) -> Any:
# should return the requirements of a merge task, depending on the leaf range
...
@abc.abstractmethod
def merge_output(self) -> FileSystemFileTarget:
# this should return a single target when the output should be a single tree
# or a target collection, list or tuple with item access through tree indices
...
@abc.abstractmethod
def merge(self, inputs: list[FileSystemFileTarget], output: FileSystemTarget) -> None:
...
[docs]
def workflow_requires(self) -> Any:
self._build_merge_forest()
reqs = super(ForestMerge, self).workflow_requires()
if self.is_forest():
raise Exception(
"workflow requirements cannot be determined for the merge forest, ForestMerge "
"misconfigured",
)
elif self.is_leaf():
# this is simply the merge workflow requirement
reqs["forest_merge"] = self.merge_workflow_requires()
else:
# intermediate node, just require the next tree depth
reqs["forest_merge"] = self._req_tree(self, tree_depth=self.tree_depth + 1) # type: ignore[operator] # noqa
return reqs
def _forest_requires(self) -> dict[int, ForestMerge]:
if not self.is_forest():
raise Exception(
"_forest_requires can only be determined for the forest, ForestMerge misconfigured",
)
n_trees = len(self.merge_forest)
indices: range | list[int] = range(n_trees)
# interpret branches as tree indices when given
if self.branches:
indices = [
i
for i in range_expand(list(self.branches), min_value=0, max_value=n_trees) # type: ignore[call-overload] # noqa
if 0 <= i < n_trees
]
return {
i: self._req_tree(
self,
branch=-1,
tree_index=i,
_exclude=self.exclude_params_workflow,
)
for i in indices
}
[docs]
def requires(self) -> DotDict:
reqs = DotDict()
if self.is_forest():
reqs["forest_merge"] = self._forest_requires()
elif self.is_leaf():
# this is simply the merge requirement
reqs["forest_merge"] = self.merge_requires(*self.leaf_range)
else:
# get all child nodes in the next layer at depth = depth + 1 and store their branches
# note: child node tuples contain the exact same values plus an additional one
tree_depth: int = self.tree_depth # type: ignore[assignment]
tree = self._get_tree()
node = self.branch_data
branches = [i for i, n in enumerate(tree[tree_depth + 1]) if n[:-1] == node]
# add to requirements
reqs["forest_merge"] = {
b: self._req_tree(self, branch=b, tree_depth=tree_depth + 1)
for b in branches
}
return reqs
[docs]
def output(self) -> Any:
output = self.merge_output()
if self.is_forest():
return output
if isinstance(output, (list, tuple, TargetCollection)):
output = output[self.tree_index] # type: ignore[call-overload]
if self.is_root():
return output
# get the directory in which intermediate outputs are stored
if isinstance(output, SiblingFileCollection):
intermediate_dir = output.dir
else:
first_output = flatten(output)[0]
if not isinstance(first_output, FileSystemTarget):
raise Exception(
f"cannot determine directory for intermediate merged outputs from '{output}'",
)
intermediate_dir = first_output.parent
# helper to create an intermediate output
def get_intermediate_output(leaf_output):
name, ext = os.path.splitext(leaf_output.basename)
basename = self.node_format.format(
name=name,
ext=ext,
tree=self.tree_index,
branch=self.branch,
depth=self.tree_depth,
)
return intermediate_dir.child(basename, type="f") # type: ignore[call-arg]
# return intermediate outputs in the same structure
if isinstance(output, TargetCollection):
return output.map(get_intermediate_output)
return map_struct(get_intermediate_output, output)
[docs]
def run(self) -> None | Iterator[Any]:
# nothing to do for the forest
if self.is_forest():
# yield the forest dependencies again
yield self._forest_requires()
return None
# trace actual inputs to merge
inputs = self.input()["forest_merge"]
inputs = list(self.trace_merge_inputs(inputs) if self.is_leaf() else inputs.values())
# merge
node_position = (self.tree_index,) + self.branch_data
self.publish_message(f"start merging {len(inputs)} inputs of node {node_position}")
self.merge(inputs, self.output())
# remove intermediate nodes
if not self.is_leaf() and not self.keep_nodes:
msg = f"removing intermediate results of node {node_position}"
with self.publish_step(msg):
for inp in flatten(inputs):
inp.remove()
return None
def control_output_postfix(self):
postfix = super(ForestMerge, self).control_output_postfix()
return ("{pf}_" + self.postfix_format).format(
pf=postfix,
tree=self.tree_index,
depth=self.tree_depth,
)