Source code for law.target.remote.cache

# coding: utf-8

"""
Cache for remote files on local disk.
"""

from __future__ import annotations

__all__ = ["RemoteCache"]

import os
import shutil
import time
import pathlib
import tempfile
import weakref
import contextlib
import atexit

from law.config import Config
import law.target.remote.base as _remote_base
from law.util import (
    makedirs, human_bytes, parse_bytes, parse_duration, create_hash, user_owns_file, io_lock,
)
from law.logger import get_logger
from law._types import Any, Callable, Iterator, AbstractContextManager


logger = get_logger(__name__)


[docs] class RemoteCache(object): TMP = "__TMP__" lock_postfix = ".lock" _instances: list[RemoteCache] = [] def __new__(cls, *args, **kwargs) -> RemoteCache: inst = super().__new__(cls) # cache instances cls._instances.append(inst) return inst @classmethod def cleanup_all(cls) -> None: # clear all caches for inst in cls._instances: try: inst._cleanup() except: pass @classmethod def parse_config( cls, section: str, config: dict[str, Any] | None = None, *, overwrite: bool = False, ) -> dict[str, Any]: from law.sandbox.base import _sandbox_switched # reads a law config section and returns parsed file system configs cfg = Config.instance() if config is None: config = {} # helper to add a config value if it exists, extracted with a config parser method def add(option: str, func: Callable[[str, str], Any]) -> None: cache_option = "cache_" + option if cfg.is_missing_or_none(section, cache_option): return elif option not in config or overwrite: config[option] = func(section, cache_option) def get_size(section, cache_option): value = cfg.get_expanded(section, cache_option) return parse_bytes(value, input_unit="MB", unit="MB") def get_time(section, cache_option): value = cfg.get_expanded(section, cache_option) return parse_duration(value, input_unit="s", unit="s") add("root", cfg.get_expanded) add("cleanup", cfg.get_expanded_bool) add("max_size", get_size) add("mtime_patience", cfg.get_expanded_float) add("file_perm", cfg.get_expanded_int) add("dir_perm", cfg.get_expanded_int) add("wait_delay", get_time) add("max_waits", cfg.get_expanded_int) add("global_lock", cfg.get_expanded_bool) # inside sandboxes, never cleanup since the outer process will do that if needed if _sandbox_switched: config["cleanup"] = False return config def __init__( self, fs: _remote_base.RemoteFileSystem, *, root: str | pathlib.Path | None = TMP, cleanup: bool = False, max_size: int | float = 0, # in MB mtime_patience: int | float = 1.0, # in seconds file_perm: int = 0o0660, dir_perm: int = 0o0770, wait_delay: int | float = 5.0, # in seconds max_waits: int = 120, global_lock: bool = False, ) -> None: super().__init__() # max_size is in MB, wait_delay is in seconds # create a unique name based on fs attributes name = f"{fs.__class__.__name__}_{create_hash(fs.base[0])}" # create the root dir, handle tmp root = os.path.expandvars(os.path.expanduser(str(root))) or self.TMP if not os.path.exists(root) and root == self.TMP: cfg = Config.instance() tmp_dir = cfg.get_expanded("target", "tmp_dir") base = tempfile.mkdtemp(dir=tmp_dir) cleanup = True else: base = os.path.join(root, name) makedirs(base, dir_perm) # save attributes and configs self.root = root self.fs_ref = weakref.ref(fs) self.base = base self.name = name self.cleanup = cleanup self.max_size = float(max_size) self.mtime_patience = float(mtime_patience) self.dir_perm = dir_perm self.file_perm = file_perm self.wait_delay = float(wait_delay) self.max_waits = max_waits self.global_lock = global_lock # path to the global lock file which should guard global actions such as cache allocations self._global_lock_path = self._lock_path(os.path.join(base, "global")) # currently locked cache paths, only used to clean up broken files during cleanup self._locked_cpaths: set[str] = set() logger.debug(f"created {self.__class__.__name__} at '{self.base}'") def __del__(self) -> None: try: self._cleanup() except (OSError, TypeError): pass def __repr__(self) -> str: return f"<{self.__class__.__name__} '{self.base}' at {hex(id(self))}>" def __contains__(self, rpath: str | pathlib.Path) -> bool: return os.path.exists(self.cache_path(rpath)) @property def fs(self) -> _remote_base.FileSystem: return self.fs_ref() # type: ignore[return-value] def _cleanup(self) -> None: # full cleanup or remove open locks if getattr(self, "cleanup", False): if os.path.exists(self.base): shutil.rmtree(self.base) else: for cpath in set(self._locked_cpaths): self._unlock(cpath) self._remove(cpath) self._locked_cpaths.clear() self._unlock_global() logger.debug(f"cleanup RemoteCache at '{self.base}'") def cache_path(self, rpath: str | pathlib.Path) -> str: rpath = str(rpath) basename = f"{create_hash(rpath)}_{os.path.basename(rpath)}" return os.path.join(self.base, basename) def _lock_path(self, cpath: str | pathlib.Path) -> str: return f"{cpath}{self.lock_postfix}" def is_locked_global(self) -> bool: return os.path.exists(self._global_lock_path) def _is_locked(self, cpath: str | pathlib.Path) -> bool: return os.path.exists(self._lock_path(cpath)) def is_locked(self, rpath: str | pathlib.Path) -> bool: return self._is_locked(self.cache_path(rpath)) def _unlock_global(self) -> None: try: os.remove(self._global_lock_path) except OSError: pass def _unlock(self, cpath: str | pathlib.Path) -> None: try: os.remove(self._lock_path(cpath)) except OSError: pass def _await_global( self, *, delay: int | float | None = None, max_waits: int | None = None, silent: bool = False, ) -> bool: delay = delay if delay is not None else self.wait_delay max_waits = max_waits if max_waits is not None else self.max_waits _max_waits = max_waits while self.is_locked_global(): if max_waits <= 0: if not silent: raise Exception( f"max_waits of {_max_waits} exceeded while waiting for global lock", ) return False time.sleep(delay) max_waits -= 1 return True def _await( self, cpath: str | pathlib.Path, *, delay: int | float | None = None, max_waits: int | None = None, silent: bool = False, global_lock: bool | None = None, ) -> bool: cpath = str(cpath) delay = delay if delay is not None else self.wait_delay max_waits = max_waits if max_waits is not None else self.max_waits _max_waits = max_waits global_lock = self.global_lock if global_lock is None else global_lock # strategy: wait as long the file is locked and if the file size did not change, reduce # max_waits per iteration and raise when 0 is reached last_size = -1 while self._is_locked(cpath) or (global_lock and self.is_locked_global()): if max_waits <= 0: if not silent: raise Exception( f"max_waits of {_max_waits} exceeded while waiting for file '{cpath}'", ) return False time.sleep(delay) # only reduce max_waits when the file size did not change # otherwise, set it to its original value again if os.path.exists(cpath): size = os.stat(cpath).st_size if size != last_size: last_size = size max_waits = _max_waits + 1 max_waits -= 1 return True @contextlib.contextmanager def _lock_global(self, **kwargs) -> Iterator[None]: self._await_global(**kwargs) try: with io_lock: with open(self._global_lock_path, "w") as f: f.write("") os.utime(self._global_lock_path, None) yield finally: self._unlock_global() @contextlib.contextmanager def _lock(self, cpath: str | pathlib.Path, **kwargs) -> Iterator[None]: cpath = str(cpath) lock_path = self._lock_path(cpath) self._await(cpath, **kwargs) try: with io_lock: with open(lock_path, "w") as f: f.write("") self._locked_cpaths.add(cpath) try: os.utime(lock_path, None) except OSError: pass yield except: # when something went really wrong, conservatively delete the cached file self._remove(cpath, lock=False) raise finally: # unlock again self._unlock(cpath) if cpath in self._locked_cpaths: self._locked_cpaths.remove(cpath) def lock(self, rpath: str | pathlib.Path) -> AbstractContextManager: return self._lock(self.cache_path(rpath)) def allocate(self, size: int | float) -> bool: def _human_bytes(size: int | float) -> tuple[float, str]: return human_bytes(size) # type: ignore[return-value] logger.debug("allocating {0[0]:.2f} {0[1]} in cache '{1}'".format(_human_bytes(size), self)) # determine stats and current cache size file_stats = [] for elem in os.listdir(self.base): if elem.endswith(self.lock_postfix): continue cpath = os.path.join(self.base, elem) file_stats.append((cpath, os.stat(cpath))) current_size = sum(stat.st_size for _, stat in file_stats) # get the available space of the disk that contains the cache in bytes, leave 10% fs_stat = os.statvfs(self.base) free_size = fs_stat.f_frsize * fs_stat.f_bavail * 0.9 # determine the maximum size of the cache # make sure it is always smaller than what is available if self.max_size <= 0: max_size = current_size + free_size else: max_size = min(self.max_size * 1024**2, current_size + free_size) # determine the size of files that need to be deleted delete_size = current_size + size - max_size if delete_size <= 0: logger.debug( "cache space sufficient, {0[0]:.2f} {0[1]} remaining".format( _human_bytes(-delete_size), ), ) return True logger.info( "need to delete {0[0]:.2f} {0[1]} from cache".format( _human_bytes(delete_size), ), ) # delete files, ordered by their access time, skip locked ones for cpath, cstat in sorted(file_stats, key=lambda tpl: tpl[1].st_atime): if self._is_locked(cpath): continue self._remove(cpath) delete_size -= cstat.st_size if delete_size <= 0: return True logger.warning( "could not allocate remaining {0[0]:.2f} {0[1]} in cache".format( _human_bytes(delete_size), ), ) return False def _touch( self, cpath: str | pathlib.Path, times: tuple[int | float, int | float] | None = None, ) -> None: cpath = str(cpath) if os.path.exists(cpath): if user_owns_file(cpath): os.chmod(cpath, self.file_perm) os.utime(cpath, times) def touch( self, rpath: str | pathlib.Path, times: tuple[int | float, int | float] | None = None, ) -> None: return self._touch(self.cache_path(rpath), times=times) def _mtime(self, cpath: str | pathlib.Path) -> float: return os.stat(str(cpath)).st_mtime def mtime(self, rpath: str | pathlib.Path) -> float: return self._mtime(self.cache_path(rpath)) def check_mtime(self, rpath: str | pathlib.Path, rmtime: int | float) -> bool: if self.mtime_patience < 0: return True return abs(self.mtime(rpath) - rmtime) <= self.mtime_patience def _remove(self, cpath: str | pathlib.Path, lock: bool = True) -> None: def remove() -> None: try: os.remove(str(cpath)) except OSError: pass if lock: with self._lock(cpath): remove() else: remove() def remove(self, rpath: str | pathlib.Path, lock: bool = True) -> None: return self._remove(self.cache_path(rpath), lock=lock)
atexit.register(RemoteCache.cleanup_all)