tensorflow#

TensorFlow contrib functionality.

Class TFGraphFormatter#

class TFGraphFormatter#

Bases: Formatter

classmethod load(path, create_session=None, as_text=None)#

Reads a saved TensorFlow graph from path and returns it. When create_session is True, a session object (compatible with the v1 API) is created and returned as the second value of a 2-tuple. The default value of create_session is True when TensorFlow v1 is detected, and False otherwise. When as_text is either True, or None and the file extension is ".pbtxt" or ".pb.txt", the content of the file at path is expected to be a human-readable text file. Otherwise, it is read as a binary protobuf file. Example:

graph = TFConstantGraphFormatter.load("path/to/model.pb", create_session=False)

graph, session = TFConstantGraphFormatter.load("path/to/model.pb", create_session=True)
classmethod dump(path, obj, variables_to_constants=False, output_names=None, *args, **kwargs)#

Extracts a TensorFlow graph from an object obj and saves it at path. The graph is optionally transformed into a simpler representation with all its variables converted to constants when variables_to_constants is True. The saved file contains the graph as a protobuf. The accepted types of obj greatly depend on the available API versions.

When the v1 API is found (which is also the case when tf.compat.v1 is available in v2), Graph, GraphDef and Session objects are accepted. However, when variables_to_constants is True, obj must be a session and output_names should refer to names of operations whose subgraphs are extracted (usually just one).

For TensorFlow v2, obj can also be a compiled keras model, or either a polymorphic or concrete function as returned by tf.function. Polymorphic functions either must have a defined input signature (tf.function(input_signature=(...,))) or they must accept no arguments in the first place. See the TensorFlow documentation on concrete functions for more info.

args and kwargs are forwarded to tf.train.write_graph (v1) or tf.io.write_graph (v2).

Class TFKerasModelFormatter#

class TFKerasModelFormatter#

Bases: Formatter

Class TFKerasWeightsFormatter#

class TFKerasWeightsFormatter#

Bases: Formatter