Source code for law.sandbox.base

# coding: utf-8

"""
Abstract defintions that enable task sandboxing.
"""

__all__ = ["Sandbox", "SandboxTask"]


import os
import sys
import shlex
from abc import ABCMeta, abstractmethod, abstractproperty
from contextlib import contextmanager
from fnmatch import fnmatch
from collections import OrderedDict

import luigi
import six

from law.config import Config
from law.task.proxy import ProxyTask, ProxyAttributeTask, ProxyCommand
from law.target.local import LocalDirectoryTarget
from law.target.collection import TargetCollection
from law.parameter import NO_STR
from law.parser import root_task
from law.util import (
    colored, is_pattern, multi_match, mask_struct, map_struct, interruptable_popen, patch_object,
    flatten,
)
from law.logger import get_logger


logger = get_logger(__name__)

_current_sandbox = os.getenv("LAW_SANDBOX", "").split(",")

_sandbox_switched = os.getenv("LAW_SANDBOX_SWITCHED", "") == "1"

_sandbox_task_id = os.getenv("LAW_SANDBOX_TASK_ID", "")

_sandbox_worker_id = os.getenv("LAW_SANDBOX_WORKER_ID", "")

_sandbox_worker_first_task_id = os.getenv("LAW_SANDBOX_WORKER_FIRST_TASK_ID", "")

_sandbox_is_root_task = os.getenv("LAW_SANDBOX_IS_ROOT_TASK", "") == "1"

_sandbox_stagein_dir = os.getenv("LAW_SANDBOX_STAGEIN_DIR", "")

_sandbox_stageout_dir = os.getenv("LAW_SANDBOX_STAGEOUT_DIR", "")


# certain values must be present in a sandbox
if _sandbox_switched:
    if not _current_sandbox or not _current_sandbox[0]:
        raise Exception("LAW_SANDBOX must not be empty in a sandbox")
    if not _sandbox_task_id:
        raise Exception("LAW_SANDBOX_TASK_ID must not be empty in a sandbox")
    if not _sandbox_worker_id:
        raise Exception("LAW_SANDBOX_WORKER_ID must not be empty in a sandbox")
    if not _sandbox_worker_first_task_id:
        raise Exception("LAW_SANDBOX_WORKER_FIRST_TASK_ID must not be empty in a sandbox")


[docs]class StageInfo(object): def __init__(self, targets, stage_dir, staged_targets): super(StageInfo, self).__init__() self.targets = targets self.stage_dir = stage_dir self.staged_targets = staged_targets def __str__(self): tmpl = "{}.{} object at {}:\n targets : {}\n stage_dir : {}\n staged_targets: {}" return tmpl.format( self.__class__.__module__, self.__class__.__name__, hex(id(self)), self.targets, self.stage_dir.path, self.staged_targets, ) def __repr__(self): return str(self)
class SandboxVariables(object): fields = () eq_fields = ("name",) @classmethod def from_name(cls, name): if not name: raise ValueError("cannot create {} from empty name '{}'".format(cls.__name__, name)) # default implementation return cls(name, **cls.parse_name(name)) @classmethod def parse_name(cls, name): if not cls.fields: return {} values = {} for i, pair in enumerate(name.split(Sandbox.delimiter)): if "=" not in pair: if i >= len(cls.fields): raise ValueError("invalid format of {} item '{}'".format(cls.__name__, pair)) key = cls.fields[i] value = pair else: key, value = pair.split("=", 1) if key not in cls.fields: raise ValueError("invalid {} key '{}'".format(cls.__name__, key)) if key in values: raise ValueError("duplicate {} key '{}'".format(cls.__name__, key)) values[key] = value return values def __init__(self, name): super(SandboxVariables, self).__init__() # intended to be read-only self._name = str(name) @property def name(self): return self._name def __str__(self): return self._name def __eq__(self, other): return ( isinstance(other, self.__class__) and all(getattr(self, attr) == getattr(other, attr) for attr in self.eq_fields) )
[docs]class Sandbox(six.with_metaclass(ABCMeta, object)): """ Sandbox definition. The config section used by instances if this or inheriting classes constructed using :py:attr:`config_section_prefix` followed by ``"_sandbox"`` and optional postifxes. The minimal set of options in the main section are: - ``"stagein_dir_name"`` (usually ``"stagein"``) - ``"stageout_dir_name"`` (usually ``"stageout"``) - ``"law_executable"`` (usually ``"law"``) """ delimiter = "::" variable_cls = SandboxVariables # cached envs _envs = {} @classmethod def check_key(cls, key, silent=False): # commas are not allowed since the LAW_SANDBOX env variable is allowed to contain multiple # comma-separated sandbox keys that need to be separated if "," in key: if silent: return False raise ValueError("invalid sandbox key format '{}'".format(key)) return True @classmethod def split_key(cls, key): parts = str(key).split(cls.delimiter, 1) if len(parts) != 2 or any(not p.strip() for p in parts): raise ValueError("invalid sandbox key '{}'".format(key)) return tuple(parts)
[docs] @classmethod def join_key(cls, _type, name): """ join_key(type, name) """ return str(_type) + cls.delimiter + str(name)
@classmethod def create_variables(cls, name): return cls.variable_cls.from_name(name) @classmethod def new(cls, key, *args, **kwargs): # check for key format cls.check_key(key, silent=False) # split the key into the sandbox type and name _type, name = cls.split_key(key) # loop recursively through subclasses and find class that matches the sandbox_type classes = list(cls.__subclasses__()) while classes: _cls = classes.pop(0) if getattr(_cls, "sandbox_type", None) == _type: return _cls(name, *args, **kwargs) classes.extend(_cls.__subclasses__()) raise Exception("no sandbox with type '{}' found".format(_type)) def __init__(self, name, task=None, env_cache_path=None): super(Sandbox, self).__init__() # when a task is set, it must be a SandboxTask instance if task and not isinstance(task, SandboxTask): raise TypeError("sandbox task must be a SandboxTask instance, got {}".format(task)) self.variables = self.create_variables(name) self.task = task self.env_cache_path = ( os.path.abspath(os.path.expandvars(os.path.expanduser(str(env_cache_path)))) if env_cache_path else None ) # target staging info self.stagein_info = None self.stageout_info = None def __str__(self): return self.key @property def name(self): return self.variables.name def is_active(self): return self.key in _current_sandbox @property def key(self): return self.join_key(self.sandbox_type, self.name) def scheduler_on_host(self): config = luigi.interface.core() return multi_match(config.scheduler_host, ["0.0.0.0", "127.0.0.1", "localhost"]) def force_local_scheduler(self): return False @abstractproperty def config_section_prefix(self): return @abstractproperty def env_cache_key(self): return @abstractmethod def create_env(self): return @abstractmethod def cmd(self, proxy_cmd): return @property def env(self): cache_key = (self.sandbox_type, self.env_cache_key) if cache_key not in self._envs: self._envs[cache_key] = self.create_env() return self._envs[cache_key] def run(self, cmd, stdout=None, stderr=None): if stdout is None: stdout = sys.stdout if stderr is None: stderr = sys.stderr return interruptable_popen( cmd, shell=True, executable="/bin/bash", stdout=stdout, stderr=stderr, env=self.env, ) def get_custom_config_section_postfix(self): return self.name def get_config_section(self, postfix=None): section = self.config_section_prefix + "_sandbox" if postfix: section += "_" + postfix custom_section = "{}_{}".format(section, self.get_custom_config_section_postfix()) cfg = Config.instance() return custom_section if cfg.has_section(custom_section) else section def _get_env(self): # environment variables to set env = OrderedDict() # default sandboxing variables env["LAW_SANDBOX"] = self.key.replace("$", r"\$") env["LAW_SANDBOX_SWITCHED"] = "1" if self.task: env["LAW_SANDBOX_TASK_ID"] = self.task.live_task_id env["LAW_SANDBOX_ROOT_TASK_ID"] = root_task().task_id env["LAW_SANDBOX_IS_ROOT_TASK"] = str(int(self.task.is_root_task())) if getattr(self.task, "_worker_id", None): env["LAW_SANDBOX_WORKER_ID"] = self.task._worker_id if getattr(self.task, "_worker_first_task_id", None): env["LAW_SANDBOX_WORKER_FIRST_TASK_ID"] = self.task._worker_first_task_id # extend by variables from the config file cfg = Config.instance() section = self.get_config_section(postfix="env") for name, value in cfg.items(section): if is_pattern(name): names = [key for key in os.environ.keys() if fnmatch(key, name)] else: names = [name] for name in names: # when there is only a key present, i.e., no value is set, # get it from the current environment env[name] = value if value is not None else os.getenv(name, "") # extend by variables defined on task level if self.task: task_env = self.task.sandbox_env(env) if task_env: env.update(task_env) return env def _get_volumes(self): volumes = OrderedDict() # extend by volumes from the config file cfg = Config.instance() section = self.get_config_section(postfix="volumes") for hdir, cdir in cfg.items(section): volumes[hdir] = cdir # extend by volumes defined on task level if self.task: task_volumes = self.task.sandbox_volumes(volumes) if task_volumes: volumes.update(task_volumes) return volumes def _expand_volume(self, vol, bin_dir=None, python_dir=None): def replace(vol, name, repl): # warn about the deprecation of the legacy format "${name}" (until v0.1) var = "{{LAW_FORWARD_" + name + "}}" vol = vol.replace(var, repl) return vol if bin_dir: vol = replace(vol, "BIN", str(bin_dir)) if python_dir: vol = replace(vol, "PY", str(python_dir)) return vol def _build_export_commands(self, env): export_cmds = [] for key, value in env.items(): export_cmds.append("export {}=\"{}\"".format(key, value)) # noqa: Q003 return export_cmds def _build_pre_setup_cmds(self, env=None): # commands that run before the setup is performed pre_setup_cmds = [] if env: pre_setup_cmds.extend(self._build_export_commands(env)) if self.task: pre_setup_cmds.extend(self.task.sandbox_pre_setup_cmds()) return pre_setup_cmds def _build_post_setup_cmds(self, env=None): # commands that run before the setup is performed post_setup_cmds = [] if env: post_setup_cmds.extend(self._build_export_commands(env)) if self.task: post_setup_cmds.extend(self.task.sandbox_post_setup_cmds()) return post_setup_cmds
[docs]class SandboxProxy(ProxyTask):
[docs] def output(self): return None
@property def sandbox_inst(self): return self.task.sandbox_inst def create_proxy_cmd(self): return ProxyCommand( self.task, exclude_task_args=self.task.exclude_params_sandbox, exclude_global_args=["workers"], executable=self.task.sandbox_law_executable(), )
[docs] def run(self): # pre_run hook if callable(self.task.sandbox_pre_run): self.task.sandbox_pre_run() # create a temporary direction for file staging tmp_dir = LocalDirectoryTarget(is_tmp=True) tmp_dir.touch() # stage-in input files stagein_info = self.stagein(tmp_dir) if stagein_info: # tell the sandbox self.sandbox_inst.stagein_info = stagein_info logger.debug("configured sandbox stage-in data") # prepare stage-out stageout_info = self.prepare_stageout(tmp_dir) if stageout_info: # tell the sandbox self.sandbox_inst.stageout_info = stageout_info logger.debug("configured sandbox stage-out data") # create the actual command to run cmd = self.sandbox_inst.cmd(self.create_proxy_cmd()) # run with log section before and after actual run call with self._run_context(cmd): code, out, err = self.sandbox_inst.run(cmd) if code != 0: raise Exception( "sandbox '{}' failed with exit code {}, please see the error inside the " "sandboxed context above for details".format(self.sandbox_inst.key, code), ) # actual stage_out if stageout_info: self.stageout(stageout_info) # post_run hook if callable(self.task.sandbox_post_run): self.task.sandbox_post_run()
def stagein(self, tmp_dir): # check if the stage-in dir is set cfg = Config.instance() section = self.sandbox_inst.get_config_section() stagein_dir_name = cfg.get_expanded(section, "stagein_dir_name") if not stagein_dir_name: return None # determine inputs as seen by the sandbox with patch_object(os, "environ", self.task.env, lock=True): sandbox_inputs = self.task.input() # get the sandbox stage-in mask stagein_mask = self.task.sandbox_stagein(sandbox_inputs) if not stagein_mask: return None # apply the mask sandbox_inputs = mask_struct(stagein_mask, sandbox_inputs) if not sandbox_inputs: return None # create the stage-in directory stagein_dir = tmp_dir.child(stagein_dir_name, type="d") stagein_dir.touch() # create localized sandbox input representations staged_inputs = create_staged_target_struct(stagein_dir, sandbox_inputs) # perform the actual stage-in via copying flat_sandbox_inputs = flatten(sandbox_inputs) flat_staged_inputs = flatten(staged_inputs) while flat_sandbox_inputs: sandbox_input = flat_sandbox_inputs.pop(0) staged_input = flat_staged_inputs.pop(0) if isinstance(sandbox_input, TargetCollection): flat_sandbox_inputs = sandbox_input._flat_target_list + flat_sandbox_inputs flat_staged_inputs = staged_input._flat_target_list + flat_staged_inputs continue logger.debug("stage-in {} to {}".format(sandbox_input.path, staged_input.path)) sandbox_input.copy_to_local(staged_input) logger.info("staged-in {} file(s)".format(len(stagein_dir.listdir()))) return StageInfo(sandbox_inputs, stagein_dir, staged_inputs) def prepare_stageout(self, tmp_dir): # check if the stage-out dir is set cfg = Config.instance() section = self.sandbox_inst.get_config_section() stageout_dir_name = cfg.get_expanded(section, "stageout_dir_name") if not stageout_dir_name: return None # determine outputs as seen by the sandbox with patch_object(os, "environ", self.task.env, lock=True): sandbox_outputs = self.task.output() # get the sandbox stage-out mask stageout_mask = self.task.sandbox_stageout(sandbox_outputs) if not stageout_mask: return None # apply the mask sandbox_outputs = mask_struct(stageout_mask, sandbox_outputs) if not sandbox_outputs: return None # create the stage-out directory stageout_dir = tmp_dir.child(stageout_dir_name, type="d") stageout_dir.touch() # create localized sandbox output representations staged_outputs = create_staged_target_struct(stageout_dir, sandbox_outputs) return StageInfo(sandbox_outputs, stageout_dir, staged_outputs) def stageout(self, stageout_info): # perform the actual stage-out via copying flat_sandbox_outputs = flatten(stageout_info.targets) flat_staged_outputs = flatten(stageout_info.staged_targets) while flat_sandbox_outputs: sandbox_output = flat_sandbox_outputs.pop(0) staged_output = flat_staged_outputs.pop(0) if isinstance(sandbox_output, TargetCollection): flat_sandbox_outputs = sandbox_output._flat_target_list + flat_sandbox_outputs flat_staged_outputs = staged_output._flat_target_list + flat_staged_outputs continue logger.debug("stage-out {} to {}".format(staged_output.path, sandbox_output.path)) if staged_output.exists(): sandbox_output.copy_from_local(staged_output) else: logger.warning( "could not find output target at {} for stage-out".format(staged_output.path), ) logger.info("staged-out {} file(s)".format(len(stageout_info.stage_dir.listdir()))) @contextmanager def _run_context(self, cmd=None): def print_banner(msg, color): print("") print(colored(" {} ".format(msg).center(80, "="), color=color)) print(colored("task : ", color=color) + colored(self.task.task_id, style="bright")) print(colored("sandbox: ", color=color) + colored(self.sandbox_inst.key, style="bright")) print(colored(80 * "=", color=color)) print("") # start banner print_banner("entering sandbox", "magenta") # log the command if cmd: logger.debug("sandbox command:\n{}".format(cmd)) sys.stdout.flush() try: yield finally: # end banner print_banner("leaving sandbox", "cyan") sys.stdout.flush()
[docs]class SandboxTask(ProxyAttributeTask): sandbox = luigi.Parameter( default=_current_sandbox[0] or NO_STR, description="name of the sandbox to run the task in; default: $LAW_SANDBOX when set, " "otherwise empty", ) allow_empty_sandbox = False valid_sandboxes = ["*"] exclude_params_sandbox = {"sandbox", "log_file"} def __init__(self, *args, **kwargs): super(SandboxTask, self).__init__(*args, **kwargs) # store whether sandbox objects have been setup, which is done lazily, # and predefine all attributes that are set by it self._sandbox_initialized = False self._effective_sandbox = None self._sandbox_inst = None self._sandbox_proxy = None def _initialize_sandbox(self, force=False): if self._sandbox_initialized and not force: return self._sandbox_initialized = True # reset values self._effective_sandbox = None self._sandbox_inst = None self._sandbox_proxy = None # when we are already in a sandbox, this task is placed inside it, i.e., there is no nesting if _sandbox_switched: self._effective_sandbox = _current_sandbox[0] # when the sandbox is set via a parameter and not hard-coded, # check if the value is among the valid sandboxes, otherwise determine the fallback elif isinstance(self.__class__.sandbox, luigi.Parameter): if multi_match(self.sandbox, self.valid_sandboxes, mode=any): self._effective_sandbox = self.sandbox else: self._effective_sandbox = self.fallback_sandbox(self.sandbox) # just set the effective sandbox else: self._effective_sandbox = self.sandbox # at this point, the sandbox must be set unless it is explicitely allowed to be empty if self._effective_sandbox in (None, NO_STR): if not self.allow_empty_sandbox: raise Exception("task {!r} requires the sandbox parameter to be set".format(self)) self._effective_sandbox = NO_STR # create the sandbox proxy when required if self._effective_sandbox not in (None, NO_STR): sandbox_inst = Sandbox.new(self._effective_sandbox, self) if not sandbox_inst.is_active(): self._sandbox_inst = sandbox_inst self._sandbox_proxy = SandboxProxy(task=self) logger.debug( "created sandbox proxy instance of type '{}'".format(self._effective_sandbox), ) @property def effective_sandbox(self): self._initialize_sandbox() return self._effective_sandbox @property def sandbox_inst(self): self._initialize_sandbox() return self._sandbox_inst @property def sandbox_proxy(self): self._initialize_sandbox() return self._sandbox_proxy def is_sandboxed(self): # returns whether the task requires no additional sandboxing, i.e., if it is already in its # desired sandbox return self.effective_sandbox == NO_STR or not self.sandbox_inst def is_sandboxed_task(self): # returns whether the task is the _one_ task whose execution is actually sandboxed return self.live_task_id == _sandbox_task_id def is_root_task(self): # returns whether the task is the root task of the initial "law run" invocation, potentially # outside a sandbox is_root = super(SandboxTask, self).is_root_task() if not _sandbox_switched: return is_root return is_root and _sandbox_is_root_task def _proxy_staged_input(self): # whether the input attribute should be forwarded to _stagein_input # (see get_proxy_attribute used in ProxyAttributeTask.__getattribute__) return _sandbox_stagein_dir and self.is_sandboxed() and self.is_sandboxed_task() def _proxy_staged_output(self): # whether the output attribute should be forwarded to _stagein_output # (see get_proxy_attribute used in ProxyAttributeTask.__getattribute__) return _sandbox_stageout_dir and self.is_sandboxed() and self.is_sandboxed_task() def _staged_input(self): from law.decorator import _is_patched_localized_method if not _sandbox_stagein_dir: raise Exception( "LAW_SANDBOX_STAGEIN_DIR must not be empty in a sandbox when target " "stage-in is required", ) # get the original inputs input_func = self.__getattribute__("input", proxy=False) inputs = input_func() # when input_func is a patched method from a localization decorator, just return the inputs # since the decorator already triggered the stage-in if _is_patched_localized_method(input_func): return inputs # create the struct of staged inputs and apply the stage-in mask staged_inputs = create_staged_target_struct(_sandbox_stagein_dir, inputs) return mask_struct(self.sandbox_stagein(inputs), staged_inputs, replace=inputs) def _staged_output(self): from law.decorator import _is_patched_localized_method if not _sandbox_stageout_dir: raise Exception( "LAW_SANDBOX_STAGEOUT_DIR must not be empty in a sandbox when target " "stage-out is required", ) # get the original outputs output_func = self.__getattribute__("output", proxy=False) outputs = output_func() # when output_func is a patched method from a localization decorator, just return the # outputs since the decorator already triggered the stage-out if _is_patched_localized_method(output_func): return outputs # create the struct of staged outputs and apply the stage-out mask staged_outputs = create_staged_target_struct(_sandbox_stageout_dir, outputs) return mask_struct(self.sandbox_stageout(outputs), staged_outputs, replace=outputs) @property def env(self): return os.environ if self.is_sandboxed() else self.sandbox_inst.env def fallback_sandbox(self, sandbox): return None def sandbox_user(self): uid, gid = os.getuid(), os.getgid() # check if there is a config section that defines the user and group ids if self.sandbox_inst: cfg = Config.instance() section = self.sandbox_inst.get_config_section() uid = cfg.get_expanded_int(section, "uid", default=uid) gid = cfg.get_expanded_int(section, "gid", default=gid) return uid, gid def sandbox_stagein(self, inputs): # disable stage-in by default return False def sandbox_stageout(self, outputs): # disable stage-out by default return False def sandbox_env(self, env): # additional environment variables return {} def sandbox_volumes(self, volumes): # additional volumes to mount return {} def sandbox_pre_setup_cmds(self): # list of commands that are run before the sandbox is set up return [] def sandbox_post_setup_cmds(self): # list of commands that are run after the sandbox is set up return [] def sandbox_law_executable(self): # law executable that is used inside the sandbox executable = "law" if self.sandbox_inst: section = self.sandbox_inst.get_config_section() executable = Config.instance().get_expanded(section, "law_executable") return shlex.split(executable) if executable else [] def sandbox_pre_run(self): # method that is invoked before the run method of the sandbox proxy is called return def sandbox_post_run(self): # method that is invoked after the run method of the sandbox proxy is called return
def create_staged_target_struct(stage_dir, struct): def map_target(target): return create_staged_target(stage_dir, target) def map_collection(func, collection, **kwargs): staged_targets = map_struct(func, collection.targets, **kwargs) return collection.__class__(staged_targets, **collection._copy_kwargs()) return map_struct(map_target, struct, custom_mappings={TargetCollection: map_collection}) def create_staged_target(stage_dir, target): if not isinstance(stage_dir, LocalDirectoryTarget): stage_dir = LocalDirectoryTarget(str(stage_dir)) return stage_dir.child(target.unique_basename, type=target.type, **target._copy_kwargs())