# coding: utf-8
"""
Helpful decorators to use with tasks.
Example usage:
.. code-block:: python
class MyTask(law.Task):
@log
@safe_output(skip=KeyboardInterrupt)
def run(self):
...
The usage of a decorator without invocation (e.g. ``@log``) is equivalent to the one *with*
invocation (``@log()``), for law to distuinguish between the two cases **always** use keyword
arguments when configuring decorators. Default arguments are applied in either case.
"""
from __future__ import annotations
__all__ = [
"factory", "log", "safe_output", "delay", "notify", "timeit", "localize", "require_sandbox",
]
import sys
import time
import traceback
import functools
import inspect
import random
import socket
import collections
import uuid
import luigi # type: ignore[import-untyped]
from law.task.base import Task
from law.task.proxy import ProxyTask
from law.sandbox.base import SandboxTask
from law.parameter import get_param, NotifyParameter
from law.target.file import localize_file_targets
from law.target.local import LocalFileTarget
from law.util import (
NoValue, no_value, uncolored, make_list, multi_match, human_duration, join_generators,
empty_context, TeeStream,
)
from law.logger import get_logger
from law._types import Callable, Any, T
logger = get_logger(__name__)
class _CompleteTask(luigi.Task):
def complete(self) -> bool:
return True
def output(self) -> None:
return None
[docs]
def factory(**default_opts) -> Callable:
"""
Factory function to create decorators for tasks' run methods. Default options for the decorator
function can be given in *default_opts*. The returned decorator can be used with or without
function invocation. Example:
.. code-block:: python
@factory(digits=2)
def runtime(fn, opts, task, *args, **kwargs):
t0 = time.time()
try:
return fn(task, *args, **kwargs)
finally:
t1 = time.time()
diff = round(t1 - t0, opts["digits"])
print("runtime: {}".format(diff))
...
class MyTask(law.Task):
@runtime
def run(self):
...
# or
@runtime(digits=3):
def run(self):
...
In most cases, the created decorators are used to decorate run methods. As intended by luigi,
run methods can become generators by yielding tasks to declare `dynamic dependencies
<https://luigi.readthedocs.io/en/stable/tasks.html#dynamic-dependencies>`__. As luigi will
resume the run method from scratch everytime a new, incomplete dependency is yielded,
decorators are required to be idempotent. Therefore, a plain definition as shown in the example
above is not sufficient. A decorator that accepts generator functions should look like the
following:
.. code-block:: python
@factory(digits=2, accept_generator=True)
def runtime(fn, opts, task, *args, **kwargs):
def before_call():
t0 = time.time()
return t0
def call(t0):
return fn(task, *args, **kwargs)
def after_call(t0):
t1 = time.time()
diff = round(t1 - t0, opts["digits"])
print("runtime: {}".format(diff))
# optional:
def on_error(e, t0):
# called when an exception occured in call,
# return True to prevent the error from being raised
...
return before_call, call, after_call[, on_error]
``before_call()`` is invoked only once. It can be used to setup objects, etc, and its return
value is passed as a single argument to both ``call()`` and ``after_call()``, even when *None*.
The former function should (at least) call the actual wrapped function and return its result
while the latter is intended to execute custom logic afterwards. The use of a 4th function for
handling exceptions is optional. It is called when an exception is raised inside ``call()`` with
the exception instance and the return value of ``before_call`` as arguments. When its return
value is *True*, the error is not raised and the return value of the wrapped function becomes
*None*.
A decorator that accepts generator functions can also be used to decorate plain, non-generator
functions, but not vice-versa. Decorated functions can be called with a keyword argument
``skip_decorators`` set to *True* to directly call the originally wrapped function without the
stack of decorators.
"""
def wrapper(decorator: Callable) -> Callable:
@functools.wraps(decorator)
def wrapper(fn: Callable | None = None, **opts) -> Callable:
_opts = default_opts.copy()
_opts.update(opts)
def wrapper(fn: Callable) -> Callable:
# get some default options
accept_generator = _opts.setdefault("accept_generator", False)
decorate_run = _opts.setdefault("decorate_run", None)
# get the originally wrapper function
# the attribute exists when fn is already a wrapper created by another decorator
orig_attr = "__law_decorator_original_fn"
orig_fn: Callable | NoValue = getattr(fn, orig_attr, no_value)
if isinstance(orig_fn, NoValue):
orig_fn = fn
# when the orignal, wrapped function is a generator, check if the decorator is
# configured to handle them, and raise a exception if not
is_gen = inspect.isgeneratorfunction(orig_fn)
if is_gen and not accept_generator:
raise Exception(
f"decorator {decorator} is not configured to decorate a generator "
"function {orig_fn}",
)
# when decorator_run is None, guess the decision based on the name of the wrapped fn
if decorate_run is None:
decorate_run = orig_fn.__name__ == "run"
# define a unique attribute to store the result of before_call() (see below)
state_attr = "__law_decorator_{}_{}_before_call_result_{}".format(
decorator.__module__.replace(".", "_"),
decorator.__name__,
uuid.uuid4().hex,
)
@functools.wraps(fn)
def wrapper(*args, **kwargs) -> Any:
# check if the decorator stack is to be skipped entirey
if kwargs.pop("skip_decorators", False):
# args[0] is the task
return fn(*args, **kwargs)
if accept_generator:
# when generator functions are accepted, the decorator is excepted to return
# three callbacks: before_call(), call(state), and after_call(state)
# the latter two take the return value of the first one as a single argument
callbacks = tuple(decorator(fn, _opts, *args, **kwargs))
if len(callbacks) not in (3, 4):
raise Exception(
"decorators accepting generator functions must return 3 or 4 "
f"callbacks, got {len(callbacks)}",
)
# extract the callbacks
before_call, call, after_call = callbacks[:3]
if len(callbacks) == 4:
on_error = callbacks[3]
else:
def on_error(e, state):
return False
# when then wrapped function returns a generator, invoke before_call() once
# for idempotency, and extend the generator to run after_call() and reset()
if is_gen:
# wrap after_call() as it is required to be a generator
def after_call_gen(state):
yield _CompleteTask() if decorate_run else None
after_call(state)
# reset function
def reset():
yield _CompleteTask() if decorate_run else None
setattr(fn, state_attr, no_value)
# call before_call once
state = getattr(wrapper, state_attr, no_value)
if state == no_value:
state = before_call()
setattr(wrapper, state_attr, state)
# wrap on_error() to include the state
def _on_error(error):
return on_error(error, state)
# join the generators, pass the result of before_call
return join_generators(
call(state),
after_call_gen(state),
reset(),
on_error=_on_error,
)
else:
# although configured to handle it, the wrapped function is not a
# generator, so just invoke the callbacks serially and handle errors
state = before_call()
try:
result = call(state)
except (Exception, KeyboardInterrupt) as error:
if not on_error(error, state):
raise
result = None
after_call(state)
return result
else:
# the wrapped function is a plain callable, so just call it
return decorator(fn, _opts, *args, **kwargs)
# store the originally wrapped function as an attribute of the wrapper
setattr(wrapper, orig_attr, orig_fn)
return wrapper
return wrapper if fn is None else wrapper(fn) # type: ignore[return-value]
return wrapper
return wrapper
def get_task(task: Task | ProxyTask) -> Task:
return task.task if isinstance(task, ProxyTask) else task # type: ignore[return-value]
[docs]
@factory(accept_generator=False)
def log(
fn: Callable[..., T],
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> T:
""" log()
Wraps a bound method of a task and redirects output of both stdout and stderr to the file
defined by the tasks's *log_file* parameter or *default_log_file* attribute. If its value is
``"-"`` or *None*, the output is not redirected. Does **not** accept generator functions.
"""
_task = get_task(task)
log = get_param(_task.log_file, _task.default_log_file)
if log and not isinstance(log, LocalFileTarget):
log = str(log)
if log == "-" or not log:
return fn(task, *args, **kwargs)
# use the local target functionality to create the parent directory
LocalFileTarget(log).parent.touch() # type: ignore[call-arg, union-attr]
with open(log, "a", 1) as f:
tee = TeeStream(f, sys.__stdout__)
sys.stdout = tee # type: ignore[assignment]
sys.stderr = tee # type: ignore[assignment]
try:
ret = fn(task, *args, **kwargs)
except:
traceback.print_exc(file=tee)
raise
finally:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
tee.flush()
return ret
[docs]
@factory(skip=None, optional=True, accept_generator=True)
def safe_output(
fn: Callable,
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable, Callable]:
""" safe_output(skip=None, optional=True)
Wraps a bound method of a task and guards its execution. If an exception occurs, and it is not
an instance of *skip*, the task's output is removed prior to the actual raising. If *optional*
is *False*, optional targets are not removed. Accepts generator functions.
"""
def before_call() -> None:
return None
def call(state: None) -> Any:
return fn(task, *args, **kwargs)
def after_call(state: None) -> None:
return
def on_error(error: Exception, state: None) -> None:
if opts["skip"] is None or not isinstance(error, opts["skip"]):
for outp in luigi.task.flatten(task.output()):
# skip optional targets
if not opts["optional"] and getattr(outp, "optional", False):
continue
# remove the target
outp.remove()
return before_call, call, after_call, on_error
[docs]
@factory(t=5.0, stddev=0.0, pdf="gauss", accept_generator=True)
def delay(
fn: Callable,
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable]:
""" delay(t=5.0, stddev=0.0, pdf="gauss")
Wraps a bound method of a task and delays its execution by *t* seconds. Accepts generator
functions.
"""
def before_call() -> None:
return None
def call(state: None) -> Any:
if opts["stddev"] <= 0:
t = opts["t"]
elif opts["pdf"] == "gauss":
t = random.gauss(opts["t"], opts["stddev"])
elif opts["pdf"] == "uniform":
t = random.uniform(opts["t"], opts["stddev"])
else:
raise ValueError(f"unknown delay decorator pdf '{opts['pdf']}'")
time.sleep(max(t, 0))
return fn(task, *args, **kwargs)
def after_call(state: None) -> None:
return
return before_call, call, after_call
[docs]
@factory(on_success=True, on_failure=True, accept_generator=True)
def notify(
fn: Callable,
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable, Callable]:
""" notify(on_success=True, on_failure=True)
Wraps a bound method of a task and guards its execution. Information about the execution (task
name, duration, etc) is collected and dispatched to all notification transports registered on
wrapped task via adding :py:class:`law.NotifyParameter` parameters. Example:
.. code-block:: python
class MyTask(law.Task):
notify_mail = law.NotifyMailParameter()
@notify
# or
@notify(sender="foo@bar.com", recipient="user@host.tld")
def run(self):
...
When the *notify_mail* parameter is *True*, a notification is sent to the configured email
address. Also see :ref:`config-notifications`. Accepts generator functions.
"""
_task = get_task(task)
def before_call() -> tuple[list[dict], float]:
# prepare notification transports
transports = []
for param_name, param in _task.get_params():
if isinstance(param, NotifyParameter) and getattr(_task, param_name):
try:
transport = param.get_transport()
if transport:
transports += make_list(transport)
except Exception as e:
logger.warning(f"get_transport() failed for '{param_name}' parameter: {e}")
# get a timestamp
t0 = time.perf_counter()
return transports, t0
def call(state: tuple[list[dict], float]) -> Any:
return fn(task, *args, **kwargs)
def send(error: Exception | None, transports: list[dict], t0: float) -> None:
# do nothing when there are no transports
if not transports:
return
# do nothing on KeyboardInterrupt, or when on_success / on_failure do not match the status
success = error is None
if isinstance(error, KeyboardInterrupt):
return
if success and not opts["on_success"]:
return
if not success and not opts["on_failure"]:
return
# prepare message content
duration = human_duration(seconds=round(time.perf_counter() - t0, 1))
status_string = "succeeded" if success else "failed"
title = "Task {} {}!".format(_task.get_task_family(), status_string)
parts = collections.OrderedDict([
("Task", str(_task)),
("Host", socket.gethostname()),
("Duration", duration),
("Last message", "-" if not len(_task._message_cache) else _task._message_cache[-1]),
])
if not success:
parts["Traceback"] = traceback.format_exc()
message = "\n".join("{}: {}".format(*tpl) for tpl in parts.items())
# dispatch via all transports
for transport in transports:
fn = transport["func"]
raw = transport.get("raw", False)
colored = transport.get("colored", False)
# remove color commands if necessary
_content: dict[str, Any] | str
if not colored:
_title = uncolored(title)
if raw:
_content = {
k: (uncolored(v) if isinstance(v, str) else v)
for k, v in parts.items()
}
else:
_content = uncolored(message)
else:
_title = title
_content = parts.copy() if raw else message
# invoke the function
try:
fn(success, _title, _content, **opts)
except Exception as e:
t = traceback.format_exc()
logger.warning(f"notification via transport '{fn}' failed: {e}\n{t}")
def after_call(state: tuple[list[dict], float]) -> None:
return send(None, *state)
def on_error(error: Exception, state: tuple[list[dict], float]) -> None:
return send(error, *state)
return before_call, call, after_call, on_error
[docs]
@factory(accept_generator=True)
def timeit(
fn: Callable,
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable, Callable]:
""" timeit()
Wraps a bound method of a task and logs its execution time in a human readable format using the
task's logger instance in info mode. Accepts generator functions.
"""
def before_call() -> float:
t0 = time.perf_counter()
return t0
def call(t0: float) -> Any:
return fn(task, *args, **kwargs)
def log_duration(t0: float) -> None:
duration = human_duration(seconds=round(time.perf_counter() - t0, 1))
task.logger.info("runtime: {}".format(duration))
def after_call(t0: float) -> None:
log_duration(t0)
def on_error(error: Exception, t0: float) -> None:
log_duration(t0)
return before_call, call, after_call, on_error
[docs]
@factory(input=True, output=True, input_kwargs=None, output_kwargs=None, accept_generator=False)
def localize(
fn: Callable[..., T],
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> T:
""" localize(input=True, output=True, input_kwargs=None, output_kwargs=None)
Wraps a bound method of a task and temporarily changes the input and output methods to return
localized targets. When *input* (*output*) is *True*, :py:meth:`Task.input`
(:py:meth:`Task.output`) is adjusted. *input_kwargs* and *output_kwargs* can be dictionaries
that are passed as keyword arguments to the respective localization method. Does **not** accept
generator functions.
"""
# store original input and output methods
input_orig = None
output_orig = None
if opts["input"]:
input_orig = (
task.__getattribute__("input", proxy=False) # type: ignore[call-arg]
if isinstance(task, SandboxTask)
else task.input
)
if opts["output"]:
output_orig = (
task.__getattribute__("output", proxy=False) # type: ignore[call-arg]
if isinstance(task, SandboxTask)
else task.output
)
# wrap input context
input_context = empty_context
if opts["input"]:
def input_context(): # type: ignore[misc] # noqa: F811
input_struct = task.input()
input_kwargs = opts["input_kwargs"] or {}
input_kwargs.setdefault("mode", "r")
return localize_file_targets(input_struct, **input_kwargs)
# wrap output context
output_context = empty_context
if opts["output"]:
def output_context(): # type: ignore[misc] # noqa: F811
output_struct = task.output()
output_kwargs = opts["output_kwargs"] or {}
output_kwargs.setdefault("mode", "w")
return localize_file_targets(output_struct, **output_kwargs)
try:
# localize both target contexts
with input_context() as localized_inputs, output_context() as localized_outputs:
# patch the input method to always return the localized inputs
if opts["input"]:
def input_patched(self):
return localized_inputs
task.input = _patch_localized_method(task, input_patched) # type: ignore[method-assign] # noqa
# patch the output method to always return the localized outputs
if opts["output"]:
def output_patched(self):
return localized_outputs
task.output = _patch_localized_method(task, output_patched)
return fn(task, *args, **kwargs)
finally:
# restore the methods
if input_orig is not None:
task.input = input_orig # type: ignore[method-assign]
if output_orig is not None:
task.output = output_orig
def _patch_localized_method(task: Task, func: Callable) -> Callable:
# add a flag to func
func._patched_localized_method = True # type: ignore[attr-defined]
# bind to task
return func.__get__(task)
def _is_patched_localized_method(func: Callable) -> bool:
return getattr(func, "_patched_localized_method", False) is True
[docs]
@factory(sandbox=None, accept_generator=True)
def require_sandbox(
fn: Callable,
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable]:
""" require_sandbox(sandbox=None)
Wraps a bound method of a sandbox task and throws an exception when the method is called while
the task is not sandboxed yet. This is intended to prevent undesired results or non-verbose
error messages when the method is invoked outside the requested sandbox. When *sandbox* is set,
it can be a (list of) pattern(s) to compare against the task's effective sandbox and in error is
raised if they don't match. Accepts generator functions.
"""
def before_call() -> None:
if not isinstance(task, SandboxTask):
raise TypeError(
"require_sandbox can only be used to decorate methods of tasks that inherit from "
f"SandboxTask, got '{task!r}'",
)
if not task.is_sandboxed():
raise Exception(
f"the invocation of method {fn.__name__} requires task {task!r} to be sandboxed",
)
if opts["sandbox"] and not multi_match(task.effective_sandbox, make_list(opts["sandbox"])):
raise Exception(
f"the invocation of method {fn.__name__} requires the sandbox of task {task!r} to "
f"match '{opts['sandbox']}'",
)
return None
def call(state: Any) -> None:
return fn(task, *args, **kwargs)
def after_call(state: None) -> None:
return
return before_call, call, after_call