Smart Distributed Training on Amazon SageMaker with SMD: Part 3*rnH5Asy2ciue5p53

Original Source Here

Smart Distributed Training on Amazon SageMaker with SMD: Part 3

Photo by Martin Jernberg on Unsplash

This is the final part of a three part post on the topic of optimizing distributed training. In part one, we provided a brief survey of distributed training algorithms. We noted that common to all algorithms is their reliance on high-speed communication between multiple GPUs. We surmised that a distributed algorithm that accounted for the underlying instance topology, particularly the differences in the communication links between GPU pairs, would perform better than one that did not. In part two we focused on data distribution and showed the benefits of Amazon SageMaker’s distributed data parallel (SDP) library.

In part three we will turn our attention to model distribution. We will demonstrate one way in which Amazon SageMaker’s distributed model parallel library allows you to configure your model distribution algorithm in a way that distinguishes between intra-node and inter-node GPU-to-GPU communication.

Model Distribution with SageMaker Distributed Model Parallel

Similar to SDP for data distribution, the Amazon SageMaker distributed model parallel (SMP) library aims to simplify and accelerate model distributed training. The library includes APIs for each of the model distribution techniques discussed above (though you should check the documentation for details on the support matrix). It allows for combining some of the methods, and includes controls for automating some of the configurations.

In this section we would like to focus on the sharded data parallelism support. Recall that the parameter sharding algorithm dictated by the ZeRO algorithm (described above) does not distinguish between intra-node GPU links and inter-node GPU links. Taking into account the relatively high communication volume of this algorithm (see details here) and the fact that this algorithm is usually applied to relatively large models, the inter-node bandwidth could become a potential bottleneck. The SMP library enables you to address this concern by supporting a technique called MiCS (due to the fact that it minimizes communication scale) or ZeRO-2D. MiCS introduces the notion of a partition group. The full set of GPUs is divided into partitions groups (of equal size), each of which contains a single replica of the model. The parameters of each model replica are sharded (using the ZeRO algorithm) across the GPUs in the model’s partition group. Within a partition group, parameters are communicated according to the ZeRO algorithm. Alignment between the model replicas is maintained via gradient sharing between corresponding GPUs in the different partition groups. This leads to a hierarchical communication strategy as shown in the following illustration showing two nodes with two GPUs each:

The intra-node (ZeRO algorithm) is shown in orange. The inter-node (gradient sharing) in red. (From


Here we show an example of integrating sharded data parallelism into a PyTorch (1.12) training job. The model we chose is a vision transformer (ViT) model with roughly 630 million parameters. The model was built using the transformers python package.

The code block below demonstrates how to instantiate a model distributed training job on eight g4dn.12xlarge instances (with 4 GPUs each) using the sharded data parallelism API. Note that model parallelism is required in this case as the ViT model we chose is too large to fit into a single GPU.

from sagemaker.pytorch import PyTorchdistribution = {
"mpi": {
"enabled": True,
"processes_per_host": 4
"smdistributed": {
"modelparallel": {
"enabled": True,
"parameters": {
"ddp": True,
"sharded_data_parallel_degree": 32,
"delayed_parameter_initialization": True
pytorch = PyTorch(entry_point='',

The script above is configured with the sharded_data_parallel_degree set to 32. This will run the classic ZeRO based parameter sharding algorithm. To use MiCS (ZeRO-2D), reconfigure this parameter to the desired size of the partition group.

The code block below contains the code of the associated training script.

import torch
from import Dataset
import time
import argparse
from transformers import ViTForImageClassification, ViTConfig
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=[index % 1000], dtype=torch.int64)
return rand_image, label
def build_model():
model_args = {
"image_size": 256,
"patch_size": 16,
"hidden_size": 1024,
"num_hidden_layers": 50,
"num_attention_heads": 16,
"intermediate_size": 4*1024,
"num_labels": 1000
model = ViTForImageClassification(ViTConfig(**model_args))
return model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='/tmp', type=str)
args, _ = parser.parse_known_args()
# init
import smdistributed.modelparallel.torch as smp
from import stage3
stage3.assert_ints_same_as_other_ranks = lambda x: None
dataset = FakeDataset()
data_loader =,
batch_size=4, num_workers=12)
model = build_model()
model = smp.DistributedModel(model)
# add checkpoint activation
for m in model.get_module().vit.encoder.layer.children():
optimizer = torch.optim.Adam(model.parameters())
optimizer = smp.DistributedOptimizer(optimizer)
loss_function = torch.nn.CrossEntropyLoss()
t0 = time.perf_counter()
summ = 0
count = 0
def train_step(model, inputs, targets):
outputs = model(inputs)
loss = loss_function(outputs['logits'], targets)
return loss
for idx, (inputs, targets) in enumerate(data_loader, start=1):
inputs =
targets =
loss_mb = train_step(model, inputs, targets)
loss = loss_mb.reduce_mean()
if torch.distributed.get_rank() == 0:
batch_time = time.perf_counter() - t0
print(f'step: {idx}: step time is {batch_time}')
if idx > 1: # skip first step
summ += batch_time
count += 1
if torch.distributed.get_rank() == 0:
print(f'average step time: {summ/count}')

For more details on the SMP API please see the SageMaker documentation and examples.


In the table below we compare the step time results of the standard ZeRO based parameter sharding algorithm (sharded_data_parallel_degree = 32) to the results of ZeRO-2D (MiCS) with the partition group set to 4 (sharded_data_parallel_degree = 4).

Average Step Time (lower is better) — (by Author)

Once again, we can see that the algorithm that takes into account the details of the environment outperforms the base algorithm by roughly 14%. Once again, we caution against drawing any conclusions regarding your own model as the relative performance can vary based on the model details and environment setup.


Multi-instance training can be quite expensive and any opportunity to increase your training speed can have a meaningful impact on your overall cost. In this post we have shown how distributed training algorithms that are tailored to the topology of the underlying training environment can boost performance. However, the degree to which such algorithms will help you, and whether adopting them is the right choice for you, will greatly depend on the details of your project and training environment. In the examples we shared we tried to demonstrate how to program your code so that you can easily toggle between different choices. It’s good to have options.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: