# coding: utf-8
"""
Helpful decorators to use with tasks.
Example usage:
.. code-block:: python
class MyTask(law.Task):
@log
@safe_output(skip=KeyboardInterrupt)
def run(self):
...
The usage of a decorator without invocation (e.g. ``@log``) is equivalent to the one *with*
invocation (``@log()``), for law to distuinguish between the two cases **always** use keyword
arguments when configuring decorators. Default arguments are applied in either case.
"""
__all__ = [
"factory", "log", "safe_output", "delay", "notify", "timeit", "localize", "require_sandbox",
]
import sys
import time
import traceback
import functools
import inspect
import random
import socket
import collections
import uuid
import luigi
import six
from law.task.proxy import ProxyTask
from law.sandbox.base import SandboxTask
from law.parameter import get_param, NotifyParameter
from law.target.file import localize_file_targets
from law.target.local import LocalFileTarget
from law.util import (
no_value, uncolored, make_list, multi_match, human_duration, open_compat, join_generators,
TeeStream, perf_counter, empty_context,
)
from law.logger import get_logger
logger = get_logger(__name__)
class _CompleteTask(luigi.Task):
def complete(self):
return True
def output(self):
return None
[docs]def factory(**default_opts):
"""
Factory function to create decorators for tasks' run methods. Default options for the decorator
function can be given in *default_opts*. The returned decorator can be used with or without
function invocation. Example:
.. code-block:: python
@factory(digits=2)
def runtime(fn, opts, task, *args, **kwargs):
t0 = time.time()
try:
return fn(task, *args, **kwargs)
finally:
t1 = time.time()
diff = round(t1 - t0, opts["digits"])
print("runtime: {}".format(diff))
...
class MyTask(law.Task):
@runtime
def run(self):
...
# or
@runtime(digits=3):
def run(self):
...
In most cases, the created decorators are used to decorate run methods. As intended by luigi,
run methods can become generators by yielding tasks to declare `dynamic dependencies
<https://luigi.readthedocs.io/en/stable/tasks.html#dynamic-dependencies>`__. As luigi will
resume the run method from scratch everytime a new, incomplete dependency is yielded,
decorators are required to be idempotent. Therefore, a plain definition as shown in the example
above is not sufficient. A decorator that accepts generator functions should look like the
following:
.. code-block:: python
@factory(digits=2, accept_generator=True)
def runtime(fn, opts, task, *args, **kwargs):
def before_call():
t0 = time.time()
return t0
def call(t0):
return fn(task, *args, **kwargs)
def after_call(t0):
t1 = time.time()
diff = round(t1 - t0, opts["digits"])
print("runtime: {}".format(diff))
# optional:
def on_error(e, t0):
# called when an exception occured in call,
# return True to prevent the error from being raised
...
return before_call, call, after_call[, on_error]
``before_call()`` is invoked only once. It can be used to setup objects, etc, and its return
value is passed as a single argument to both ``call()`` and ``after_call()``, even when *None*.
The former function should (at least) call the actual wrapped function and return its result
while the latter is intended to execute custom logic afterwards. The use of a 4th function for
handling exceptions is optional. It is called when an exception is raised inside ``call()`` with
the exception instance and the return value of ``before_call`` as arguments. When its return
value is *True*, the error is not raised and the return value of the wrapped function becomes
*None*.
A decorator that accepts generator functions can also be used to decorate plain, non-generator
functions, but not vice-versa. Decorated functions can be called with a keyword argument
``skip_decorators`` set to *True* to directly call the originally wrapped function without the
stack of decorators.
"""
def wrapper(decorator):
@functools.wraps(decorator)
def wrapper(fn=None, **opts):
_opts = default_opts.copy()
_opts.update(opts)
def wrapper(fn):
# get some default options
accept_generator = _opts.setdefault("accept_generator", False)
decorate_run = _opts.setdefault("decorate_run", None)
# get the originally wrapper function
# the attribute exists when fn is already a wrapper created by another decorator
orig_attr = "__law_decorator_original_fn"
orig_fn = getattr(fn, orig_attr, no_value)
if orig_fn == no_value:
orig_fn = fn
# when the orignal, wrapped function is a generator, check if the decorator is
# configured to handle them, and raise a exception if not
is_gen = inspect.isgeneratorfunction(orig_fn)
if is_gen and not accept_generator:
raise Exception("decorator {} is not configured to decorate a generator "
"function {}".format(decorator, orig_fn))
# when decorator_run is None, guess the decision based on the name of the wrapped fn
if decorate_run is None:
decorate_run = orig_fn.__name__ == "run"
# define a unique attribute to store the result of before_call() (see below)
state_attr = "__law_decorator_{}_{}_before_call_result_{}".format(
decorator.__module__.replace(".", "_"), decorator.__name__, uuid.uuid4().hex)
@functools.wraps(fn)
def wrapper(*args, **kwargs):
# check if the decorator stack is to be skipped entirey
if kwargs.pop("skip_decorators", False):
# args[0] is the task
return fn(*args, **kwargs)
if accept_generator:
# when generator functions are accepted, the decorator is excepted to return
# three callbacks: before_call(), call(state), and after_call(state)
# the latter two take the return value of the first one as a single argument
callbacks = tuple(decorator(fn, _opts, *args, **kwargs))
if len(callbacks) not in (3, 4):
raise Exception("decorators accepting generator functions must return "
"3 or 4 callbacks, got {}".format(len(callbacks)))
# extract the callbacks
before_call, call, after_call = callbacks[:3]
if len(callbacks) == 4:
on_error = callbacks[3]
else:
def on_error(e, state):
return False
# when then wrapped function returns a generator, invoke before_call() once
# for idempotency, and extend the generator to run after_call() and reset()
if is_gen:
# wrap after_call() as it is required to be a generator
def after_call_gen(state):
yield _CompleteTask() if decorate_run else None
after_call(state)
# reset function
def reset():
yield _CompleteTask() if decorate_run else None
setattr(fn, state_attr, no_value)
# call before_call once
state = getattr(wrapper, state_attr, no_value)
if state == no_value:
state = before_call()
setattr(wrapper, state_attr, state)
# wrap on_error() to include the state
def _on_error(error):
return on_error(error, state)
# join the generators, pass the result of before_call
return join_generators(call(state), after_call_gen(state), reset(),
on_error=_on_error)
else:
# although configured to handle it, the wrapped function is not a
# generator, so just invoke the callbacks serially and handle errors
state = before_call()
try:
result = call(state)
except (Exception, KeyboardInterrupt) as error:
if on_error(error, state):
result = None
else:
raise
after_call(state)
return result
else:
# the wrapped function is a plain callable, so just call it
return decorator(fn, _opts, *args, **kwargs)
# store the originally wrapped function as an attribute of the wrapper
setattr(wrapper, orig_attr, orig_fn)
return wrapper
return wrapper if fn is None else wrapper(fn)
return wrapper
return wrapper
def get_task(task):
return task if not isinstance(task, ProxyTask) else task.task
[docs]@factory(accept_generator=False)
def log(fn, opts, task, *args, **kwargs):
""" log()
Wraps a bound method of a task and redirects output of both stdout and stderr to the file
defined by the tasks's *log_file* parameter or *default_log_file* attribute. If its value is
``"-"`` or *None*, the output is not redirected. Does **not** accept generator functions.
"""
_task = get_task(task)
log = get_param(_task.log_file, _task.default_log_file)
if log and not isinstance(log, LocalFileTarget):
log = str(log)
if log == "-" or not log:
return fn(task, *args, **kwargs)
else:
# use the local target functionality to create the parent directory
LocalFileTarget(log).parent.touch()
with open_compat(log, "a", 1) as f:
tee = TeeStream(f, sys.__stdout__)
sys.stdout = tee
sys.stderr = tee
try:
ret = fn(task, *args, **kwargs)
except:
traceback.print_exc(file=tee)
raise
finally:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
tee.flush()
return ret
[docs]@factory(skip=None, accept_generator=True)
def safe_output(fn, opts, task, *args, **kwargs):
""" safe_output(skip=None)
Wraps a bound method of a task and guards its execution. If an exception occurs, and it is not
an instance of *skip*, the task's output is removed prior to the actual raising. Accepts
generator functions.
"""
def before_call():
return None
def call(state):
return fn(task, *args, **kwargs)
def after_call(state):
return
def on_error(error, state):
if opts["skip"] is None or not isinstance(error, opts["skip"]):
for outp in luigi.task.flatten(task.output()):
outp.remove()
return before_call, call, after_call, on_error
[docs]@factory(t=5.0, stddev=0.0, pdf="gauss", accept_generator=True)
def delay(fn, opts, task, *args, **kwargs):
""" delay(t=5.0, stddev=0.0, pdf="gauss")
Wraps a bound method of a task and delays its execution by *t* seconds. Accepts generator
functions.
"""
def before_call():
return None
def call(state):
if opts["stddev"] <= 0:
t = opts["t"]
elif opts["pdf"] == "gauss":
t = random.gauss(opts["t"], opts["stddev"])
elif opts["pdf"] == "uniform":
t = random.uniform(opts["t"], opts["stddev"])
else:
raise ValueError("unknown delay decorator pdf '{}'".format(opts["pdf"]))
time.sleep(max(t, 0))
return fn(task, *args, **kwargs)
def after_call(state):
return
return before_call, call, after_call
[docs]@factory(on_success=True, on_failure=True, accept_generator=True)
def notify(fn, opts, task, *args, **kwargs):
""" notify(on_success=True, on_failure=True)
Wraps a bound method of a task and guards its execution. Information about the execution (task
name, duration, etc) is collected and dispatched to all notification transports registered on
wrapped task via adding :py:class:`law.NotifyParameter` parameters. Example:
.. code-block:: python
class MyTask(law.Task):
notify_mail = law.NotifyMailParameter()
@notify
# or
@notify(sender="foo@bar.com", recipient="user@host.tld")
def run(self):
...
When the *notify_mail* parameter is *True*, a notification is sent to the configured email
address. Also see :ref:`config-notifications`. Accepts generator functions.
"""
_task = get_task(task)
def before_call():
# prepare notification transports
transports = []
for param_name, param in _task.get_params():
if isinstance(param, NotifyParameter) and getattr(_task, param_name):
try:
transport = param.get_transport()
if transport:
transports += make_list(transport)
except Exception as e:
logger.warning("get_transport() failed for '{}' parameter: {}".format(
param_name, e))
# get a timestamp
t0 = perf_counter()
return transports, t0
def call(state):
return fn(task, *args, **kwargs)
def send(error, transports, t0):
# do nothing when there are no transports
if not transports:
return
# do nothing on KeyboardInterrupt, or when on_success / on_failure do not match the status
success = error is None
if isinstance(error, KeyboardInterrupt):
return
elif success and not opts["on_success"]:
return
elif not success and not opts["on_failure"]:
return
# prepare message content
duration = human_duration(seconds=round(perf_counter() - t0, 1))
status_string = "succeeded" if success else "failed"
title = "Task {} {}!".format(_task.get_task_family(), status_string)
parts = collections.OrderedDict([
("Host", socket.gethostname()),
("Duration", duration),
("Last message", "-" if not len(_task._message_cache) else _task._message_cache[-1]),
("Task", str(_task)),
])
if not success:
parts["Traceback"] = traceback.format_exc()
message = "\n".join("{}: {}".format(*tpl) for tpl in parts.items())
# dispatch via all transports
for transport in transports:
fn = transport["func"]
raw = transport.get("raw", False)
colored = transport.get("colored", False)
# remove color commands if necessary
if not colored:
_title = uncolored(title)
if raw:
_content = {
k: (uncolored(v) if isinstance(v, six.string_types) else v)
for k, v in parts.items()
}
else:
_content = uncolored(message)
else:
_title = title
_content = parts.copy() if raw else message
# invoke the function
try:
fn(success, _title, _content, **opts)
except Exception as e:
t = traceback.format_exc()
logger.warning("notification via transport '{}' failed: {}\n{}".format(fn, e, t))
def after_call(state):
return send(None, *state)
def on_error(error, state):
return send(error, *state)
return before_call, call, after_call, on_error
[docs]@factory(accept_generator=True)
def timeit(fn, opts, task, *args, **kwargs):
""" timeit()
Wraps a bound method of a task and logs its execution time in a human readable format using the
task's logger instance in info mode. Accepts generator functions.
"""
def before_call():
t0 = perf_counter()
return t0
def call(t0):
return fn(task, *args, **kwargs)
def log_duration(t0):
duration = human_duration(seconds=round(perf_counter() - t0, 1))
task.logger.info("runtime: {}".format(duration))
def after_call(t0):
log_duration(t0)
def on_error(error, t0):
log_duration(t0)
return before_call, call, after_call, on_error
[docs]@factory(input=True, output=True, input_kwargs=None, output_kwargs=None, accept_generator=False)
def localize(fn, opts, task, *args, **kwargs):
""" localize(input=True, output=True, input_kwargs=None, output_kwargs=None)
Wraps a bound method of a task and temporarily changes the input and output methods to return
localized targets. When *input* (*output*) is *True*, :py:meth:`Task.input`
(:py:meth:`Task.output`) is adjusted. *input_kwargs* and *output_kwargs* can be dictionaries
that are passed as keyword arguments to the respective localization method. Does **not** accept
generator functions.
"""
# store original input and output methods
input_orig = None
output_orig = None
if opts["input"]:
input_orig = (
task.__getattribute__("input", proxy=False)
if isinstance(task, SandboxTask)
else task.input
)
if opts["output"]:
output_orig = (
task.__getattribute__("output", proxy=False)
if isinstance(task, SandboxTask)
else task.output
)
# wrap input context
input_context = empty_context
if opts["input"]:
def input_context(): # noqa: F811
input_struct = task.input()
input_kwargs = opts["input_kwargs"] or {}
input_kwargs.setdefault("mode", "r")
return localize_file_targets(input_struct, **input_kwargs)
# wrap output context
output_context = empty_context
if opts["output"]:
def output_context(): # noqa: F811
output_struct = task.output()
output_kwargs = opts["output_kwargs"] or {}
output_kwargs.setdefault("mode", "w")
return localize_file_targets(output_struct, **output_kwargs)
try:
# localize both target contexts
with input_context() as localized_inputs, output_context() as localized_outputs:
# patch the input method to always return the localized inputs
if opts["input"]:
def input_patched(self):
return localized_inputs
task.input = _patch_localized_method(task, input_patched)
# patch the output method to always return the localized outputs
if opts["output"]:
def output_patched(self):
return localized_outputs
task.output = _patch_localized_method(task, output_patched)
return fn(task, *args, **kwargs)
finally:
# restore the methods
if input_orig is not None:
task.input = input_orig
if output_orig is not None:
task.output = output_orig
def _patch_localized_method(task, func):
# add a flag to func
func._patched_localized_method = True
# bind to task
return func.__get__(task)
def _is_patched_localized_method(func):
return getattr(func, "_patched_localized_method", False) is True
[docs]@factory(sandbox=None, accept_generator=True)
def require_sandbox(fn, opts, task, *args, **kwargs):
""" require_sandbox(sandbox=None)
Wraps a bound method of a sandbox task and throws an exception when the method is called while
the task is not sandboxed yet. This is intended to prevent undesired results or non-verbose
error messages when the method is invoked outside the requested sandbox. When *sandbox* is set,
it can be a (list of) pattern(s) to compare against the task's effective sandbox and in error is
raised if they don't match. Accepts generator functions.
"""
def before_call():
if not isinstance(task, SandboxTask):
raise TypeError("require_sandbox can only be used to decorate methods of tasks that "
"inherit from SandboxTask, got '{!r}'".format(task))
if not task.is_sandboxed():
raise Exception("the invocation of method {} requires task {!r} to be sandboxed".format(
fn.__name__, task))
if opts["sandbox"] and not multi_match(task.effective_sandbox, make_list(opts["sandbox"])):
raise Exception("the invocation of method {} requires the sandbox of task {!r} to "
"match '{}'" .format(fn.__name__, task, opts["sandbox"]))
return None
def call(state):
return fn(task, *args, **kwargs)
def after_call(state):
return
return before_call, call, after_call