"""
Functions that are invoked by interactive task methods.
"""
from __future__ import annotations
__all__ = [
"fetch_task_output",
"print_task_deps",
"print_task_output",
"print_task_status",
"remove_task_output",
]
import os
import pathlib
import re
from law._types import Any, Iterator
from law.config import Config
from law.logger import get_logger
from law.target.base import Target
from law.target.collection import FileCollection, TargetCollection
from law.target.file import FileSystemTarget
from law.task.base import ExternalTask, Task
from law.util import (
colored,
flag_to_bool,
flatten,
get_terminal_width,
human_bytes,
is_lazy_iterable,
make_list,
makedirs,
merge_dicts,
multi_match,
query_choice,
uncolor_cre,
uncolored,
)
logger = get_logger(__name__)
# formatting characters
fmt_chars: dict[str, dict[str, str | int]] = {
"plain": {
"ind": 2,
"free": 1,
"-": "-",
"t": "+",
"l": "+",
"|": "|",
">": ">",
},
"fancy": {
"ind": 2,
"free": 1,
"-": "─",
"t": "├",
"l": "└",
"|": "│",
">": ">",
},
}
fmt_chars["compact"] = merge_dicts(fmt_chars["plain"], {"free": 0})
fmt_chars["fancy_compact"] = merge_dicts(fmt_chars["fancy"], {"free": 0})
# helper to create a list of 3-tuples (target, depth, prefix) of an arbitrarily structured output
def _flatten_output(output: Any, depth: int) -> list[tuple[Target, int, str]]:
if isinstance(output, (list, tuple, set)) or is_lazy_iterable(output):
return [(outp, depth, f"{i}: ") for i, outp in enumerate(output)]
if isinstance(output, dict):
return [(outp, depth, f"{k}: ") for k, outp in output.items()]
return [(outp, depth, "") for outp in flatten(output)]
def _iter_output(
output: Any,
offset: str,
ind: str = " ",
) -> Iterator[tuple[Target, int, str, str, list]]:
lookup = _flatten_output(output, 0)
while lookup:
output, odepth, oprefix = lookup.pop(0)
ooffset = offset + odepth * ind
if isinstance(output, Target):
yield output, odepth, oprefix, ooffset, lookup
else:
# before updating the lookup list, but check if the output changes by this
_lookup = _flatten_output(output, odepth + 1)
if len(_lookup) > 0 and _lookup[0][0] == output:
print(ooffset + oprefix + colored("not a target", color="red"))
else:
# print the key of the current structure
print(ooffset + oprefix)
# update the lookup list
lookup[:0] = _lookup
def _print_wrapped(line: str, width: int | None, offset: str = "") -> None:
# when the width is not set or the line is empty, just print the line
if not line or width is None or width <= 0:
print(line)
return
# split into actual strings to print (even parts) and color/style modifiers (odd parts) for
# proper width computation
parts = [(part, i % 2 == 1) for i, part in enumerate(uncolor_cre.split(line))]
# build lines with odd parts until the line is filled
line, length, last_style = "", 0, ""
while parts:
part, is_style = parts.pop(0)
if is_style:
# style modifier
line += part
last_style = part
elif length + len(part) <= width:
# actual string that still fits
line += part
length += len(part)
else:
# actual string that would overflow the line, so add the characters that would still fit
# and then print the line
n = width - length
line += part[:n]
print(line)
# add the remaining characters with an uncolored offset and reset the state
parts[:0] = [
("\033[0m", True),
(uncolored(offset), False),
(last_style, True),
(part[n:], False),
]
line, length, last_style = "", 0, ""
# print any leftover line
if line:
print(line)
def _parse_stopping_condition(condition: int | str) -> tuple[int, list[str]]:
# returns a maximum depth value and a list of task family patterns:
# - a depth of -1 is returned in case no depth could be extracted
# - an empty list for the patterns is returned in case no patterns could be extracted
if isinstance(condition, int):
return condition, []
max_depth = -1
family_patterns = []
for part in str(condition).strip().split("|"):
part = part.strip()
if part.lstrip("-").isdigit():
max_depth = int(part)
else:
family_patterns.append(part)
return max_depth, family_patterns
[docs]
def print_task_deps(task: Task, stopping_condition: int | str = 1) -> None:
# parse the stopping condition
max_depth, family_patterns = _parse_stopping_condition(stopping_condition)
# show a verbose message
msg = []
if max_depth >= 0 or not family_patterns:
msg.append(f"with max_depth {max_depth}")
if family_patterns:
msg.append(f"up to task families '{','.join(family_patterns)}'")
print(f"print task dependencies {' and '.join(msg)}\n")
# get the format chars
cfg = Config.instance()
fmt_name = cfg.get_expanded("task", "interactive_format")
fmt: dict[str, str | int] = fmt_chars.get(fmt_name, fmt_chars["fancy"])
# get the line break setting
break_lines = cfg.get_expanded_bool("task", "interactive_line_breaks")
out_width = cfg.get_expanded_int("task", "interactive_line_width")
print_width = (out_width if out_width > 0 else get_terminal_width()) if break_lines else None
_print = lambda line, offset: _print_wrapped(line, print_width, offset)
parents_last_flags: list[bool] = []
for dep, next_deps, depth, is_last in task.walk_deps( # type: ignore[misc]
max_depth=max_depth,
order="pre",
yield_last_flag=True,
):
if family_patterns and multi_match(dep.task_family, family_patterns, mode=any):
next_deps.clear()
del parents_last_flags[depth:]
next_deps_shown = bool(next_deps) and (max_depth < 0 or depth < max_depth)
# determine the print common offset
offset = "".join([
f"{' ' if f else fmt['|']}{' ' * int(fmt['ind'])}"
for f in parents_last_flags[1:]
])
parents_last_flags.append(is_last)
# print free space
free_offset = f"{offset}{fmt['|']}"
free_lines = "\n".join(int(fmt["free"]) * [free_offset])
if depth > 0 and free_lines:
print(free_lines)
# determine task offset and prefix
task_offset = offset
if depth > 0:
task_offset += f"{fmt['l' if is_last else 't']}{int(fmt['ind']) * fmt['-']}"
task_prefix = f"{depth} {fmt['>']} "
# determine text offset and prefix
text_offset = offset
if depth > 0:
text_offset += f"{' ' if is_last else fmt['|']}{int(fmt['ind']) * ' '}"
text_prefix = (len(task_prefix) - 1) * " "
text_offset += f"{fmt['|'] if next_deps_shown else ' '}{text_prefix}"
# print the task line
_print(task_offset + task_prefix + dep.repr(color=True), text_offset) # type: ignore[union-attr]
[docs]
def print_task_status(
task: Task,
stopping_condition: int | str = 0,
target_depth: int = 0,
flags: str | None = None,
) -> None:
from law.workflow.base import BaseWorkflow
# parse the stopping condition
max_depth, family_patterns = _parse_stopping_condition(stopping_condition)
target_depth = int(target_depth)
flags_tuple: tuple[str, ...] = tuple(flags.lower().split("-")) if flags else ()
# show a verbose message
msg = []
if max_depth >= 0 or not family_patterns:
msg.append(f"with max_depth {max_depth}")
if family_patterns:
msg.append(f"up to task families '{','.join(family_patterns)}'")
print(f"print task status {' and '.join(msg)} and target_depth {target_depth}\n")
# get the format chars
cfg = Config.instance()
fmt_name = cfg.get_expanded("task", "interactive_format")
fmt = fmt_chars.get(fmt_name, fmt_chars["fancy"])
# get the line break setting
break_lines = cfg.get_expanded_bool("task", "interactive_line_breaks")
out_width = cfg.get_expanded_int("task", "interactive_line_width")
print_width = (out_width if out_width > 0 else get_terminal_width()) if break_lines else None
_print = lambda line, offset: _print_wrapped(line, print_width, offset)
# get other settings
skip_seen = cfg.get_expanded_bool("task", "interactive_status_skip_seen")
# walk through deps
done = set()
parents_last_flags: list[bool] = []
for dep, next_deps, depth, is_last in task.walk_deps( # type: ignore[misc]
max_depth=max_depth,
order="pre",
yield_last_flag=True,
):
if family_patterns and multi_match(dep.task_family, family_patterns, mode=any):
next_deps.clear()
del parents_last_flags[depth:]
next_deps_shown = bool(next_deps) and (max_depth < 0 or depth < max_depth)
# determine the print common offset
offset = "".join([
f"{' ' if f else fmt['|']}{' ' * int(fmt['ind'])}"
for f in parents_last_flags[1:]
])
parents_last_flags.append(is_last)
# print free space
free_offset = f"{offset}{fmt['|']}"
free_lines = "\n".join(int(fmt["free"]) * [free_offset])
if depth > 0 and free_lines:
print(free_lines)
# when the dep is a workflow, independent of its create_branch_map_before_repr setting,
# preload its branch map which updates branch parameters
if isinstance(dep, BaseWorkflow):
dep.get_branch_map()
# determine task offset and prefix
task_offset = offset
if depth > 0:
task_offset += f"{fmt['l' if is_last else 't']}{int(fmt['ind']) * fmt['-']}"
task_prefix = f"{depth} {fmt['>']} "
# determine text offset and prefix
text_offset = offset
if depth > 0:
text_offset += f"{' ' if is_last else fmt['|']}{int(fmt['ind']) * ' '}"
text_prefix = (len(task_prefix) - 1) * " "
text_offset += f"{fmt['|'] if next_deps_shown else ' '}{text_prefix}"
text_offset_ind = text_offset + int(fmt["ind"]) * " "
# print the task line
_print(task_offset + task_prefix + dep.repr(color=True), text_offset) # type: ignore[union-attr]
# skip if already seen
if skip_seen and dep in done:
_print(text_offset_ind + colored("outputs already checked", "yellow"), text_offset_ind)
continue
done.add(dep)
# compiled regex for splitting leading whitespace
ws_cre = re.compile(r"^(\s*)(.*)$")
# start the traversing
for output, _, oprefix, ooffset, _ in _iter_output(
dep.output(),
text_offset_ind,
int(fmt["ind"]) * " ",
):
_print(ooffset + oprefix + output.repr(color=True), ooffset + len(oprefix) * " ")
ooffset += int(fmt["ind"]) * " "
status_text = output.status_text(max_depth=target_depth, flags=flags_tuple, color=True)
status_lines = status_text.split("\n")
_print(ooffset + status_lines[0], ooffset)
for line in status_lines[1:]:
m = ws_cre.match(line)
line_offset = "" if m is None else m.group(1)
_print(ooffset + line, ooffset + line_offset + int(fmt["ind"]) * " ")
[docs]
def print_task_output(task: Task, stopping_condition: int | str = 0, scheme: bool = True) -> None:
# parse the stopping condition
max_depth, family_patterns = _parse_stopping_condition(stopping_condition)
scheme = flag_to_bool(scheme) # type: ignore[assignment]
# show a verbose message
msg = []
if max_depth >= 0 or not family_patterns:
msg.append(f"with max_depth {max_depth}")
if family_patterns:
msg.append(f"up to task families '{','.join(family_patterns)}'")
scheme_str = "showing" if scheme else "hiding"
print(f"print task output {' and '.join(msg)}, {scheme_str} schemes\n")
done_deps = set()
done_uris = set()
for dep, next_deps, _ in task.walk_deps(max_depth=max_depth, order="pre"): # type: ignore[misc]
if dep in done_deps:
continue
done_deps.add(dep)
if family_patterns and multi_match(dep.task_family, family_patterns, mode=any):
next_deps.clear()
for outp in flatten(dep.output()):
kwargs = {}
if isinstance(outp, (FileSystemTarget, FileCollection)):
kwargs = {"scheme": scheme}
for uri in make_list(outp.uri(**kwargs)):
if uri in done_uris:
continue
done_uris.add(uri)
print(uri)
[docs]
def remove_task_output(
task: Task,
stopping_condition: int | str = 0,
mode: str | None = None,
run_task: bool = False,
) -> bool:
from law.workflow.base import BaseWorkflow
# parse the stopping condition
max_depth, family_patterns = _parse_stopping_condition(stopping_condition)
# show a verbose message
msg = []
if max_depth >= 0 or not family_patterns:
msg.append(f"with max_depth {max_depth}")
if family_patterns:
msg.append(f"up to task families '{','.join(family_patterns)}'")
print(f"remove task output {' and '.join(msg)}")
run_task = flag_to_bool(run_task) # type: ignore[assignment]
if run_task:
print("task will run after output removal")
print("")
# get the format chars
cfg = Config.instance()
fmt_name = cfg.get_expanded("task", "interactive_format")
fmt = fmt_chars.get(fmt_name, fmt_chars["fancy"])
local_sync = cfg.get_expanded_bool("target", "interactive_removal_local_sync")
# get the line break setting
break_lines = cfg.get_expanded_bool("task", "interactive_line_breaks")
out_width = cfg.get_expanded_int("task", "interactive_line_width")
print_width = [(out_width if out_width > 0 else get_terminal_width()) if break_lines else None]
_print = lambda line, offset: _print_wrapped(line, print_width[0], offset)
# custom query_choice function that updates the terminal_width
def _query_choice(*args, **kwargs) -> str:
if print_width[0]:
print_width[0] = out_width if out_width > 0 else get_terminal_width()
return query_choice(*args, **kwargs)
# determine the mode, i.e., interactive, dry, all
modes = ["i", "d", "a"]
mode_names = ["interactive", "dry", "all"]
if mode and mode not in modes:
raise Exception(f"unknown removal mode '{mode}'")
if not mode:
mode = _query_choice("removal mode?", modes, default="i", descriptions=mode_names)
mode_name = mode_names[modes.index(mode)]
print(f"selected {colored(mode_name, 'blue', style='bright')} mode")
print("")
done = set()
parents_last_flags: list[bool] = []
for dep, next_deps, depth, is_last in task.walk_deps( # type: ignore[misc]
max_depth=max_depth,
order="pre",
yield_last_flag=True,
):
if family_patterns and multi_match(dep.task_family, family_patterns, mode=any):
next_deps.clear()
del parents_last_flags[depth:]
next_deps_shown = bool(next_deps) and (max_depth < 0 or depth < max_depth)
# determine the print common offset
offset = "".join([
f"{' ' if f else fmt['|']}{' ' * int(fmt['ind'])}"
for f in parents_last_flags[1:]
])
parents_last_flags.append(is_last)
# print free space
free_offset = f"{offset}{fmt['|']}"
free_lines = "\n".join(int(fmt["free"]) * [free_offset])
if depth > 0 and free_lines:
print(free_lines)
# when the dep is a workflow, independent of its create_branch_map_before_repr setting,
# preload its branch map which updates branch parameters
if isinstance(dep, BaseWorkflow):
dep.get_branch_map()
# determine task offset and prefix
task_offset = offset
if depth > 0:
task_offset += f"{fmt['l' if is_last else 't']}{int(fmt['ind']) * fmt['-']}"
task_prefix = f"{depth} {fmt['>']} "
# determine text offset and prefix
text_offset = offset
if depth > 0:
text_offset += f"{' ' if is_last else fmt['|']}{int(fmt['ind']) * ' '}"
text_prefix = (len(task_prefix) - 1) * " "
text_offset += f"{fmt['|'] if next_deps_shown else ' '}{text_prefix}"
text_offset_ind = text_offset + int(fmt["ind"]) * " "
# print the task line
_print(task_offset + task_prefix + dep.repr(color=True), text_offset) # type: ignore[union-attr]
# always skip external tasks
if isinstance(dep, ExternalTask):
_print(text_offset_ind + colored("task is external", "yellow"), text_offset_ind)
continue
# skip when this task was already handled
if dep in done:
_print(text_offset_ind + colored("already handled", "yellow"), text_offset_ind)
continue
done.add(dep)
# skip when mode is "all" and task is configured to skip
if mode == "a" and getattr(dep, "skip_output_removal", False):
_print(text_offset_ind + colored("configured to skip", "yellow"), text_offset_ind)
continue
# query for a decision per task when mode is "interactive"
task_mode = None
if mode == "i":
task_mode = _query_choice(
text_offset_ind + "remove outputs?",
["y", "n", "a"],
default="y",
descriptions=["yes", "no", "all"],
)
if task_mode == "n":
continue
# start the traversing through output structure
for output, odepth, oprefix, ooffset, lookup in _iter_output(
dep.output(),
text_offset_ind,
int(fmt["ind"]) * " ",
):
_print(ooffset + oprefix + output.repr(color=True), ooffset + len(oprefix) * " ")
ooffset += int(fmt["ind"]) * " "
# skip external targets
if getattr(output, "external", False):
_print(ooffset + colored("external output", "yellow"), ooffset)
continue
# stop here when in dry mode
if mode == "d":
_print(ooffset + colored("dry removed", "yellow"), ooffset)
continue
# when the mode is "interactive" and the task decision is not "all", query per output
if mode == "i" and task_mode != "a":
if isinstance(output, TargetCollection):
coll_choice = _query_choice(
ooffset + "remove?",
["y", "n", "i"],
default="n",
descriptions=["yes", "no", "interactive"],
)
if coll_choice == "i":
lookup[:0] = _flatten_output(output.targets, odepth + 1)
continue
target_choice = coll_choice
else:
target_choice = _query_choice(
ooffset + "remove?",
["y", "n"],
default="n",
descriptions=["yes", "no"],
)
if target_choice == "n":
_print(ooffset + colored("skipped", "yellow"), ooffset)
continue
# finally remove
output.remove(local_sync=local_sync) # type: ignore[call-arg]
_print(ooffset + colored("removed", "red", style="bright"), ooffset)
return run_task
[docs]
def fetch_task_output(
task: Task,
stopping_condition: int | str = 0,
mode: str | int | None = None,
target_dir: str | pathlib.Path = ".",
unique_names: bool = True,
include_external: bool = False,
) -> None:
from law.workflow.base import BaseWorkflow
# parse the stopping condition
max_depth, family_patterns = _parse_stopping_condition(stopping_condition)
# show a verbose message
msg = []
if max_depth >= 0 or not family_patterns:
msg.append(f"with max_depth {max_depth}")
if family_patterns:
msg.append(f"up to task families '{','.join(family_patterns)}'")
print(f"fetch task output {' and '.join(msg)}")
target_dir = os.path.normpath(os.path.abspath(str(target_dir)))
print(f"target directory is {target_dir}")
makedirs(target_dir)
include_external = flag_to_bool(include_external) # type: ignore[assignment]
if include_external:
print("include external tasks")
print("")
# get the format chars
cfg = Config.instance()
fmt_name = cfg.get_expanded("task", "interactive_format")
fmt = fmt_chars.get(fmt_name, fmt_chars["fancy"])
# get the line break setting
break_lines = cfg.get_expanded_bool("task", "interactive_line_breaks")
out_width = cfg.get_expanded_int("task", "interactive_line_width")
print_width = [(out_width if out_width > 0 else get_terminal_width()) if break_lines else None]
_print = lambda line, offset: _print_wrapped(line, print_width[0], offset)
# custom query_choice function that updates the terminal_width
def _query_choice(*args, **kwargs) -> str:
if print_width[0]:
print_width[0] = out_width if out_width > 0 else get_terminal_width()
return query_choice(*args, **kwargs)
# determine the mode, i.e., all, dry, interactive
modes = ["i", "a", "d"]
mode_names = ["interactive", "all", "dry"]
if mode is None:
mode = _query_choice("fetch mode?", modes, default="i", descriptions=mode_names)
elif isinstance(mode, int):
mode = modes[mode]
else:
mode = mode[0].lower()
if mode not in modes:
raise Exception(f"unknown fetch mode '{mode}'")
mode_name = mode_names[modes.index(mode)]
print(f"selected {colored(mode_name, 'blue', style='bright')} mode")
print("")
done = set()
parents_last_flags: list[bool] = []
for dep, next_deps, depth, is_last in task.walk_deps( # type: ignore[misc]
max_depth=max_depth,
order="pre",
yield_last_flag=True,
):
if family_patterns and multi_match(dep.task_family, family_patterns, mode=any):
next_deps.clear()
del parents_last_flags[depth:]
next_deps_shown = bool(next_deps) and (max_depth < 0 or depth < max_depth)
# determine the print common offset
offset = "".join([
f"{' ' if f else fmt['|']}{' ' * int(fmt['ind'])}"
for f in parents_last_flags[1:]
])
parents_last_flags.append(is_last)
# print free space
free_offset = f"{offset}{fmt['|']}"
free_lines = "\n".join(int(fmt["free"]) * [free_offset])
if depth > 0 and free_lines:
print(free_lines)
# when the dep is a workflow, independent of its create_branch_map_before_repr setting,
# preload its branch map which updates branch parameters
if isinstance(dep, BaseWorkflow):
dep.get_branch_map()
# determine task offset and prefix
task_offset = offset
if depth > 0:
task_offset += f"{fmt['l' if is_last else 't']}{int(fmt['ind']) * fmt['-']}"
task_prefix = f"{depth} {fmt['>']} "
# determine text offset and prefix
text_offset = offset
if depth > 0:
text_offset += f"{' ' if is_last else fmt['|']}{int(fmt['ind']) * ' '}"
text_prefix = (len(task_prefix) - 1) * " "
text_offset += f"{fmt['|'] if next_deps_shown else ' '}{text_prefix}"
text_offset_ind = text_offset + int(fmt["ind"]) * " "
# print the task line
_print(task_offset + task_prefix + dep.repr(color=True), text_offset) # type: ignore[union-attr]
if not include_external and isinstance(dep, ExternalTask):
_print(text_offset_ind + colored("task is external", "yellow"), text_offset_ind)
continue
if dep in done:
_print(text_offset_ind + colored("outputs already fetched", "yellow"), text_offset_ind)
continue
if mode == "i":
task_mode = _query_choice(
text_offset_ind + "fetch outputs?",
["y", "n", "a"],
default="y",
descriptions=["yes", "no", "all"],
)
if task_mode == "n":
_print(text_offset_ind + colored("skipped", "yellow"), text_offset_ind)
continue
done.add(dep)
# start the traversing through output structure with a lookup pattern
for output, odepth, oprefix, ooffset, lookup in _iter_output(
dep.output(),
text_offset_ind,
int(fmt["ind"]) * " ",
):
try:
stat = output.stat() # type: ignore[attr-defined]
except Exception:
stat = None
# print the target repr
target_line = ooffset + oprefix + output.repr(color=True)
if stat:
target_line += " ({:.2f} {})".format(*human_bytes(stat.st_size))
_print(target_line, ooffset + len(oprefix) * " ")
ooffset += int(fmt["ind"]) * " "
# skip external targets
if not include_external and getattr(output, "external", False):
_print(ooffset + colored("external output, skip", "yellow"), ooffset)
continue
# skip missing targets
if stat is None and not isinstance(output, TargetCollection):
_print(ooffset + colored("not existing, skip", "yellow"), ooffset)
continue
# skip targets without a copy_to_local method
is_copyable = callable(getattr(output, "copy_to_local", None))
if not is_copyable and not isinstance(output, TargetCollection):
_print(ooffset + colored("not a file target, skip", "yellow"), ooffset)
continue
# stop here when in dry mode
if mode == "d":
_print(ooffset + colored("dry fetched", "yellow"), ooffset)
continue
# collect actual outputs to fetch
to_fetch = [output]
if mode == "i" and task_mode != "a":
if isinstance(output, TargetCollection):
coll_choice = _query_choice(
ooffset + "fetch?",
["y", "n", "i"],
default="y",
descriptions=["yes", "no", "interactive"],
)
if coll_choice == "i":
lookup[:0] = _flatten_output(output.targets, odepth + 1)
continue
target_choice = coll_choice
to_fetch = list(output._flat_target_list)
else:
target_choice = _query_choice(
ooffset + "fetch?",
["y", "n"],
default="y",
descriptions=["yes", "no"],
)
if target_choice == "n":
_print(ooffset + colored("skipped", "yellow"), ooffset)
continue
# flatten all target collections
to_fetch_flat = []
while to_fetch:
t = to_fetch.pop(0)
if isinstance(t, TargetCollection):
to_fetch[:0] = list(t._flat_target_list)
else:
to_fetch_flat.append(t)
# actual copy
for outp in to_fetch_flat:
if not callable(getattr(outp, "copy_to_local", None)):
continue
# define the basename
basename: str = outp.basename # type: ignore[attr-defined]
if unique_names:
basename = f"{dep.live_task_id}__{basename}"
# copy and log
outp.copy_to_local(os.path.join(target_dir, basename), retries=0) # type: ignore[attr-defined]
_print(
ooffset + f"{colored('fetched', 'green', style='bright')} ({basename})",
ooffset,
)