Source code for law.contrib.numpy.formatter
# coding: utf-8
"""
NumPy target formatters.
"""
from __future__ import annotations
__all__ = ["NumpyFormatter"]
import pathlib
from law.target.formatter import Formatter
from law.target.file import FileSystemFileTarget, get_path
from law.logger import get_logger
from law.util import no_value
from law._types import Any, Callable
logger = get_logger(__name__)
[docs]
class NumpyFormatter(Formatter):
name = "numpy"
@classmethod
def accepts(cls, path: str | pathlib.Path | FileSystemFileTarget, mode: str) -> bool:
return get_path(path).endswith((".npy", ".npz", ".txt"))
@classmethod
def load(cls, path: str | pathlib.Path | FileSystemFileTarget, *args, **kwargs) -> Any:
import numpy as np # type: ignore[import-untyped, import-not-found]
path = get_path(path)
func = np.loadtxt if str(path).endswith(".txt") else np.load
return func(path, *args, **kwargs) # type: ignore[operator]
@classmethod
def dump(cls, path: str | pathlib.Path | FileSystemFileTarget, *args, **kwargs) -> Any:
import numpy as np
_path = get_path(path)
perm = kwargs.pop("perm", no_value)
func: Callable
if str(_path).endswith(".txt"):
func = np.savetxt
elif str(_path).endswith(".npz"):
compress_flag = "savez_compressed"
compress = False
if compress_flag in kwargs:
if isinstance(kwargs[compress_flag], bool):
compress = kwargs.pop(compress_flag)
else:
logger.warning(f"the '{compress_flag}' argument is reserved to set compression")
func = np.savez_compressed if compress else np.savez
else:
func = np.save
ret = func(_path, *args, **kwargs)
if perm != no_value:
cls.chmod(path, perm)
return ret