GenSynth Documentation

Creating the Python3 Data Plugin

Except in the simplest cases (i.e., if you already have compatible TFRecords), you must write a plugin to adapt the data to a format usable by GenSynth; you must provide a TensorFlow-standard Dataset (tf.data.Dataset) for the training, validation, and test sets.

If you already have a data pipeline, the plugin will likely be a thin wrapper around your existing pipeline.

We recommend starting from an example or template provided by DarwinAI. DarwinAI provides examples that you can install and use as starting templates.

Start by creating Datasets that read the data files; then apply transformations using tf.data.Dataset class methods.

When your dataset module supports sharding, it can be used in a multi-worker or multi-GPU environment. It  is worthwhile to support sharding with a small amount of effort. Sharding refers to the systematic way of dividing a training or validation dataset into multiple pieces (shards), such that each worker consistently uses one portion of the data.

The easiest way to shard data is to construct a Dataset the same way in each worker, and apply the Dataset.shard() method to give each worker its portion of the data.

When you are sharding the data, there are a couple of things to keep in mind:

  • Each worker needs to be able to unambiguously determine which data to use.

  • If you are working with multiple files, ensure that you are working with a sorted list of file names.

  • If you are working with rows from a SQL database, ensure there is an unambiguous ORDER BY clause.

You will want to create the shards as early as possible in the data pipeline to avoid extra work within each worker. For example, if you have many files of data, you can shard by file names, so that each worker only needs to open its subset of files. If you do not have many files, or if you are receiving a stream of data, each worker will have to read all data and discard all but every Nth item.

When you create your plugin, it must contain a public class with these instance methods:

Constructor:

__init__(self)

Definition: The constructor may be called multiple times, creating multiple instances of your class. Each instance should provide data from the beginning of the set.

Because this class is constructed each time dataset validation occurs, this method must be lightweight and complete on the order of 10 seconds. For example, the entire dataset should NOT be read at this time.

Arguments: N/A

Returns: N/A

Method:

get_train_dataset(self, num_shards=1, shard_index=0)

Definition: Creates and returns the training dataset.

The returned dataset is used to train the network.

Arguments:

num_shards (int): The number of shards across the system. (Must have a default of 1 for the non-distributed case.)

shard_index (int): The shard index of the present instance. (Must have a default of 0 for the non-distributed case.)

The arguments request that the interface provide the training dataset for shard shard_index in the context of num_shards. The shard_index parameter will have values from 0 to num_shards-1.

Note: The method will be called with positional arguments, either with one argument (self) or all three arguments (self, num_shards, shard_index). The latter is used for the distributed processing case when there are multiple workers or machines.

Returns a tuple of:

tf.data.Dataset: Training dataset.

int: Number of samples in this shard of the dataset.

int: Batch size produced by this dataset. This is the batch size of the shard.

Dict[tensor_name: value] (optional): If this dataset requires values passed at run time, they can be specified here.

Method:

get_test_dataset(self)

Definition: Creates and returns the test dataset.

The returned dataset is used for testing and reporting the results of the network on unseen data and has no impact on the GenSynth or training processes.

Arguments: None.

Returns a tuple of:

tf.data.Dataset:Test dataset.

int: Number of samples in this dataset.

int: Batch size produced by this dataset.

Dict[tensor_name: value] (optional): If this dataset requires values passed at run time, they can be specified here.

Method:

get_validation_dataset(self, num_shards=1, shard_index=0)

Definition: Creates and returns the validation dataset.

The returned dataset is used to prevent overfitting during training and guiding subsequent GenSynth cycles.

Arguments:

num_shards (int): The number of shards across the system. (Must have a default of 1 for the non-distributed case.)

shard_index (int): The shard index of the present instance. (Must have a default of 0 for the non-distributed case.)

The arguments request that the interface provide the validation dataset for shard shard_index in the context of num_shards. The shard_index parameter will have values from 0 to num_shards-1.

Note: The method will be called with positional arguments, either with one argument (self) or all three arguments (self, num_shards, shard_index). The latter is used for the distributed processing case when there are multiple workers or machines.

Returns a tuple of:

tf.data.Dataset: Validation dataset.

int: Number of samples in this shard of the dataset.

int: Batch size produced by this dataset. This is the batch size of the shard.

Dict[tensor_name: value] (optional): If this dataset requires values passed at run time, they can be specified here.

image3.png

Example:

class MyCustomInterface():    
    def __init__(self):
...    

    def get_train_dataset(self, num_shards, shard_index):        
        # Use num_shards and shard_index to select data portion
...        
        return train_dataset, train_size, TRAIN_BATCH_SIZE     

    def get_test_dataset(self):
...        
        return test_dataset, test_size, TEST_BATCH_SIZE     

    def get_validation_dataset(self, num_shards, shard_index):
...        
        return val_dataset, val_size, VAL_BATCH_SIZE`

Note

Each of the get_ methods returns the same data types, differing only in which test/train/validation data is returned.

The get_ methods have these properties:

  • They return a tuple of:

    • An object of type tf.train.Dataset.

    • The size of the dataset as an integer (the epoch size) for the shard returned.

    • The batch size of the data (which must match the network tensor shapes).

    • An optional dictionary of tensor name and value pairs required to be fed when using this dataset.

  • They may be called multiple times; each time it is called, the method must construct a new Dataset instance to be returned.

The get_train_dataset() and get_validation_dataset() have unique optional arguments to support multi-worker training. When multiple GPUs or CPUs are used, the dataset interface will be passed num_shards and shard_index, indicating it should only return the data for the specified shard. The default values (for single-GPU training) indicate the entire set of data should be returned.

Usually this Python 3 module is placed in the same folder as the data. Remember that any files referenced by the module must use paths that work within the Docker container.

You must provide the code that creates datasets that provide the data (images, classes, bounding boxes, etc.) to the model for each of the train, validation, and test phases.

The dataset iterator must provide a get_next() function that returns a simple dictionary mapping data keys to tensors.

This structure will be needed later when running GenSynth. For example, for a list or tuple, you will need to know the indices of each tensor; for a dictionary you will need to know the key names of each tensor.

Tip

You must use the dictionary of keys.

This example snippet shows the construction of a training dataset using a TFRecord parse function that returns a dictionary with image and label keys:

image3.png

Example:

def parse_function(serialized):    
    features = {        
        'image': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    }
    parsed = tf.parse_single_example(serialized, features=features) 
    # Parsing TFRecord image into a 28x28x1 image.    
    image = tf.decode_raw(parsed['image'], tf.float32)
    image = tf.reshape(image, [28, 28, 1])    
    label = tf.cast(parsed['label'], tf.int64)    
    # Provide they keys to be referenced later 
    return {'image': image, 'label': label}

TRAIN_FILES = ['/data/train.tfrecords'] 
class CustomInterface:
    def get_train_dataset(self, num_shards=1, shard_index=0):        
        num_data = get_num_data(TRAIN_FILES)        
        dataset = tf.data.TFRecordDataset(filenames=TRAIN_FILES)
        dataset = dataset.map(parse_function)
        if num_shards > 1:
            dataset = dataset.shard(num_shards, shard_index)
        dataset = dataset.shuffle(500)
        batch_size = 16
        dataset = dataset.batch(batch_size)
        return dataset, num_data // num_shards, batch_size

Note

Any Python modules imported by your module must be in the PYTHONPATH.