# coding: utf-8
"""
Custom luigi base task definitions.
"""
__all__ = ["Task", "WrapperTask", "ExternalTask"]
import sys
import socket
import logging
from collections import OrderedDict
from contextlib import contextmanager
from abc import ABCMeta, abstractmethod
import inspect
import luigi
import six
from law.config import Config
from law.parameter import NO_STR, CSVParameter
from law.target.file import localize_file_targets
from law.parser import root_task, global_cmdline_values
from law.logger import setup_logger
from law.util import (
no_value, abort, law_run, common_task_params, colored, uncolored, make_list, multi_match,
flatten, BaseStream, human_duration, patch_object, round_discrete, empty_context, perf_counter,
)
from law.logger import get_logger
logger = get_logger(__name__)
getfullargspec = inspect.getfullargspec if six.PY3 else inspect.getargspec
[docs]class BaseRegister(luigi.task_register.Register):
def __new__(metacls, classname, bases, classdict):
# default attributes, irrespective of inheritance
classdict.setdefault("exclude_index", False)
# unite "exclude_params_*" sets with those of all base classes
for base in bases:
for attr, base_params in vars(base).items():
if attr.startswith("exclude_params_") and isinstance(base_params, set):
params = classdict.setdefault(attr, set())
if isinstance(params, set):
params.update(base_params)
# remove those parameter names from "exclude_params_*" sets which are explicitly
# listed in corresponding "include_params_*" sets defined on the class itself
for attr, include_params in classdict.items():
if attr.startswith("include_params_") and isinstance(include_params, set):
exclude_attr = "exclude" + attr[len("include"):]
if exclude_attr in classdict and isinstance(classdict[exclude_attr], set):
classdict[exclude_attr] -= include_params
# create the class
cls = ABCMeta.__new__(metacls, classname, bases, classdict)
# default attributes, apart from inheritance
if getattr(cls, "update_register", None) is None:
cls.update_register = False
# deregister when requested
if cls.update_register:
cls.deregister()
# add to register (mimic luigi.task_register.Register.__new__)
cls._namespace_at_class_time = metacls._get_namespace(cls.__module__)
metacls._reg.append(cls)
return cls
[docs]class BaseTask(six.with_metaclass(BaseRegister, luigi.Task)):
# whether to cache the result of requires() for input() and potentially also other calls
cache_requirements = False
exclude_index = True
exclude_params_index = set()
exclude_params_req = set()
exclude_params_req_set = set()
exclude_params_req_get = set()
prefer_params_cli = set()
@staticmethod
def resource_name(name, host=None):
if host is None:
host = socket.gethostname().partition(".")[0]
return "{}_{}".format(host, name)
[docs] @classmethod
def deregister(cls, task_cls=None):
"""
Removes a task class *task_cls* from the luigi task register. When *None*, *this* class is
used. Task family strings and patterns are accepted as well. *True* is returned when at
least one class was successfully removed, and *False* otherwise.
"""
# always compare task families
if task_cls is None:
task_family = cls.get_task_family()
elif isinstance(task_cls, six.string_types):
task_family = task_cls
else:
task_family = task_cls.get_task_family()
success = False
# remove from the register
i = -1
while True:
i += 1
if i >= len(Register._reg):
break
registered_cls = Register._reg[i]
if multi_match(registered_cls.get_task_family(), task_family, mode=any):
Register._reg.pop(i)
i -= 1
success = True
logger.debug("removed task class {} from register".format(registered_cls))
return success
[docs] @classmethod
def modify_param_args(cls, params, args, kwargs):
"""
Hook to modify command line arguments before they are event created within
:py:meth:`get_param_values`.
"""
return params, args, kwargs
[docs] @classmethod
def modify_param_values(cls, params):
"""
Hook to modify command line arguments before instances of this class are created.
"""
return params
[docs] @classmethod
def get_param_values(cls, params, args, kwargs):
# call the hook optionally modifying the values before values are assigned
params, args, kwargs = cls.modify_param_args(params, args, kwargs)
# assign to actual parameters
values = super(BaseTask, cls).get_param_values(params, args, kwargs)
# parse left-over strings in case values were given programmatically, but unencoded
values = [
(
name,
(getattr(cls, name).parse(value) if isinstance(value, six.string_types) else value),
)
for name, value in values
]
# call the hook optionally modifying the values afterwards,
# but remove temporary objects that might have been placed into it
param_names = {name for name, _ in params}
values = [
(name, value)
for name, value in cls.modify_param_values(OrderedDict(values)).items()
if name in param_names
]
return values
@classmethod
def req(cls, inst, **kwargs):
return cls(**cls.req_params(inst, **kwargs))
@classmethod
def req_params(cls, inst, _exclude=None, _prefer_cli=None, _skip_task_excludes=False,
_skip_task_excludes_get=None, _skip_task_excludes_set=None, **kwargs):
# common/intersection params
params = common_task_params(inst, cls)
# determine parameters to exclude
_exclude = set() if _exclude is None else set(make_list(_exclude))
# also use this class' req and req_get sets
# and the req and req_set sets of the instance's class
# unless explicitly skipped
if _skip_task_excludes_get is None:
_skip_task_excludes_get = _skip_task_excludes
if not _skip_task_excludes_get:
_exclude.update(cls.exclude_params_req, cls.exclude_params_req_get)
if _skip_task_excludes_set is None:
_skip_task_excludes_set = _skip_task_excludes
if not _skip_task_excludes_set:
_exclude.update(inst.exclude_params_req, inst.exclude_params_req_set)
# remove excluded parameters
for name in list(params.keys()):
if multi_match(name, _exclude, any):
del params[name]
# add kwargs
params.update(kwargs)
# remove params that are preferably set via cli class arguments
prefer_cli = set(cls.prefer_params_cli or ()) if _prefer_cli is None else set(_prefer_cli)
if prefer_cli:
cls_args = []
prefix = cls.get_task_family() + "_"
if luigi.cmdline_parser.CmdlineParser.get_instance():
for key in global_cmdline_values().keys():
if key.startswith(prefix):
cls_args.append(key[len(prefix):])
for name in prefer_cli:
if name in params and name in cls_args:
del params[name]
return params
def __init__(self, *args, **kwargs):
super(BaseTask, self).__init__(*args, **kwargs)
# task level logger, created lazily
self._task_logger = None
# attribute for cached requirements if enabled
self._cached_requirements = no_value
[docs] def complete(self):
# create a flat list of all outputs
outputs = flatten(self.output())
if len(outputs) == 0:
logger.warning("task {!r} has no outputs or no custom complete() method".format(self))
return True
return all(t.complete() for t in outputs)
[docs] @abstractmethod
def run(self):
return
def get_logger_name(self):
return self.task_id
def _create_logger(self, name, level=None):
return setup_logger(name, level=level)
@property
def logger(self):
if not self._task_logger:
name = self.get_logger_name()
existing = name in logging.root.manager.loggerDict
self._task_logger = logging.getLogger(name) if existing else self._create_logger(name)
return self._task_logger
@property
def live_task_id(self):
"""
The task id depends on the task family and parameters, and is generated by luigi once in the
constructor. As the latter may change, this property returns to the id with the current set
of parameters.
"""
# create a temporary dictionary of param_kwargs that is patched for the duration of the
# call to create the string representation of the parameters
param_kwargs = {attr: getattr(self, attr) for attr in self.param_kwargs}
# only_public was introduced in luigi 2.8.0, so check if that arg exists
str_params_kwargs = {"only_significant": True}
if "only_public" in getfullargspec(self.to_str_params).args:
str_params_kwargs["only_public"] = True
with patch_object(self, "param_kwargs", param_kwargs):
str_params = self.to_str_params(**str_params_kwargs)
# create the task id
task_id = luigi.task.task_id_str(self.get_task_family(), str_params)
return task_id
def walk_deps(self, max_depth=-1, order="level", yield_last_flag=False):
# see https://en.wikipedia.org/wiki/Tree_traversal
if order not in ("level", "pre"):
raise ValueError("unknown traversal order '{}', use 'level' or 'pre'".format(order))
# yielding the last flag as well is only available in 'pre' order
if order != "pre" and yield_last_flag:
raise ValueError("yield_last_flag can only be used in 'pre' order, but got '{}'".format(
order))
tasks = [(self, 0)]
while len(tasks):
task, depth = tasks.pop(0)
deps = flatten(task.requires())
next_depth = tasks[0][1] if tasks else None
# define the tuple of objects to yield
tpl = (task, deps, depth)
if not yield_last_flag:
yield tpl
# define the next deps, considering the maximum depth if set
deps = (
(d, depth + 1)
for d in deps
if max_depth < 0 or depth < max_depth
)
# add to the tasks run process, depending on the traversal order
if order == "level":
tasks[len(tasks):] = deps
elif order == "pre":
tasks[:0] = deps
# when an additional flag should be yielded that denotes whether the object is the last
# one in its depth, evaluate this decision here and then yield
# note: this assumes that the deps were not changed by the using context
if yield_last_flag:
is_last = next_depth is None or next_depth < depth
yield tpl + (is_last,)
def cli_args(self, exclude=None, replace=None, skip_empty_bools=True):
exclude = set() if exclude is None else set(make_list(exclude))
if replace is None:
replace = {}
args = OrderedDict()
for name, param in self.get_params():
# skip excluded parameters
if multi_match(name, exclude, any):
continue
# get the raw value
raw = replace.get(name, getattr(self, name))
# when the parameter is bool but its value is None, skip it
if skip_empty_bools and isinstance(param, luigi.BoolParameter) and raw is None:
continue
# serialize and add it
args["--" + name.replace("_", "-")] = str(param.serialize(raw))
return args
[docs]class Register(BaseRegister):
def __call__(cls, *args, **kwargs):
inst = super(Register, cls).__call__(*args, **kwargs)
# check for interactive parameters
for param in inst.interactive_params:
value = getattr(inst, param)
if value:
# reset the interactive parameter
setattr(inst, param, ())
# at this point, inst must be the root task so set the global value
root_task(inst)
skip_abort = False
try:
logger.debug("evaluating interactive parameter '{}' with value {}".format(
param, value))
skip_abort = getattr(inst, "_" + param)(value)
except KeyboardInterrupt:
print("\naborted")
# abort the process if not explicitly skipped
if not skip_abort:
abort(exitcode=0)
print("")
return inst
[docs]class Task(six.with_metaclass(Register, BaseTask)):
log_file = luigi.Parameter(
default=NO_STR,
significant=False,
description="a custom log file; default: <task.default_log_file>",
)
print_deps = CSVParameter(
default=(),
significant=False,
description="print task dependencies but do not run any task; this CSV parameter accepts a "
"single integer value which sets the task recursion depth (0 means non-recursive)",
)
print_status = CSVParameter(
default=(),
significant=False,
description="print the task status but do not run any task; this CSV parameter accepts up "
"to three values: 1. the task recursion depth (0 means non-recursive), 2. the depth of the "
"status text of target collections (default: 0), 3. a flag that is passed to the status "
"text creation (default: '')",
)
print_output = CSVParameter(
default=(),
significant=False,
description="print a flat list of output targets but do not run any task; this CSV "
"parameter accepts up to two values: 1. the task recursion depth (0 means non-recursive), "
"2. a boolean flag that decides whether paths of file targets should contain file system "
"schemes (default: True)",
)
remove_output = CSVParameter(
default=(),
significant=False,
description="remove task outputs but do not run any task by default; this CSV parameter "
"accepts up to three values: 1. the task recursion depth (0 means non-recursive), 2. one "
"of the modes 'i' (interactive), 'a' (all), 'd' (dry run) (default: 'i'), 3. a boolean "
"flag that decides whether the task is run after outputs were removed (default: False)",
)
fetch_output = CSVParameter(
default=(),
significant=False,
description="copy all task outputs into a local directory but do not run any task; this "
"CSV parameter accepts up to four values: 1. the task recursion depth (0 means "
"non-recursive), 2. one of the modes 'i' (interactive), 'a' (all), 'd' (dry run) (default: "
"'i'), 3. the target directory (default: '.'), 4. a boolean flag that decides whether "
"external outputs and outputs of external tasks should be fetched (default: False)",
)
interactive_params = [
"print_deps", "print_status", "print_output", "fetch_output", "remove_output",
]
# cache size for published messages
message_cache_size = 10
# force skipping this task when remove_output is set to "all" mode
skip_output_removal = False
exclude_index = True
exclude_params_req = set()
exclude_params_repr = set()
exclude_params_repr_empty = set()
@classmethod
def req_params(cls, inst, _exclude=None, _prefer_cli=None, **kwargs):
_exclude = set() if _exclude is None else set(make_list(_exclude))
# always exclude interactive parameters
_exclude |= set(inst.interactive_params)
return super(Task, cls).req_params(inst, _exclude=_exclude, _prefer_cli=_prefer_cli,
**kwargs)
def __init__(self, *args, **kwargs):
super(Task, self).__init__(*args, **kwargs)
# cache for messages published to the scheduler
self._message_cache = []
# cache for the last progress published to the scheduler
self._last_progress_percentage = None
@property
def default_log_file(self):
return "-"
def is_root_task(self):
return root_task() == self
def publish_message(self, msg, stdout=sys.stdout, scheduler=True, **kwargs):
msg = str(msg)
# write to stdout
if stdout:
stdout.write(msg + "\n")
stdout.flush()
# publish to the scheduler
if scheduler:
self._publish_message(msg, **kwargs)
def _publish_message(self, msg, flush_cache=False, silent=False):
msg = uncolored(str(msg))
# flush the message cache?
if flush_cache:
del self._message_cache[:]
# add to message cache and handle overflow
self._message_cache.append(msg)
if self.message_cache_size >= 0:
end = max(len(self._message_cache) - self.message_cache_size, 0)
del self._message_cache[:end]
# set status message based on the full, current message cache
if callable(getattr(self, "set_status_message", None)):
self.set_status_message("\n".join(self._message_cache))
elif not silent:
logger.warning("set_status_message not set, cannot send task message to scheduler")
def _create_message_stream(self, *args, **kwargs):
return TaskMessageStream(self, *args, **kwargs)
def _create_logger(self, name, level=None, **kwargs):
return setup_logger(name, level=level, add_console_handler={
"handler_kwargs": {"stream": self._create_message_stream(**kwargs)},
})
@contextmanager
def publish_step(self, msg, success_message="done", fail_message="failed", runtime=True,
scheduler=True, flush_cache=False):
self.publish_message(msg, scheduler=scheduler, flush_cache=flush_cache)
success = False
t0 = perf_counter()
try:
yield
success = True
finally:
msg = success_message if success else fail_message
if runtime:
diff = perf_counter() - t0
msg = "{} (took {})".format(msg, human_duration(seconds=diff))
self.publish_message(msg, scheduler=scheduler, flush_cache=flush_cache)
def publish_progress(self, percentage, precision=1):
percentage = int(round_discrete(percentage, precision, "floor"))
if percentage != self._last_progress_percentage:
self._last_progress_percentage = percentage
if callable(getattr(self, "set_progress_percentage", None)):
self.set_progress_percentage(percentage)
else:
logger.warning("set_progress_percentage not set, cannot send task progress to "
"scheduler")
def create_progress_callback(self, n_total, reach=(0, 100), precision=1):
def make_callback(n, start, end):
def callback(i):
self.publish_progress(start + (i + 1) / float(n) * (end - start), precision)
return callback
if isinstance(n_total, (list, tuple)):
width = 100.0 / len(n_total)
reaches = [(width * i, width * (i + 1)) for i in range(len(n_total))]
return n_total.__class__(make_callback(n, *r) for n, r in zip(n_total, reaches))
else:
return make_callback(n_total, *reach)
def iter_progress(self, iterable, n_total, reach=(0, 100), precision=1, msg=None):
# create a progress callback with all arguments
progress_callback = self.create_progress_callback(n_total, reach=reach, precision=precision)
# when msg is set, place the iteration in an outer context
context = (lambda: self.publish_step(msg)) if msg else empty_context
# iterate and invoke the callback
with context():
for i, val in enumerate(iterable):
yield val
progress_callback(i)
def cli_args(self, exclude=None, replace=None):
exclude = set() if exclude is None else set(make_list(exclude))
# always exclude interactive parameters
exclude |= set(self.interactive_params)
return super(Task, self).cli_args(exclude=exclude, replace=replace)
def __repr__(self):
color = Config.instance().get_expanded_bool("task", "colored_repr")
return self.repr(color=color)
def __str__(self):
color = Config.instance().get_expanded_bool("task", "colored_str")
return self.repr(color=color)
def repr(self, all_params=False, color=None, **kwargs):
if color is None:
color = Config.instance().get_expanded_bool("task", "colored_repr")
family = self._repr_family(self.get_task_family(), color=color, **kwargs)
parts = [
self._repr_param(name, value, color=color, **kwargs)
for name, value in six.iteritems(self._repr_params(all_params=all_params))
] + [
self._repr_flag(flag, color=color, **kwargs)
for flag in self._repr_flags()
]
return "{}({})".format(family, ", ".join(parts))
def _repr_params(self, all_params=False):
# determine parameters to exclude
exclude = set()
if not all_params:
exclude |= self.exclude_params_repr
exclude |= set(self.interactive_params)
# build a map "name -> value" for all significant parameters
params = OrderedDict()
for name, param in self.get_params():
value = getattr(self, name)
include = (
param.significant and
not multi_match(name, exclude) and
(name not in self.exclude_params_repr_empty or value)
)
if include:
params[name] = value
return params
def _repr_flags(self):
return []
def _repr_family(self, family, color=False, **kwargs):
return colored(family, "green") if color else family
def _repr_param(self, name, value, color=False, serialize=True, **kwargs):
# try to serialize first unless explicitly disabled
if serialize:
param = getattr(self.__class__, name, no_value)
if param != no_value:
value = param.serialize(value) if isinstance(param, luigi.Parameter) else param
return "{}={}".format(colored(name, color="blue", style="bright") if color else name, value)
def _repr_flag(self, name, color=False, **kwargs):
return colored(name, color="magenta") if color else name
def _print_deps(self, args):
return print_task_deps(self, *args)
def _print_status(self, args):
return print_task_status(self, *args)
def _print_output(self, args):
return print_task_output(self, *args)
def _remove_output(self, args):
return remove_task_output(self, *args)
def _fetch_output(self, args):
return fetch_task_output(self, *args)
@classmethod
def _law_run_inst(cls, inst, _exclude=None, _replace=None, _global=None, _run_kwargs=None):
# get the cli arguments
args = inst.cli_args(exclude=_exclude, replace=_replace)
# prepend a space to values starting with "-"
for key, value in args.items():
if value.startswith("-"):
args[key] = " {}".format(value)
# flatten them
flat_args = sum((make_list(tpl) for tpl in args.items()), [])
# add global parameters when given
if _global:
flat_args.extend([str(arg) for arg in make_list(_global)])
# build the full command
cmd = [cls.get_task_family()] + flat_args
# run it
return law_run(cmd, **(_run_kwargs or {}))
@classmethod
def law_run_inst(cls, _exclude=None, _replace=None, _global=None, _run_kwargs=None, **kwargs):
# create a new instance
inst = cls(**kwargs)
return cls._law_run_inst(inst, _exclude=_exclude, _replace=_replace, _global=_global,
_run_kwargs=_run_kwargs)
def law_run(self, _exclude=None, _replace=None, _global=None, _run_kwargs=None, **kwargs):
# when kwargs are given, create a new instance
inst = self.req(self, **kwargs) if kwargs else self
return self._law_run_inst(inst, _exclude=_exclude, _replace=_replace, _global=_global,
_run_kwargs=_run_kwargs)
def localize_input(self, *args, **kwargs):
return localize_file_targets(self.input(), *args, **kwargs)
def localize_output(self, *args, **kwargs):
return localize_file_targets(self.output(), *args, **kwargs)
[docs]class WrapperTask(Task):
"""
Use for tasks that only wrap other tasks and that by definition are done
if all their requirements exist.
"""
exclude_index = True
def _repr_flags(self):
return super(WrapperTask, self)._repr_flags() + ["wrapper"]
[docs] def complete(self):
# get potentially cached requirements
if self.cache_requirements:
if self._cached_requirements is no_value:
self._cached_requirements = self.requires()
reqs = self._cached_requirements
else:
reqs = self.requires()
return all(task.complete() for task in flatten(reqs))
[docs] def output(self):
return self.input()
[docs]class ExternalTask(Task):
exclude_index = True
run = None
def _repr_flags(self):
return super(ExternalTask, self)._repr_flags() + ["external"]
class TaskMessageStream(BaseStream):
def __init__(self, task, stdout=sys.stdout, scheduler=True, flush_cache=False, **kwargs):
super(TaskMessageStream, self).__init__(**kwargs)
self.task = task
self.stdout = stdout
self.scheduler = scheduler
self.flush_cache = flush_cache
def _write(self, msg):
# foward to publish_message
self.task.publish_message(msg.rstrip("\n"), stdout=self.stdout, scheduler=self.scheduler,
flush_cache=self.flush_cache, silent=True)
# trailing imports
from law.task.interactive import (
print_task_deps, print_task_status, print_task_output, remove_task_output, fetch_task_output,
)