Close Menu
SkytikSkytik

    Subscribe to Updates

    Get the latest creative news from FooBar about art, design and business.

    What's Hot

    At Least 32 People Dead After a Mine Bridge Collapsed Due to Overcrowding

    November 17, 2025

    Here’s how I turned a Raspberry Pi into an in-car media server

    November 17, 2025

    Beloved SF cat’s death fuels Waymo criticism

    November 17, 2025
    Facebook X (Twitter) Instagram
    • About Us
    • Contact Us
    SkytikSkytik
    • Home
    • AI Tools
    • Online Tools
    • Tech News
    • Guides
    • Reviews
    • SEO & Marketing
    • Social Media Tools
    SkytikSkytik
    Home»AI Tools»AI in Multiple GPUs: ZeRO & FSDP
    AI Tools

    AI in Multiple GPUs: ZeRO & FSDP

    AwaisBy AwaisMarch 5, 2026No Comments9 Mins Read0 Views
    Facebook Twitter Pinterest LinkedIn Telegram Tumblr Email
    AI in Multiple GPUs: ZeRO & FSDP
    Share
    Facebook Twitter LinkedIn Pinterest Email

    of a series about distributed AI across multiple GPUs:

    Introduction

    In the previous post, we saw how Distributed Data Parallelism (DDP) speeds up training by splitting batches across GPUs. DDP solves the throughput problem, but it introduces a new challenge: memory redundancy.

    In vanilla DDP, every GPU holds a complete copy of the model parameters, gradients, and optimizer states. For large models like GPT-3 (175B parameters), this redundancy becomes a big waste of precious VRAM.

    Image by author: Model, gradients and optimizer are redundant across GPUs in regular DDP

    ZeRO (Zero Redundancy Optimizer) solves this. There are three levels:

    • ZeRO-1 partitions only optimizer states
    • ZeRO-2 partitions optimizer states + gradients
    • ZeRO-3 partitions optimizer states + gradients + model parameters

    ZeRO isn’t a parallelism technique because all GPUs still run the same forward and backward passes. It’s a memory optimization strategy that eliminates redundancy across GPUs, letting you train larger models on the same hardware.

    The Memory Problem in DDP

    Let’s break down what actually consumes memory during training. For a model with  parameters:

    • Model Parameters:  values (the weights of your neural network)
    • Gradients:  values (one gradient per parameter)
    • Optimizer States (Adam):  values (first moment  and second moment  for each parameter)
    • Activations: Intermediate outputs stored during forward pass for use in backward pass

    The first three scale with model size and are redundant across GPUs in DDP. Activations scale with batch size, sequence length, and # neurons, and are unique per GPU since each GPU processes different data. ZeRO doesn’t touch activation memory.

    Let’s calculate the memory usage for a 7B-parameter model using Adam and FP32:

    • Parameters: 7 billion * 4 bytes = 28 GB
    • Gradients: 7 billion * 4 bytes = 28 GB
    • Optimizer states: 7 billion * 2 * 4 bytes = 56 GB
    • Memory per GPU in DDP:  112 GB

    Activations add significant memory on top of this, but since they’re unique per GPU, ZeRO can’t partition them. Techniques like activation checkpointing can help, it discards some activations and then recomputes them as needed during the backward pass. But that’s outside the scope of this article.

    Let’s understand how ZeRO works by implementing it from the ground up, starting with ZeRO-1 and working our way to ZeRO-3.

    ZeRO-1: Optimizer State Partitioning

    In ZeRO-1, only the optimizer states are partitioned. Each GPU:

    • Still holds the full model parameters and gradients
    • Stores only 1/N of the optimizer states (N = number of GPUs)
    • Updates only the corresponding 1/N of the parameters

    This is the sequence actions taken during training:

    1. Forward pass: each GPU processes its own micro-batch
    2. Backward pass: compute gradients
    3. all-reduce gradients: every GPU gets the all gradients
    4. Optimizer step: Each GPU updates its parameter partition
    5. all-gather parameters: sync the updated model across GPUs
    Image by author: Zero 1 animation

    Here’s a simplified implementation:

    import torch
    import torch.distributed as dist
    
    
    class ZeRO_1:
        def __init__(self, model, optimizer_cls):
            self.model = model
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
    
            self.param_shards = list()  # each rank holds only its shard of the optimizer states
            self.param_metadata = list()  # metadata to reconstruct shards
    
            for param in self.model.parameters():
                original_shape = param.data.shape
                flat = param.data.view(-1)
                numel = flat.numel()
    
                remainder = numel % self.world_size
                pad_size = (self.world_size - remainder) % self.world_size
                padded_numel = numel + pad_size
                shard_size = padded_numel // self.world_size
    
                shard_start = self.rank * shard_size
                shard_end = shard_start + shard_size
    
                self.param_metadata.append(
                    {
                        "original_shape": original_shape,
                        "numel": numel,
                        "padded_numel": padded_numel,
                        "shard_size": shard_size,
                        "shard_start": shard_start,
                        "shard_end": shard_end,
                    }
                )
    
                if pad_size > 0:
                    flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
                else:
                    flat_padded = flat
    
                shard = flat_padded[shard_start:shard_end].clone()
                shard.requires_grad_(True)
                self.param_shards.append(shard)
    
            self.optimizer = optimizer_cls(self.param_shards)
    
        def training_step(self, inputs, targets, loss_fn):
            output = self.model(inputs) # forward
            loss = loss_fn(output, targets) # compute loss
            loss.backward() # backward
    
            self._sync_gradients()  # all-reduce gradients across GPUs
            self.optimizer.step() # update local shard of parameters
            self._sync_params() # all gather model params
    
            # clear gradients for the next step
            for param in self.model.parameters():
                param.grad = None
    
        def _sync_gradients(self):
            for idx, param in enumerate(self.model.parameters()):
                meta = self.param_metadata[idx]
    
                dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                param.grad /= self.world_size
    
                self.param_shards[idx].grad = param.grad.view(-1)[meta["shard_start"]:meta["shard_end"]]
    
        def _sync_params(self):
            for idx, param in enumerate(self.model.parameters()):
                meta = self.param_metadata[idx]
    
                full_flat = torch.empty(meta["padded_numel"], device=param.device, dtype=param.dtype)
                dist.all_gather_into_tensor(
                    output_tensor=full_flat,
                    input_tensor=self.param_shards[idx].data,
                )
                
                reconstructed = full_flat[:meta["numel"]].view(meta["original_shape"])
                param.data.copy_(reconstructed)

    Notice that the all-reduce syncs all gradients, but each GPU only uses the gradients for its own parameter partition, it’s overcommunicating. ZeRO-2 fixes this by sharding the gradients too.

    In practice, you’d never use ZeRO-1 as ZeRO-2 gives you better memory savings at essentially the same cost. But it’s still worth going over it for learning purposes.

    Memory with ZeRO-1, 7B model, 8 GPUs:

    • Parameters: 28 GB (fully replicated)
    • Gradients: 28 GB (fully replicated)
    • Optimizer states: 56 GB / 8 = 7 GB
    • Total per GPU: 63 GB (down from  GB)

    ZeRO-2: Gradient Partitioning

    ZeRO-2 partitions both optimizer states and gradients. Since each GPU only updates a partition of parameters, it only needs the corresponding gradients.

    ZeRO-1 uses all-reduce, which gives every GPU all the gradients. ZeRO-2 replaces this with reduce-scatter, each GPU receives only the gradients it actually needs. This saves both memory and communication bandwidth.

    Training steps:

    1. Forward pass: each GPU processes its own micro-batch
    2. Backward pass: compute gradients
    3. reduce-scatter gradients: each GPU gets only its partition
    4. Optimizer step: Each GPU updates its parameter partition
    5. all-gather parameters: sync the updated model across GPUs
    Image by author: Zero 2 animation

    The implementation is very similar to ZeRO-1, but the gradient synchronization step uses reduce-scatter instead of all-reduce:
    But wait, if every GPU computes all gradients during backprop, how does this actually save VRAM? Here’s how:

    • As the parameter gradients are computed layer by layer, they’re immediately reduce-scattered and the local copy is freed (our simplified implementation doesn’t perform this).
    • During backprop, you only need the gradient of the next neuron activation to compute the current param’s gradient, i.e., you don’t need the entire gradient graph.
    • That way you can free up the memory for gradients as you’re moving backwards, keeping only the assigned partition for each GPU.

    Memory with ZeRO-2, 7B model, 8 GPUs:

    • Parameters: 28 GB (fully replicated)
    • Gradients: 28 GB / 8 = 3.5 GB
    • Optimizer states: 56 GB / 8 = 7 GB
    • Total per GPU: 38.5 GB (down from 112 GB)

    ZeRO-3: Parameter Partitioning

    ZeRO-3 partitions optimizer states, gradients, and parameters. Each GPU stores only 1/N of the entire model state.

    During forward and backward passes, each layer needs its full parameters, but each GPU only stores a fraction. So we all-gather parameters just-in-time, use them, then discard immediately after.

    Training steps:

    • Forward pass:
      • All-gather the layer’s parameters from all GPUs
      • Run the layer’s forward pass using previous layer’s activations as input
      • Discard the gathered parameters (keep only the local partition)
      • Repeat these steps until all layers are done
    • Backward pass (per layer, in reverse):
      • All-gather the layer’s parameters again
      • Compute gradients for current layer using activation gradients from next layer
      • Reduce-scatter the gradients (each GPU keeps its shard)
      • Discard the gathered parameters (keep only the local partition)
      • Repeat these steps until all layers are done
    • Each GPU runs an optimizer step on its partition
    • No final all-gather needed since parameters are gathered layer-by-layer during the forward pass
    Image by author: Zero 3 animation

    Here’s a simplified implementation:

    class ZeRO_3(ZeRO_2):
        """
        ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + model parameters (stage 3).
    
        At rest, each rank holds only param_shards[idx] — a 1/world_size slice
        of each parameter. Full parameters are materialised temporarily during
        the forward and backward passes via all_gather, then immediately freed.
        """
    
        def __init__(self, model, optimizer_cls):
            self.model = model
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
    
            self.param_metadata = []
            shard_list = []
    
            self._param_to_idx = {}
    
            for idx, param in enumerate(self.model.parameters()):
                original_shape = param.data.shape
                flat = param.data.view(-1)
                numel = flat.numel()
    
                remainder = numel % self.world_size
                pad_size = (self.world_size - remainder) % self.world_size
                padded_numel = numel + pad_size
                shard_size = padded_numel // self.world_size
    
                shard_start = self.rank * shard_size
                shard_end = shard_start + shard_size
    
                self.param_metadata.append(
                    {
                        "original_shape": original_shape,
                        "numel": numel,
                        "padded_numel": padded_numel,
                        "shard_size": shard_size,
                        "shard_start": shard_start,
                        "shard_end": shard_end,
                    }
                )
    
                if pad_size > 0:
                    flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
                else:
                    flat_padded = flat
    
                shard = flat_padded[shard_start:shard_end].clone()
                shard_list.append(shard)
    
                # Replace the full tensor with only this rank's shard.
                # The model's param.data now points to a tiny slice; the full
                # weight will be reconstructed on demand during forward/backward.
                param.data = shard.detach()
                self._param_to_idx[param] = idx
    
            self.param_shards = [s.requires_grad_(True) for s in shard_list]
            self.optimizer = optimizer_cls(self.param_shards)
    
            self._register_hooks()
    
        def _gather_param(self, idx, device, dtype):
            """All-gather the full parameter tensor for parameter `idx`."""
            meta = self.param_metadata[idx]
            full_flat = torch.empty(meta["padded_numel"], device=device, dtype=dtype)
            dist.all_gather_into_tensor(
                output_tensor=full_flat,
                input_tensor=self.param_shards[idx].data,
            )
            return full_flat[: meta["numel"]].view(meta["original_shape"])
    
        def _gather_module_params(self, module):
            """Gather full params for every parameter that belongs to this module only (not children)."""
            for param in module.parameters(recurse=False):
                idx = self._param_to_idx[param]
                param.data = self._gather_param(idx, param.device, param.dtype)
    
        def _reshard_module_params(self, module):
            """Reshard params back to local shard for every direct param of this module."""
            for param in module.parameters(recurse=False):
                idx = self._param_to_idx[param]
                param.data = self.param_shards[idx].data
    
        def _register_hooks(self):
            self._hooks = []
            for module in self.model.modules():
                # Skip container modules that have no direct parameters
                if not list(module.parameters(recurse=False)):
                    continue
    
                # Forward: gather -> run -> reshard
                h1 = module.register_forward_pre_hook(
                    lambda mod, _inputs: self._gather_module_params(mod)
                )
                h2 = module.register_forward_hook(
                    lambda mod, _inputs, _output: self._reshard_module_params(mod)
                )
    
                # Backward: gather before grad computation → reshard after
                h3 = module.register_full_backward_pre_hook(
                    lambda mod, _grad_output: self._gather_module_params(mod)
                )
                h4 = module.register_full_backward_hook(
                    lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
                )
    
                self._hooks.extend([h1, h2, h3, h4])
    
        def training_step(self, inputs, targets, loss_fn):
            # Hooks handle all gather/reshard around each module automatically
            output = self.model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
    
            self._sync_gradients()
    
            # Each rank updates only its local shard
            self.optimizer.step()
    
            for param in self.model.parameters():
                param.grad = None

    Each layer’s parameters are gathered right before they’re needed and freed immediately after. This keeps peak memory minimal at the cost of more communication. In practice, implementations overlap the all-gather for layer N+1 with the forward of layer N to hide latency.

    Memory with ZeRO-3, 7B model, 8 GPUs:

    • Parameters: 28 GB / 8 = 3.5 GB
    • Gradients: 28 GB / 8 = 3.5 GB
    • Optimizer states: 56 GB / 8 = 7 GB
    • Total per GPU: 14 GB (down from 112 GB)

    That’s an 8x reduction in memory usage, which is exactly what we’d expect from partitioning across 8 GPUs.

    Using ZeRO in PyTorch

    PyTorch ships with two implementations of ZeRO-3: FSDP1 (older, less optimized) and FSDP2 (newer, recommended). Always use FSDP2.

    FSDP (Fully Sharded Data Parallel) handles parameter gathering, gradient scattering, communication overlap, and memory management automatically:

    from torch.distributed.fsdp import fully_shard
    
    model = Transformer()
    for layer in model.layers:
        fully_shard(layer)
    fully_shard(model)

    You have to apply fully_shard layer-by-layer and then wrap the whole model.

    Conclusion

    ZeRO is exchanging memory for communication, so it’s not a free lunch. In general it’s not worth it for smaller models (e.g. BERT) but it’s a game changer for larger models.

    Congratulations on making it to the end! In this post, you learned about:

    • The memory redundancy problem in standard DDP
    • How ZeRO partitions optimizer states, gradients, and parameters across GPUs
    • The three stages of ZeRO and their memory/communication trade-offs
    • How to use ZeRO-3 via PyTorch’s FSDP

    In the next article, we’ll explore Tensor Parallelism, a model parallelism technique that speeds up a layer computation by distributing work across GPUs.

    References

    1. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (Original Paper)
    2. PyTorch FSDP Tutorial
    3. FSDP API Reference
    4. The Ultra-Scale Playbook by Huggging Face
    FSDP GPUs Multiple
    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Awais
    • Website

    Related Posts

    Escaping the SQL Jungle | Towards Data Science

    March 21, 2026

    A Gentle Introduction to Nonlinear Constrained Optimization with Piecewise Linear Approximations

    March 21, 2026

    Agentic RAG Failure Modes: Retrieval Thrash, Tool Storms, and Context Bloat (and How to Spot Them Early)

    March 21, 2026

    Multi-Hop Data Synthesis for Generalizable Vision-Language Reasoning

    March 21, 2026

    How to Measure AI Value

    March 20, 2026

    What Really Controls Temporal Reasoning in Large Language Models: Tokenisation or Representation of Time?

    March 20, 2026
    Leave A Reply Cancel Reply

    Top Posts

    At Least 32 People Dead After a Mine Bridge Collapsed Due to Overcrowding

    November 17, 20250 Views

    Here’s how I turned a Raspberry Pi into an in-car media server

    November 17, 20250 Views

    Beloved SF cat’s death fuels Waymo criticism

    November 17, 20250 Views
    Don't Miss

    Escaping the SQL Jungle | Towards Data Science

    March 21, 2026

    don’t collapse overnight. They grow slowly, query by query. “What breaks when I change a…

    SEO’s new battleground: Winning the consensus layer

    March 21, 2026

    A Gentle Introduction to Nonlinear Constrained Optimization with Piecewise Linear Approximations

    March 21, 2026

    23 Radish Recipes for Salads, Pickles, and More

    March 21, 2026
    Stay In Touch
    • Facebook
    • YouTube
    • TikTok
    • WhatsApp
    • Twitter
    • Instagram
    Latest Reviews

    Google confirms AI headline rewrites test in Search results

    March 21, 2026

    How to add Google Calendar to Outlook

    March 21, 2026
    Most Popular

    13 Trending Songs on TikTok in Nov 2025 (+ How to Use Them)

    November 18, 20257 Views

    How to watch the 2026 GRAMMY Awards online from anywhere

    February 1, 20263 Views

    Corporate Reputation Management Strategies | Sprout Social

    November 19, 20252 Views
    Our Picks

    At Least 32 People Dead After a Mine Bridge Collapsed Due to Overcrowding

    November 17, 2025

    Here’s how I turned a Raspberry Pi into an in-car media server

    November 17, 2025

    Beloved SF cat’s death fuels Waymo criticism

    November 17, 2025

    Subscribe to Updates

    Get the latest creative news from FooBar about art, design and business.

    Facebook X (Twitter) Instagram Pinterest YouTube Dribbble
    • About Us
    • Contact Us
    • Privacy Policy
    • Terms & Conditions
    • Disclaimer

    © 2025 skytik.cc. All rights reserved.

    Type above and press Enter to search. Press Esc to cancel.