Source code for law.target.formatter

# coding: utf-8

"""
Formatter classes for file targets.
"""

__all__ = ["AUTO_FORMATTER", "Formatter", "get_formatter", "find_formatters", "find_formatter"]


import os
import json
import zipfile
import gzip
import tarfile
from collections import OrderedDict

import six

from law.util import make_list, import_file
from law.logger import get_logger


logger = get_logger(__name__)


AUTO_FORMATTER = "auto"


[docs]class FormatterRegister(type): formatters = OrderedDict() def __new__(metacls, classname, bases, classdict): cls = type.__new__(metacls, classname, bases, classdict) if cls.name in metacls.formatters: raise ValueError("duplicate formatter name '{}' for class {}".format(cls.name, cls)) if cls.name == AUTO_FORMATTER: raise ValueError("formatter class {} must not be named '{}'".format( cls, AUTO_FORMATTER)) # store classes by name if cls.name != "_base": metacls.formatters[cls.name] = cls logger.debug("registered target formatter '{}'".format(cls.name)) return cls
[docs]def get_formatter(name, silent=False): """ Returns the formatter class whose name attribute is *name*. When no class could be found and *silent* is *True*, *None* is returned. Otherwise, an exception is raised. """ formatter = FormatterRegister.formatters.get(name) if formatter or silent: return formatter raise Exception("cannot find formatter '{}'".format(name))
[docs]def find_formatters(path, mode, silent=True): """ Returns a list of formatter classes which would accept the file given by *path* and *mode*, which should either be ``"load"`` or ``"dump"``. When no classes could be found and *silent* is *True*, an empty list is returned. Otherwise, an exception is raised. """ path = get_path(path) formatters = [f for f in six.itervalues(FormatterRegister.formatters) if f.accepts(path, mode)] if formatters or silent: return formatters raise Exception("cannot find any '{}' formatter for {}".format(mode, path))
[docs]def find_formatter(path, mode, name=AUTO_FORMATTER): """ Returns the formatter class whose name attribute is *name* when *name* is not *AUTO_FORMATTER*. Otherwise, the first formatter that accepts *path* is returned. Internally, this method simply uses :py:func:`get_formatter` or :py:func:`find_formatters` depending on the value of *name*. """ if name == AUTO_FORMATTER: return find_formatters(path, mode, silent=False)[0] return get_formatter(name, silent=False)
[docs]class Formatter(six.with_metaclass(FormatterRegister, object)): name = "_base" # modes LOAD = "load" DUMP = "dump" @classmethod def accepts(cls, path, mode): raise NotImplementedError @classmethod def load(cls, path, *args, **kwargs): raise NotImplementedError @classmethod def dump(cls, path, *args, **kwargs): raise NotImplementedError
[docs]class TextFormatter(Formatter): name = "text" @classmethod def accepts(cls, path, mode): return get_path(path).endswith(".txt") @classmethod def load(cls, path, *args, **kwargs): with open(get_path(path), "r") as f: return f.read(*args, **kwargs) @classmethod def dump(cls, path, content, *args, **kwargs): with open(get_path(path), "w") as f: f.write(str(content), *args, **kwargs)
[docs]class JSONFormatter(Formatter): name = "json" @classmethod def accepts(cls, path, mode): return get_path(path).endswith(".json") @classmethod def load(_cls, path, *args, **kwargs): # kwargs might contain *cls* with open(get_path(path), "r") as f: return json.load(f, *args, **kwargs) @classmethod def dump(_cls, path, obj, *args, **kwargs): # kwargs might contain *cls* with open(get_path(path), "w") as f: return json.dump(obj, f, *args, **kwargs)
class PickleFormatter(Formatter): name = "pickle" @classmethod def accepts(cls, path, mode): path = get_path(path) return path.endswith((".pkl", ".pickle", ".p")) @classmethod def load(cls, path, *args, **kwargs): with open(get_path(path), "rb") as f: return six.moves.cPickle.load(f, *args, **kwargs) @classmethod def dump(cls, path, obj, *args, **kwargs): with open(get_path(path), "wb") as f: return six.moves.cPickle.dump(obj, f, *args, **kwargs)
[docs]class YAMLFormatter(Formatter): name = "yaml" @classmethod def accepts(cls, path, mode): path = get_path(path) return path.endswith((".yaml", ".yml")) @classmethod def load(cls, path, *args, **kwargs): import yaml with open(get_path(path), "r") as f: return yaml.safe_load(f, *args, **kwargs) @classmethod def dump(cls, path, obj, *args, **kwargs): import yaml with open(get_path(path), "w") as f: return yaml.dump(obj, f, *args, **kwargs)
[docs]class ZipFormatter(Formatter): name = "zip" @classmethod def accepts(cls, path, mode): return get_path(path).endswith(".zip") @classmethod def load(cls, path, dst, *args, **kwargs): # assume read mode, but also check args and kwargs mode = "r" if args: mode = args[0] args = args[1:] elif "mode" in kwargs: mode = kwargs.pop("mode") # arguments passed to extractall() extractall_kwargs = kwargs.pop("extractall_kwargs", None) or {} # open zip file and extract to dst with zipfile.ZipFile(get_path(path), mode, *args, **kwargs) as f: f.extractall(get_path(dst), **extractall_kwargs) @classmethod def dump(cls, path, src, *args, **kwargs): # assume write mode, but also check args and kwargs mode = "w" if args: mode = args[0] args = args[1:] elif "mode" in kwargs: mode = kwargs.pop("mode") # arguments passed to write() write_kwargs = kwargs.pop("write_kwargs", None) or {} # open a new zip file and add all files in src with zipfile.ZipFile(get_path(path), mode, *args, **kwargs) as f: src = get_path(src) if os.path.isfile(src): f.write(src, os.path.basename(src), **write_kwargs) else: for elem in os.listdir(src): f.write(os.path.join(src, elem), elem, **write_kwargs)
class GZipFormatter(Formatter): name = "gzip" @classmethod def accepts(cls, path, mode): return get_path(path).endswith(".gz") @classmethod def load(cls, path, *args, **kwargs): # assume read mode, but also check args and kwargs mode = "r" if args: mode = args[0] args = args[1:] elif "mode" in kwargs: mode = kwargs.pop("mode") # arguments passed to read() read_kwargs = kwargs.pop("read_kwargs", None) or {} # open with gzip and return content with gzip.open(get_path(path), mode, *args, **kwargs) as f: return f.read(**read_kwargs) @classmethod def dump(cls, path, obj, *args, **kwargs): # assume write mode, but also check args and kwargs mode = "w" if args: mode = args[0] args = args[1:] elif "mode" in kwargs: mode = kwargs.pop("mode") # arguments passed to write() write_kwargs = kwargs.pop("write_kwargs", None) or {} # write into a new gzip file with gzip.open(get_path(path), mode, *args, **kwargs) as f: return f.write(obj, **write_kwargs)
[docs]class TarFormatter(Formatter): name = "tar" @classmethod def infer_compression(cls, path): path = get_path(path) if path.endswith((".tar.gz", ".tgz")): return "gz" if path.endswith((".tar.bz2", ".tbz2", ".bz2")): return "bz2" if path.endswith((".tar.xz", ".txz", ".lzma")): return "xz" return None @classmethod def accepts(cls, path, mode): return cls.infer_compression(path) is not None @classmethod def load(cls, path, dst, *args, **kwargs): # get the mode from args and kwargs, default to read mode with inferred compression if args: mode = args[0] args = args[1:] elif "mode" in kwargs: mode = kwargs.pop("mode") else: compression = cls.infer_compression(path) mode = "r" if not compression else "r:" + compression # arguments passed to extractall() extractall_kwargs = kwargs.pop("extractall_kwargs", None) or {} # open zip file and extract to dst with tarfile.open(get_path(path), mode, *args, **kwargs) as f: f.extractall(get_path(dst), **extractall_kwargs) @classmethod def dump(cls, path, src, *args, **kwargs): # get the mode from args and kwargs, default to write mode with inferred compression if args: mode = args[0] args = args[1:] elif "mode" in kwargs: mode = kwargs.pop("mode") else: compression = cls.infer_compression(path) mode = "w" if not compression else "w:" + compression # arguments passed to add() add_kwargs = kwargs.pop("add_kwargs", None) or {} # backwards compatibility _filter = kwargs.pop("filter", None) if _filter is not None: logger.warning_once( "passing filter=callback' to TarFormatter.dump is deprecated and will be removed " "in a future release; please use 'add_kwargs=dict(filter=callback)' instead", ) add_kwargs["filter"] = _filter # open a new zip file and add all files in src with tarfile.open(get_path(path), mode, *args, **kwargs) as f: srcs = [os.path.abspath(get_path(src)) for src in make_list(src)] common_prefix = os.path.commonprefix(srcs) for src in srcs: _add_kwargs = {"arcname": os.path.relpath(src, common_prefix)} _add_kwargs.update(add_kwargs) f.add(src, **_add_kwargs)
class PythonFormatter(Formatter): name = "python" @classmethod def accepts(cls, path, mode): return get_path(path).endswith(".py") @classmethod def load(cls, path, *args, **kwargs): return import_file(get_path(path), *args, **kwargs) # trailing imports from law.target.file import get_path