Replicate Version control for machine learning

The Replicate Python library is used to create experiments and checkpoints in your training script. It also has functions for programmatically analyzing the experiments.

These two modes are comprehensively described below in the Experiment tracking and Analyze and plot experiments sections. For an introduction of how to use the Python API, see the tutorial.

To install the library, see the installation instructions.

Experiment tracking

These functions implement the core functionality of Replicate, i.e. saving parameters, metrics, and code to the repository. You use these functions during model training.

replicate.init()

Create and return an experiment.

It takes these arguments:

  • path: A path to a file or directory that will be uploaded to the repository, relative to the project directory. This can be used to save your training code, or anything you want. If path is not set, no data will be saved.
  • params: A dictionary of hyperparameters to record along with the experiment.

The path saved is relative to the project directory. The project directory is determined by the directory that contains replicate.yaml. If no replicate.yaml is found in any parent directories, the current working directory will be used.

If you want to exclude some files from being included, you can create a .replicateignore file alongside replicate.yaml. It is in the same format as .gitignore.

The repository location for this data is determined by the repository option in replicate.yaml. Learn more in the reference documentation.

For example:

>>> import replicate
>>> experiment = replicate.init(
... path=".",
... params={"learning_rate" 0.01},
... )

experiment.checkpoint()

Create a checkpoint within an experiment.

It takes these arguments:

  • path: A path to a file or directory that will be uploaded to the repository, relative to the project directory. This can be used to save weights, Tensorboard logs, and other artifacts produced during the training process. If path is not set, no data will be saved.
  • metrics: A dictionary of metrics to record along with the checkpoint.
  • primary_metric (optional): A tuple (name, goal) to define one of the metrics as a primary metric to optimize. Goal can either be minimize or maximize.
  • step (optional): the iteration number of this checkpoint, such as epoch number. This is displayed in replicate ls and various other places.

Like replicate.init(), the path saved is relative to the project directory. The project directory is determined by the directory that contains replicate.yaml. If no replicate.yaml is found in any parent directories, the current working directory will be used.

Any keyword arguments passed to the function will also be recorded.

For example:

>>> experiment.checkpoint(
... path="weights/",
... step=5,
... metrics={"train_loss": 0.425, "train_accuracy": 0.749},
... primary_metric=("train_accuracy", "maximize"),
... )

experiment.stop()

Stop an experiment.

Experiments running in a script will eventually time out, but when you're using a notebook, you must to call this method to mark an experiment as stopped.

For example:

>>> experiment.stop()

Analyze and plot experiments

The Python API also contains a set of functions to analyze and plot the results of experiments. These functions are analogous to the commands found in the CLI.

For interactive examples of how to use these functions, see this Colab notebook

replicate.experiments.list()

List the experiments in the current project. Returns a list of Experiment objects.

It takes one argument:

  • filter (optional): A function with the signature def filter(exp: replicate.Experiment): bool. If this function returns true, the experiment is included in the returned list. If not provided, all experiments are returned.

For example:

>>> experiments = replicate.experiments.list(
... filter=lambda exp: exp.best().step >= 100
... )

experiments.plot()

The list of experiments returned by replicate.experiments.list() can be plotted and compared graphically. experiments.plot() plots a single metric from the checkpoints of all experiments in a list of experiments.

It takes two arguments:

  • metric (optional): The name of the metric to plot. If omitted, defaults to the primary metric, or raises an error if no primary metric is defined.
  • logy (optional): Whether to use a logarithmic Y-axis. Defaults to false.

For example:

>>> experiments = replicate.experiments.list()
>>> experiments.plot(metric="loss")

experiments.scatter()

This function plots a metric against a hyperparameter for all the experiments in a list of experiments. If a primary metric is defined, the best checkpoint for each experiment is used, otherwise the latest checkpoint is used.

It takes four arguments:

  • param: The name of the hyperparameter to plot on the X-axis.
  • metric (optional): The name of the metric to plot. If omitted, defaults to the primary metric, or raises an error if no primary metric is defined.
  • logx (optional): Whether to use a logarithmic X-axis. Defaults to false.
  • logy (optional): Whether to use a logarithmic Y-axis. Defaults to false.

For example:

>>> experiments = replicate.experiments.list()
>>> experiments.scatter(param="learning_rate", metric="loss")

experiments.delete()

Delete all experiments in this list of experiments.

For example:

>>> bad_experiments = replicate.experiments.list(
... lambda exp: exp.best().metrics["accuracy"] < 0.5
... )
>>> bad_experiments.delete()

replicate.experiments.get()

Returns a single Experiment object.

It takes a single argument:

  • experiment_id: The ID of the experiment to return. Can either be a full ID or an unambiguous prefix.

For example:

>>> exp = replicate.experiments.get("abcd")

The Experiment class

Experiment objects have the following fields:

  • id: The experiment ID
  • created: The datetime.datetime of when the experiment was created
  • duration: The time between the creation of the experiment and the latest checkpoint
  • user: The username of the person who started the experiment
  • command: The command that the experiment was started with
  • path: The path which was saved to the repository by replicate.init(), relative to replicate.yaml
  • params: The dictionary of parameters that were defined in replicate.init()
  • python_packages: The Python packages and their versions which where imported when the experiment started
  • checkpoints: A list of (Checkpoint)[#the-checkpoint-class] objects

experiment.latest()

Returns the latest checkpoint for this experiment.

experiment.best()

Returns the best checkpoint for this experiment, according to the primary metric. If no primary metric is defined, returns None.

experiment.delete()

Delete this experiment and its checkpoints.

experiment.refresh()

Update this experiment with the latest data from the repository. This is useful when you are doing analysis of an experiment that is running, and you want to get the latest data.

experiment.plot()

Plot a single metric in this experiment.

It takes two arguments:

  • metric (optional): The name of the metric to plot. If omitted, defaults to the primary metric, or raises an error if no primary metric is defined.
  • logy (optional): Whether to use a logarithmic Y-axis. Defaults to false.

For example:

>>> experiment = replicate.experiments.get("ab1234")
>>> experiment.plot(metric="loss")

The Checkpoint class

Checkpoint objects have the following fields:

  • id: The checkpoint ID
  • created: The datetime.datetime of when the checkpoint was created
  • path: The path which was saved to the repository by experiment.checkpoint(), relative to replicate.yaml
  • step: The step as defined in experiment.checkpoint()
  • metrics: The dictionary of metrics that were defined in experiment.checkpoint()
  • primary_metric: The primary metric as defined in experiment.checkpoint()

checkpoint.checkout()

Similar to replicate checkout in the CLI, checkpoint.checkout() copies the files from the checkpoint and its parent experiment to a local directory.

It takes two arguments:

  • output_directory: The directory to save the files to
  • quiet (optional): Whether to output log messages, defaults to false

For example:

>>> exp = replicate.experiments.get("abcd")
>>> chk = exp.best()
>>> chk.checkout("/tmp/my-folder")

checkpoint.open()

Return a file-like objects to a single file in this checkpoint or its associated experiment.

It takes a single argument:

  • path: The path to the file to open

For example:

>>> exp = replicate.experiments.get("abcd")
>>> chk = exp.best()
>>> model_file = chk.open("model.pkl")
>>> model = torch.load(model_file)

experiment.checkpoints.metrics

The checkpoints list on Experiment objects has a metrics attribute that returns a list of metric values for each of the checkpoints in this list. This is a shorthand for the following list comprehension:

>>> [chk.metrics["loss"] for chk in experiment.checkpoints]
[0.99, 0.87, 0.64]
>>> experiment.checkpoints.metrics["loss"]
[0.99, 0.87, 0.64]

experiment.checkpoints.step

Similar to experiment.checkpoints.metrics, this attribute can be used as a shorthand to get the step of each checkpoint in the checkpoint list.

>>> [chk.step for chk in experiment.checkpoints]
[0, 10, 20]
>>> experiment.checkpoints.step
[0, 10, 20]

experiment.checkpoints.plot()

Plot a single metric across all checkpoints in an experiment. This is similar to experiment.plot(), except it allows you to select a range of checkpoints to plot.

For example:

>>> experiment.checkpoints.plot("loss") # equivalent to experiment.plot("loss")
>>> experiment.checkpoints[-20:].plot("loss") # plot the last 20 checkpoints
>>> experiment.checkpoints[::5].plot("loss") # plot every 5 checkpoints

The Project class

If you want to use Replicate without creating a replicate.yaml configuration file, you can configure the repository location using the Project class.

For example, at the training stage:

>>> project = replicate.Project(repository="s3://my-bucket", directory="/Users/andreas/my-project")
>>> experiment = project.experiments.create(path=".", params={"foo": "bar"}) # equivalent to replicate.init()
>>> experiment.checkpoint([...])

Or during analysis:

>>> project = replicate.Project(repository="s3://my-bucket")
>>> experiments = project.experiments.list() # equivalent to replicate.experiments.list()