Source code for law.contrib.keras.formatter
# coding: utf-8
"""
Keras target formatters.
"""
from __future__ import annotations
__all__ = ["KerasModelFormatter", "KerasWeightsFormatter"]
import pathlib
from law.target.formatter import Formatter
from law.target.file import FileSystemFileTarget, get_path
from law.logger import get_logger
from law.util import no_value
from law._types import Any
logger = get_logger(__name__)
[docs]
class KerasModelFormatter(Formatter):
name = "keras_model"
@classmethod
def accepts(cls, path: str | pathlib.Path | FileSystemFileTarget, mode: str) -> bool:
return get_path(path).endswith((".hdf5", ".h5", ".json", ".yaml", ".yml"))
@classmethod
def load(cls, path: str | pathlib.Path | FileSystemFileTarget, *args, **kwargs) -> Any:
import keras # type: ignore[import-untyped, import-not-found]
path = get_path(path)
# the method for loading the model depends on the file extension
if path.endswith(".json"):
with open(path, "r") as f:
return keras.models.model_from_json(f.read(), *args, **kwargs)
if path.endswith((".yml", ".yaml")):
with open(path, "r") as f:
return keras.models.model_from_yaml(f.read(), *args, **kwargs)
# .hdf5, .h5, bundle
return keras.models.load_model(path, *args, **kwargs)
@classmethod
def dump(cls, path: str | pathlib.Path | FileSystemFileTarget, model, *args, **kwargs) -> Any:
_path = get_path(path)
perm = kwargs.pop("perm", no_value)
# the method for saving the model depends on the file extension
ret = None
if _path.endswith(".json"):
with open(_path, "w") as f:
f.write(model.to_json(*args, **kwargs))
elif _path.endswith((".yml", ".yaml")):
with open(_path, "w") as f:
f.write(model.to_yaml(*args, **kwargs))
else: # .hdf5, .h5, bundle
ret = model.save(_path, *args, **kwargs)
if perm != no_value:
cls.chmod(path, perm)
return ret
[docs]
class KerasWeightsFormatter(Formatter):
name = "keras_weights"
@classmethod
def accepts(cls, path: str | pathlib.Path | FileSystemFileTarget, mode: str) -> bool:
return get_path(path).endswith((".hdf5", ".h5"))
@classmethod
def load(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
model: Any,
*args,
**kwargs,
) -> Any:
return model.load_weights(get_path(path), *args, **kwargs)
@classmethod
def dump(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
model: Any,
*args,
**kwargs,
) -> Any:
perm = kwargs.pop("perm", no_value)
ret = model.save_weights(get_path(path), *args, **kwargs)
if perm != no_value:
cls.chmod(path, perm)
return ret