GenSynth Documentation

Using Keras Models in TensorFlow

This section tells you how to prepare Keras models for use in GenSynth. If you do not follow these steps, you may create a situation that is difficult to diagnose, such as unsatisfactory performance.

Note

‘Keras’ refers to any version of tf.keras with the TensorFlow backend.

Ensuring Graph Collections are Populated

In Keras, metric variables and update operations are stored as attributes of the keras.Model and are not added to tf.Graph collections such as tf.GraphKeys.UPDATE_OPS or tf.GraphKeys.METRIC_VARIABLES.

This causes issues with properly running Keras models in TensorFlow without the use of Keras’ compile or fit methods (which handle these objects for you). For a list of built-in graph collections, please refer to TensorFlow’s documentation.

These collections are needed because TensorFlow 1.x, and applications that work with TensorFlow 1.x, such as GenSynth, are designed to use these collections during operation. Without properly populating these collections, some functionality, such as accumulating results for metrics and updating batch normalization statistics during training, will not properly function.

You can check if the model you are creating with Keras has the necessary collections populated by inspecting the tf.Graph object. For example, to check the tf.GraphKeys.UPDATE_OPS collection:

In TensorFlow:

tf_graph = tf.get_default_graph() # The existing tf.Graph object. print(tf_graph.get_collection(tf.GraphKeys.UPDATE_OPS)

In Keras:

print(tf.keras.backend.get_session().graph.get_collection(tf.GraphKeys.UPDATE_OPS)

You should see this collection populated if your model uses FusedBatchNormV3 operations, or if you have compiled your model with an accuracy metric.

Manually Create Graph Collections

It is simple enough to manipulate the TensorFlow graph collections, and this code snippet will populate most of the needed collections with use of <tf.Graph>.add_collection(tf.GraphKeys<key>).

You will need to populate these collections before saving your MetaGraph; otherwise, it will be missing the required collections when loaded in GenSynth. This can result in batch normalization statistics not accumulating correctly, or metrics may become corrupted or removed from the graph.

def init_variable_collections_from_keras(graph: tf.Graph, keras_model: tf.keras.Model):
    """Create missing collections in a tf.Graph using keras model attributes"""
    for metric in keras_model.metrics:
        for update_op in metric.updates:
            graph.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op)
        for weight in metric._non_trainable_weights:
            graph.add_to_collection(tf.GraphKeys.METRIC_VARIABLES, weight)
            graph.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES, weight)
    for update_op in keras_model.updates:
        graph.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op)
Save Your Model as a TensorFlow Checkpoint

Currently if you use Keras and you save your model into the HDF5 format using keras.models.save_model() you can use it with GenSynth, provided it has no custom layers. Unfortunately, many models in Keras use loss functions or metrics which are not provided within keras.layers. In these cases, the model cannot be loaded without code references to your custom layers.

Since Keras uses TensorFlow as its backend, the best solution for using GenSynth with Keras models that contain custom objects is to save it as a TensorFlow checkpoint.

Note

Make sure to update your graph collections prior to saving the model.

This code snippet shows how to convert a Keras model object into a compatible TensorFlow checkpoint:

# NOTE: define your keras_model object here and compile() it...
keras_model = ...
keras_model.compile(...)

sess = tf.keras.backend.get_session()
init_variable_collections_from_keras(sess.graph, keras_model)
saver = tf.train.Saver()
with sess.graph.as_default():
    saver.save(sess, 'my_keras_model')

Note

When defining a Keras model with the TensorFlow backend, Keras will add all of the TensorFlow graph objects to an internal, global, tf.Session object. This can be accessed via tf.keras.backend.get_session()