Replicate Version control for machine learning

This is an experimental feature, and the API may change in the future.

Replicate works with any machine learning framework, but it includes a callback that makes it easier to use with Keras.

ReplicateCallback behaves like Tensorflow's ModelCheckpoint callback, but in addition to exporting a model at the end of each epoch, it also:

  1. Calls replicate.init() at the start of training to create an experiment, and
  2. Calls experiment.checkpoint() after saving the model at the end of the epoch. All metrics are saved, and also any data from the logs dictionary passed to the callback's on_epoch_end method.

Here is a simple example:

import tensorflow as tf
from tensorflow import keras
from replicate.keras_callback import ReplicateCallback
dense_size = 784
learning_rate = 0.01
# from https://www.tensorflow.org/guide/keras/custom_callback
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=dense_size))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=learning_rate),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
model.fit(
x_train[:1000],
y_train[:1000],
batch_size=128,
epochs=20,
validation_split=0.5,
callbacks=[
ReplicateCallback(
params={"dense_size": dense_size, "learning_rate": learning_rate,},
primary_metric=("mean_absolute_error", "minimize"),
),
],
)

The ReplicateCallback class takes the following arguments, all optional:

  • filepath: The path where the exported model is saved. This path is also saved by experiment.checkpoint() at the end of each epoch. If it is None, the model is not saved, and the callback just gathers code and metrics. Default: model.hdf5
  • params: A dictionary of hyperparameters that will be recorded to the experiment at the start of training.
  • primary_metric: A tuple in the format (metric_name, goal), where goal is either minimize, or maximize. For example, ("mean_absolute_error", "minimize").
  • save_freq:"epoch" or integer. When using "epoch", the callback saves the model after each epoch. When using integer, the callback saves the model at end of this many batches. Default: "epoch"
  • save_weights_only: if True, then only the model's weights will be saved (model.save_weights(filepath)), otherwise the full model is saved (model.save(filepath)). Default: False