---
title: Distributed training with Keras 3
date-created: 2023/11/07
last-modified: 2023/11/07
description: Complete guide to the distribution API for multi-backend Keras.
output: rmarkdown::html_vignette
backend: jax
vignette: >
  %\VignetteIndexEntry{Distributed training with Keras 3}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

## Introduction

The Keras distribution API is a new interface designed to facilitate
distributed deep learning across a variety of backends like JAX, TensorFlow and
PyTorch. This powerful API introduces a suite of tools enabling data and model
parallelism, allowing for efficient scaling of deep learning models on multiple
accelerators and hosts. Whether leveraging the power of GPUs or TPUs, the API
provides a streamlined approach to initializing distributed environments,
defining device meshes, and orchestrating the layout of tensors across
computational resources. Through classes like `DataParallel` and
`ModelParallel`, it abstracts the complexity involved in parallel computation,
making it easier for developers to accelerate their machine learning
workflows.

## How it works

The Keras distribution API provides a global programming model that allows
developers to compose applications that operate on tensors in a global context
(as if working with a single device) while
automatically managing distribution across many devices. The API leverages the
underlying framework (e.g. JAX) to distribute the program and tensors according to the
sharding directives through a procedure called single program, multiple data
(SPMD) expansion.

By decoupling the application from sharding directives, the API enables running
the same application on a single device, multiple devices, or even multiple
clients, while preserving its global semantics.

## Setup


``` r
# This guide assumes there are 8 GPUs available for testing. If you don't have
# 8 gpus available locally, you can set the following envvar to
# make xla initialize the CPU as 8 devices, to enable local testing
Sys.setenv("CUDA_VISIBLE_DEVICES" = "")
Sys.setenv("XLA_FLAGS" = "--xla_force_host_platform_device_count=8")
```



``` r
library(keras3)

# The distribution API is only implemented for the JAX backend for now.
use_backend("jax", FALSE)
jax <- reticulate::import("jax")

library(tfdatasets, exclude = "shape") # For dataset input.
```


## `DeviceMesh` and `TensorLayout`

The `keras$distribution$DeviceMesh` class in Keras distribution API represents a cluster of
computational devices configured for distributed computation. It aligns with
similar concepts in [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) and
[`tf.dtensor.Mesh`](https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh),
where it's used to map the physical devices to a logical mesh structure.

The `TensorLayout` class then specifies how tensors are distributed across the
`DeviceMesh`, detailing the sharding of tensors along specified axes that
correspond to the names of the axes in the `DeviceMesh`.

You can find more detailed concept explainers in the
[TensorFlow DTensor guide](https://www.tensorflow.org/guide/dtensor_overview#dtensors_model_of_distributed_tensors).


``` r
# Retrieve the local available gpu devices.
devices <- jax$devices() # "gpu"
str(devices)
```

```
## List of 8
##  $ :TFRT_CPU_0
##  $ :TFRT_CPU_1
##  $ :TFRT_CPU_2
##  $ :TFRT_CPU_3
##  $ :TFRT_CPU_4
##  $ :TFRT_CPU_5
##  $ :TFRT_CPU_6
##  $ :TFRT_CPU_7
```

``` r
# Define a 2x4 device mesh with data and model parallel axes
mesh <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = list("data", "model"),
  devices = devices
)

# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d <- keras$distribution$TensorLayout(
  axes = c("model", "data"),
  device_mesh = mesh
)

# A 4D layout which could be used for data parallelism of an image input.
replicated_layout_4d <- keras$distribution$TensorLayout(
  axes = list("data", NULL, NULL, NULL),
  device_mesh = mesh
)
```


## Distribution

The `Distribution` class in Keras serves as a foundational abstract class designed
for developing custom distribution strategies. It encapsulates the core logic
needed to distribute a model's variables, input data, and intermediate
computations across a device mesh. As an end user, you won't have to interact
directly with this class, but its subclasses like `DataParallel` or
`ModelParallel`.

## DataParallel

The `DataParallel` class in the Keras distribution API is designed for the
data parallelism strategy in distributed training, where the model weights are
replicated across all devices in the `DeviceMesh`, and each device processes a
portion of the input data.

Here is a sample usage of this class.

``` r
# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel <- DataParallel()
data_parallel <- keras$distribution$DataParallel(devices = devices)

# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d <- keras$distribution$DeviceMesh(
  shape = shape(8),
  axis_names = list("data"),
  devices = devices
)
data_parallel <- keras$distribution$DataParallel(device_mesh = mesh_1d)

inputs <- random_normal(c(128, 28, 28, 1))
labels <- random_normal(c(128, 10))
dataset <- tensor_slices_dataset(c(inputs, labels)) |>
  dataset_batch(16)

# Set the global distribution.
keras$distribution$set_distribution(data_parallel)

# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model |> fit()` or
# `model |> evaluate()` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregation of losses,
# since all the computation happens in a global context.
inputs <- keras_input(shape = c(28, 28, 1))
outputs <- inputs |>
  layer_flatten() |>
  layer_dense(units = 200, use_bias = FALSE, activation = "relu") |>
  layer_dropout(0.4) |>
  layer_dense(units = 10, activation = "softmax")

model <- keras_model(inputs = inputs, outputs = outputs)

model |> compile(loss = "mse")
model |> fit(dataset, epochs = 3)
```

```
## Epoch 1/3
## 8/8 - 0s - 40ms/step - loss: 1.1536
## Epoch 2/3
## 8/8 - 0s - 5ms/step - loss: 1.0540
## Epoch 3/3
## 8/8 - 0s - 6ms/step - loss: 1.0072
```

``` r
model |> evaluate(dataset)
```

```
## 8/8 - 0s - 9ms/step - loss: 0.9620
```

```
## $loss
## [1] 0.9620273
```


## `ModelParallel` and `LayoutMap`

`ModelParallel` will be mostly useful when model weights are too large to fit
on a single accelerator. This setting allows you to spit your model weights or
activation tensors across all the devices on the `DeviceMesh`, and enable the
horizontal scaling for the large models.

Unlike the `DataParallel` model where all weights are fully replicated,
the weights layout under `ModelParallel` usually need some customization for
best performances. We introduce `LayoutMap` to let you specify the
`TensorLayout` for any weights and intermediate tensors from global perspective.

`LayoutMap` is a dict-like object that maps a string to `TensorLayout`
instances. It behaves differently from a normal dict in that the string
key is treated as a regex when retrieving the value. The class allows you to
define the naming schema of `TensorLayout` and then retrieve the corresponding
`TensorLayout` instance. Typically, the key used to query
is the `variable$path` attribute,  which is the identifier of the variable.
As a shortcut, a list of axis
names is also allowed when inserting a value, and it will be converted to
`TensorLayout`.

The `LayoutMap` can also optionally contain a `DeviceMesh` to populate the
`TensorLayout$device_mesh` if it is not set. When retrieving a layout with a
key, and if there isn't an exact match, all existing keys in the layout map will
be treated as regex and matched against the input key again. If there are
multiple matches, a `ValueError` is raised. If no matches are found, `NULL` is
returned.


``` r
mesh_2d <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = c("data", "model"),
  devices = devices
)
layout_map  <- keras$distribution$LayoutMap(mesh_2d)

# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] <- tuple(NULL, "model")
layout_map["d1/bias"] <- tuple("model")

# You can also set the layout for the layer output like
layout_map["d2/output"] <- tuple("data", NULL)

model_parallel <- keras$distribution$ModelParallel(
  layout_map = layout_map, batch_dim_name = "data"
)

keras$distribution$set_distribution(model_parallel)

inputs <- layer_input(shape = c(28, 28, 1))
outputs <- inputs |>
  layer_flatten() |>
  layer_dense(units = 200, use_bias = FALSE,
              activation = "relu", name = "d1") |>
  layer_dropout(0.4) |>
  layer_dense(units = 10,
              activation = "softmax",
              name = "d2")

model <- keras_model(inputs = inputs, outputs = outputs)
```

We can visualize how individual weights will be sharded

``` r
d1 <- get_layer(model, "d1")
d1$kernel$value |> jax$debug$visualize_array_sharding()
```

```
## ┌───────┬───────┬───────┬───────┐
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## │CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## └───────┴───────┴───────┴───────┘
```

``` r
d2 <- get_layer(model, "d2")
d2$kernel$value |> jax$debug$visualize_array_sharding()
```

```
## ┌───────────────────┐
## │                   │
## │                   │
## │                   │
## │                   │
## │CPU 0,1,2,3,4,5,6,7│
## │                   │
## │                   │
## │                   │
## │                   │
## └───────────────────┘
```

``` r
d2$bias$value |> jax$debug$visualize_array_sharding()
```

```
## ┌───────────────────┐
## │CPU 0,1,2,3,4,5,6,7│
## └───────────────────┘
```

``` r
x_batch <- dataset |>
  as_iterator() |> iter_next() |>
  _[[1]] |> op_convert_to_tensor()

output_array <- model(x_batch)
output_array |> jax$debug$visualize_array_sharding()
```

```
## ┌─────────────┐
## │             │
## │ CPU 0,1,2,3 │
## │             │
## │             │
## ├─────────────┤
## │             │
## │ CPU 4,5,6,7 │
## │             │
## │             │
## └─────────────┘
```




``` r
# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model |> compile(loss = "mse")
model |> fit(dataset, epochs = 3)
```

```
## Epoch 1/3
## 8/8 - 0s - 46ms/step - loss: 1.1676
## Epoch 2/3
## 8/8 - 0s - 4ms/step - loss: 1.1134
## Epoch 3/3
## 8/8 - 0s - 5ms/step - loss: 1.1034
```

``` r
model |> evaluate(dataset)
```

```
## 8/8 - 0s - 8ms/step - loss: 1.0676
```

```
## $loss
## [1] 1.067567
```


It is also easy to change the mesh structure to tune the computation between
more data parallel or model parallel. You can do this by adjusting the shape of
the mesh. And no changes are needed for any other code.


``` r
full_data_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(8, 1),
  axis_names = list("data", "model"),
  devices = devices
)
more_data_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(4, 2),
  axis_names = list("data", "model"),
  devices = devices
)
more_model_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = list("data", "model"),
  devices = devices
)
full_model_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(1, 8),
  axis_names = list("data", "model"),
  devices = devices
)
```


### Further reading

1. [JAX Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
2. [JAX sharding module](https://jax.readthedocs.io/en/latest/jax.sharding.html)
3. [TensorFlow Distributed training with DTensors](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial)
4. [TensorFlow DTensor concepts](https://www.tensorflow.org/guide/dtensor_overview)
5. [Using DTensors with tf.keras](https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial)