Source code for law.contrib.tensorflow.formatter
# coding: utf-8
"""
TensorFlow target formatters.
"""
from __future__ import annotations
__all__ = [
"TFGraphFormatter", "TFSavedModelFormatter", "TFKerasModelFormatter", "TFKerasWeightsFormatter",
]
import os
import pathlib
from law.target.formatter import Formatter
from law.target.file import FileSystemFileTarget, get_path
from law.util import no_value
from law._types import ModuleType, Any, Sequence
[docs]
class TFGraphFormatter(Formatter):
name = "tf_graph"
@classmethod
def import_tf(cls) -> tuple[ModuleType, ModuleType | None, tuple[str, str, str]]:
import tensorflow as tf # type: ignore[import-untyped, import-not-found]
# keep a reference to the v1 API as long as v2 provides compatibility
tf1 = None
tf_version = tf.__version__.split(".", 2)
if tf_version[0] == "1":
tf1 = tf
elif getattr(tf, "compat", None) is not None and getattr(tf.compat, "v1", None) is not None:
tf1 = tf.compat.v1
return tf, tf1, tf_version
@classmethod
def accepts(cls, path: str | pathlib.Path | FileSystemFileTarget, mode: str) -> bool:
return get_path(path).endswith((".pb", ".pbtxt", ".pb.txt"))
[docs]
@classmethod
def load(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
create_session: bool | None = None,
as_text: bool | None = None,
) -> Any | tuple[Any, Any]:
"""
Reads a saved TensorFlow graph from *path* and returns it. When *create_session* is *True*,
a session object (compatible with the v1 API) is created and returned as the second value of
a 2-tuple. The default value of *create_session* is *True* when TensorFlow v1 is detected,
and *False* otherwise. When *as_text* is either *True*, or *None* and the file extension is
``".pbtxt"`` or ``".pb.txt"``, the content of the file at *path* is expected to be a
human-readable text file. Otherwise, it is read as a binary protobuf file. Example:
.. code-block:: python
graph = TFConstantGraphFormatter.load("path/to/model.pb", create_session=False)
graph, session = TFConstantGraphFormatter.load("path/to/model.pb", create_session=True)
"""
tf, tf1, tf_version = cls.import_tf()
path = get_path(path)
# default create_session value
if create_session is None:
create_session = tf_version[0] == "1"
if create_session and not tf1:
raise NotImplementedError(
"the v1 compatibility layer of TensorFlow v2 is missing, but required by when "
"create_session is True",
)
# default as_text value
if as_text is None:
as_text = str(path).endswith((".pbtxt", ".pb.txt"))
graph = tf.Graph()
with graph.as_default():
graph_def = graph.as_graph_def()
if as_text:
# use a simple pb reader to load the file into graph_def
from google.protobuf import text_format # type: ignore[import-untyped, import-not-found] # noqa
with open(path, "rb") as f:
text_format.Merge(f.read(), graph_def)
else:
# use the gfile api depending on the TF version
if tf_version[0] == "1":
from tensorflow.python.platform import gfile # type: ignore[import-untyped, import-not-found] # noqa
with gfile.FastGFile(path, "rb") as f:
graph_def.ParseFromString(f.read())
else:
with tf.io.gfile.GFile(path, "rb") as f:
graph_def.ParseFromString(f.read())
# import the graph_def (pb object) into the actual graph
tf.import_graph_def(graph_def, name="")
if create_session:
session = tf1.Session(graph=graph) # type: ignore[union-attr]
return graph, session
return graph
[docs]
@classmethod
def dump(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
obj: Any,
variables_to_constants: bool = False,
output_names: Sequence[str] | None = None,
*args,
**kwargs,
) -> Any:
"""
Extracts a TensorFlow graph from an object *obj* and saves it at *path*. The graph is
optionally transformed into a simpler representation with all its variables converted to
constants when *variables_to_constants* is *True*. The saved file contains the graph as a
protobuf. The accepted types of *obj* greatly depend on the available API versions.
When the v1 API is found (which is also the case when ``tf.compat.v1`` is available in v2),
``Graph``, ``GraphDef`` and ``Session`` objects are accepted. However, when
*variables_to_constants* is *True*, *obj* must be a session and *output_names* should refer
to names of operations whose subgraphs are extracted (usually just one).
For TensorFlow v2, *obj* can also be a compiled keras model, or either a polymorphic or
concrete function as returned by ``tf.function``. Polymorphic functions either must have a
defined input signature (``tf.function(input_signature=(...,))``) or they must accept no
arguments in the first place. See the TensorFlow documentation on `concrete functions
<https://www.tensorflow.org/guide/concrete_function>`__ for more info.
*args* and *kwargs* are forwarded to ``tf.train.write_graph`` (v1) or ``tf.io.write_graph``
(v2).
"""
tf, tf1, tf_version = cls.import_tf()
_path = get_path(path)
perm = kwargs.pop("perm", no_value)
graph_dir, graph_name = os.path.split(_path)
# default as_text value
kwargs.setdefault("as_text", str(_path).endswith((".pbtxt", ".pb.txt")))
# convert keras models and polymorphic functions to concrete functions, v2 only
if tf_version[0] != "1":
from tensorflow.python.keras.saving import saving_utils # type: ignore[import-untyped, import-not-found] # noqa
from tensorflow.python.eager.def_function import Function # type: ignore[import-untyped, import-not-found] # noqa
from tensorflow.python.eager.function import ConcreteFunction # type: ignore[import-untyped, import-not-found] # noqa
if isinstance(obj, tf.keras.Model):
learning_phase_orig = tf.keras.backend.learning_phase()
tf.keras.backend.set_learning_phase(False)
model_func = saving_utils.trace_model_call(obj)
if model_func.function_spec.arg_names and not model_func.input_signature:
raise ValueError(
"when obj is a keras model callable accepting arguments, its input "
"signature must be frozen by building the model",
)
obj = model_func.get_concrete_function()
tf.keras.backend.set_learning_phase(learning_phase_orig)
elif isinstance(obj, Function):
if obj.function_spec.arg_names and not obj.input_signature:
raise ValueError(
"when obj is a polymorphic function accepting arguments, its input "
"signature must be frozen",
)
obj = obj.get_concrete_function()
# convert variables to constants
if variables_to_constants:
if tf1 and isinstance(obj, tf1.Session):
if not output_names:
raise ValueError(
"when variables_to_constants is true, output_names must contain operations "
f"to export, got '{output_names}' instead",
)
obj = tf1.graph_util.convert_variables_to_constants(
obj,
obj.graph.as_graph_def(),
output_names,
)
elif tf_version[0] != "1":
from tensorflow.python.framework import convert_to_constants # type: ignore[import-untyped, import-not-found] # noqa
if not isinstance(obj, ConcreteFunction):
raise TypeError(
"when variables_to_constants is true, obj must be a concrete or "
f"polymorphic function, got '{obj}' instead",
)
obj = convert_to_constants.convert_variables_to_constants_v2(obj)
else:
raise TypeError(
f"cannot convert variables to constants for object '{obj}', type not "
f"understood for TensorFlow version {tf.__version__}",
)
# extract the graph
if tf1 and isinstance(obj, tf1.Session):
graph = obj.graph
elif tf_version[0] != "1" and isinstance(obj, ConcreteFunction):
graph = obj.graph
else:
graph = obj
# write it
if tf_version[0] == "1":
ret = tf1.train.write_graph(graph, graph_dir, graph_name, *args, **kwargs) # type: ignore[union-attr] # noqa
else:
ret = tf.io.write_graph(graph, graph_dir, graph_name, *args, **kwargs)
if perm != no_value:
cls.chmod(path, perm)
return ret
class TFSavedModelFormatter(Formatter):
name = "tf_saved_model"
@classmethod
def accepts(cls, path: str | pathlib.Path | FileSystemFileTarget, mode: str) -> bool:
# accept paths where basenames refer to directories, likely without any file extension
_, ext = os.path.splitext(get_path(path))
return not ext
@classmethod
def load(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
*args,
**kwargs,
) -> Any:
import tensorflow as tf
return tf.saved_model.load(get_path(path), *args, **kwargs)
@classmethod
def dump(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
model: Any,
*args,
**kwargs,
) -> Any:
import tensorflow as tf
perm = kwargs.pop("perm", no_value)
ret = tf.saved_model.save(model, get_path(path), *args, **kwargs)
if perm != no_value:
cls.chmod(path, perm)
return ret
[docs]
class TFKerasModelFormatter(Formatter):
name = "tf_keras_model"
@classmethod
def accepts(cls, path: str | pathlib.Path | FileSystemFileTarget, mode: str) -> bool:
_, ext = os.path.splitext(get_path(path))
return ext in (".hdf5", ".h5", ".json", ".yaml", ".yml", "")
@classmethod
def load(cls, path: str | pathlib.Path | FileSystemFileTarget, *args, **kwargs) -> Any:
import tensorflow as tf
path = get_path(path)
# the method for loading the model depends on the file extension
if str(path).endswith(".json"):
with open(path, "r") as f:
return tf.keras.models.model_from_json(f.read(), *args, **kwargs)
if str(path).endswith((".yml", ".yaml")):
with open(path, "r") as f:
return tf.keras.models.model_from_yaml(f.read(), *args, **kwargs)
# .hdf5, .h5, bundle
return tf.keras.models.load_model(path, *args, **kwargs)
@classmethod
def dump(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
model: Any,
*args,
**kwargs,
) -> Any:
_path = get_path(path)
perm = kwargs.pop("perm", no_value)
# the method for saving the model depends on the file extension
ret = None
if str(_path).endswith(".json"):
with open(_path, "w") as f:
f.write(model.to_json())
elif str(_path).endswith((".yml", ".yaml")):
with open(_path, "w") as f:
f.write(model.to_yaml())
else: # .hdf5, .h5, bundle
ret = model.save(_path, *args, **kwargs)
if perm != no_value:
cls.chmod(path, perm)
return ret
[docs]
class TFKerasWeightsFormatter(Formatter):
name = "tf_keras_weights"
@classmethod
def accepts(cls, path: str | pathlib.Path | FileSystemFileTarget, mode: str) -> bool:
return get_path(path).endswith((".hdf5", ".h5"))
@classmethod
def load(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
model: Any,
*args,
**kwargs,
) -> Any:
return model.load_weights(get_path(path), *args, **kwargs)
@classmethod
def dump(
cls,
path: str | pathlib.Path | FileSystemFileTarget,
model: Any,
*args,
**kwargs,
) -> Any:
perm = kwargs.pop("perm", no_value)
ret = model.save_weights(get_path(path), *args, **kwargs)
if perm != no_value:
cls.chmod(path, perm)
return ret