Close Menu
SkytikSkytik

    Subscribe to Updates

    Get the latest creative news from FooBar about art, design and business.

    What's Hot

    At Least 32 People Dead After a Mine Bridge Collapsed Due to Overcrowding

    November 17, 2025

    Here’s how I turned a Raspberry Pi into an in-car media server

    November 17, 2025

    Beloved SF cat’s death fuels Waymo criticism

    November 17, 2025
    Facebook X (Twitter) Instagram
    • About Us
    • Contact Us
    SkytikSkytik
    • Home
    • AI Tools
    • Online Tools
    • Tech News
    • Guides
    • Reviews
    • SEO & Marketing
    • Social Media Tools
    SkytikSkytik
    Home»AI Tools»Train Your Large Model on Multiple GPUs with Tensor Parallelism
    AI Tools

    Train Your Large Model on Multiple GPUs with Tensor Parallelism

    AwaisBy AwaisJanuary 1, 2026No Comments7 Mins Read0 Views
    Facebook Twitter Pinterest LinkedIn Telegram Tumblr Email
    Train Your Large Model on Multiple GPUs with Tensor Parallelism
    Share
    Facebook Twitter LinkedIn Pinterest Email

    import dataclasses

    import datetime

    import os

     

    import datasets

    import tokenizers

    import torch

    import torch.distributed as dist

    import torch.nn as nn

    import torch.nn.functional as F

    import torch.optim.lr_scheduler as lr_scheduler

    import tqdm

    from torch import Tensor

    from torch.distributed.checkpoint import load, save

    from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner

    from torch.distributed.fsdp import FSDPModule, fully_shard

    from torch.distributed.tensor import Replicate, Shard

    from torch.distributed.tensor.parallel import (

        ColwiseParallel,

        PrepareModuleInput,

        RowwiseParallel,

        SequenceParallel,

        loss_parallel,

        parallelize_module,

    )

    from torch.utils.data.distributed import DistributedSampler

     

    # Set default to bfloat16

    torch.set_default_dtype(torch.bfloat16)

    print(“NCCL version:”, torch.cuda.nccl.version())

     

    # Build the model

    @dataclasses.dataclass

    class LlamaConfig:

        “”“Define Llama model hyperparameters.”“”

        vocab_size: int = 50000  # Size of the tokenizer vocabulary

        max_position_embeddings: int = 2048  # Maximum sequence length

        hidden_size: int = 768  # Dimension of hidden layers

        intermediate_size: int = 4*768  # Dimension of MLP’s hidden layer

        num_hidden_layers: int = 12  # Number of transformer layers

        num_attention_heads: int = 12  # Number of attention heads

        num_key_value_heads: int = 3  # Number of key-value heads for GQA

     

     

    class RotaryPositionEncoding(nn.Module):

        “”“Rotary position encoding.”“”

     

        def __init__(self, dim: int, max_position_embeddings: int) -> None:

            “”“Initialize the RotaryPositionEncoding module.

     

            Args:

                dim: The hidden dimension of the input tensor to which RoPE is applied

                max_position_embeddings: The maximum sequence length of the input tensor

            ““”

            super().__init__()

            self.dim = dim

            self.max_position_embeddings = max_position_embeddings

            # compute a matrix of n\theta_i

            N = 10_000.0

            inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))

            inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)

            position = torch.arange(max_position_embeddings)

            sinusoid_inp = torch.outer(position, inv_freq)

            # save cosine and sine matrices as buffers, not parameters

            self.register_buffer(“cos”, sinusoid_inp.cos())

            self.register_buffer(“sin”, sinusoid_inp.sin())

     

        def forward(self, x: Tensor) -> Tensor:

            “”“Apply RoPE to tensor x.

     

            Args:

                x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim)

     

            Returns:

                Output tensor of shape (batch_size, seq_length, num_heads, head_dim)

            ““”

            batch_size, seq_len, num_heads, head_dim = x.shape

            device = x.device

            dtype = x.dtype

            # transform the cosine and sine matrices to 4D tensor and the same dtype as x

            cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)

            sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1)

            # apply RoPE to x

            x1, x2 = x.chunk(2, dim=–1)

            rotated = torch.cat((–x2, x1), dim=–1)

            output = (x * cos) + (rotated * sin)

            return output

     

     

    class LlamaAttention(nn.Module):

        “”“Grouped-query attention with rotary embeddings.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.hidden_size = config.hidden_size

            self.num_heads = config.num_attention_heads

            self.head_dim = self.hidden_size // self.num_heads

            self.num_kv_heads = config.num_key_value_heads  # GQA: H_kv < H_q

     

            # hidden_size must be divisible by num_heads

            assert (self.head_dim * self.num_heads) == self.hidden_size

     

            # Linear layers for Q, K, V projections

            self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)

            self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

            self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

            self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

     

        def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

            bs, seq_len, dim = hidden_states.size()

     

            # Project inputs to Q, K, V

            query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)

            key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

            value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

     

            # Apply rotary position embeddings

            query_states = rope(query_states)

            key_states = rope(key_states)

     

            # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention

            query_states = query_states.transpose(1, 2)

            key_states = key_states.transpose(1, 2)

            value_states = value_states.transpose(1, 2)

     

            # Use PyTorch’s optimized attention implementation

            # setting is_causal=True is incompatible with setting explicit attention mask

            attn_output = F.scaled_dot_product_attention(

                query_states,

                key_states,

                value_states,

                attn_mask=attn_mask,

                dropout_p=0.0,

                enable_gqa=True,

            )

     

            # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output

            attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)

            attn_output = self.o_proj(attn_output)

            return attn_output

     

     

    class LlamaMLP(nn.Module):

        “”“Feed-forward network with SwiGLU activation.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            # Two parallel projections for SwiGLU

            self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

            self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

            self.act_fn = F.silu  # SwiGLU activation function

            # Project back to hidden size

            self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

     

        def forward(self, x: Tensor) -> Tensor:

            # SwiGLU activation: multiply gate and up-projected inputs

            gate = self.act_fn(self.gate_proj(x))

            up = self.up_proj(x)

            return self.down_proj(gate * up)

     

     

    class LlamaDecoderLayer(nn.Module):

        “”“Single transformer layer for a Llama model.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

            self.self_attn = LlamaAttention(config)

            self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

            self.mlp = LlamaMLP(config)

     

        def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

            # First residual block: Self-attention

            residual = hidden_states

            hidden_states = self.input_layernorm(hidden_states)

            attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)

            hidden_states = attn_outputs + residual

     

            # Second residual block: MLP

            residual = hidden_states

            hidden_states = self.post_attention_layernorm(hidden_states)

            hidden_states = self.mlp(hidden_states) + residual

            return hidden_states

     

     

    class LlamaModel(nn.Module):

        “”“The full Llama model without any pretraining heads.”“”

     

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.rotary_emb = RotaryPositionEncoding(

                config.hidden_size // config.num_attention_heads,

                config.max_position_embeddings,

            )

     

            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

            self.layers = nn.ModuleList([

                LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)

            ])

            self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5)

     

        def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

            # Convert input token IDs to embeddings

            hidden_states = self.embed_tokens(input_ids)

            # Process through all transformer layers, then the final norm layer

            for layer in self.layers:

                hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)

            hidden_states = self.norm(hidden_states)

            # Return the final hidden states

            return hidden_states

     

     

    class LlamaForPretraining(nn.Module):

        def __init__(self, config: LlamaConfig) -> None:

            super().__init__()

            self.base_model = LlamaModel(config)

            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

     

        def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

            hidden_states = self.base_model(input_ids, attn_mask)

            return self.lm_head(hidden_states)

     

     

    def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor:

        “”“Create a causal mask for self-attention.

     

        Args:

            batch: Batch of sequences, shape (batch_size, seq_len)

            dtype: Data type of the mask

     

        Returns:

            Causal mask of shape (seq_len, seq_len)

        ““”

        batch_size, seq_len = batch.shape

        mask = torch.full((seq_len, seq_len), float(“-inf”), device=batch.device, dtype=dtype) \

                    .triu(diagonal=1)

        return mask

     

     

    def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor:

        “”“Create a padding mask for a batch of sequences for self-attention.

     

        Args:

            batch: Batch of sequences, shape (batch_size, seq_len)

            padding_token_id: ID of the padding token

            dtype: Data type of the mask

     

        Returns:

            Padding mask of shape (batch_size, 1, seq_len, seq_len)

        ““”

        padded = torch.zeros_like(batch, device=batch.device, dtype=dtype) \

                      .masked_fill(batch == padding_token_id, float(“-inf”))

        mask = padded[:,:,None] + padded[:,None,:]

        return mask[:, None, :, :]

     

     

    # Generator function to create padded sequences of fixed length

    class PretrainingDataset(torch.utils.data.Dataset):

        def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,

                     seq_length: int):

            self.dataset = dataset

            self.tokenizer = tokenizer

            self.seq_length = seq_length

            self.bot = tokenizer.token_to_id(“[BOT]”)

            self.eot = tokenizer.token_to_id(“[EOT]”)

            self.pad = tokenizer.token_to_id(“[PAD]”)

     

        def __len__(self):

            return len(self.dataset)

     

        def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:

            “”“Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens

            are added. Clipped and padded to the sequence length.

            ““”

            seq = self.dataset[index][“text”]

            tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]

            # pad to target sequence length

            toklen = len(tokens)

            if toklen < self.seq_length+1:

                pad_length = self.seq_length+1 – toklen

                tokens += [self.pad] * pad_length

            # return the sequence

            x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64)

            y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64)

            return x, y

     

     

    def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:

        dist.barrier()

        load(

            {“model”: model, “optimizer”: optimizer},

            checkpoint_id=“checkpoint-dist”,

            planner=DefaultLoadPlanner(allow_partial_load=True),  # ignore keys for RoPE buffer

        )

        scheduler.load_state_dict(

            torch.load(“checkpoint-dist/lrscheduler.pt”, map_location=device),

        )

        dist.barrier()

     

     

    def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:

        dist.barrier()

        save(

            {“model”: model, “optimizer”: optimizer},

            checkpoint_id=“checkpoint-dist”,

        )

        if dist.get_rank() == 0:

            torch.save(scheduler.state_dict(), “checkpoint-dist/lrscheduler.pt”)

        dist.barrier()

     

     

    # Load the tokenizer and dataset

    tokenizer = tokenizers.Tokenizer.from_file(“bpe_50K.json”)

    dataset = datasets.load_dataset(“HuggingFaceFW/fineweb”, “sample-10BT”, split=“train”)

     

    # Initialize the distributed environment

    dist.init_process_group(backend=“nccl”, timeout=datetime.timedelta(seconds=60))

    local_rank = int(os.environ[“LOCAL_RANK”])

    device = torch.device(f“cuda:{local_rank}”)

    rank = dist.get_rank()

    world_size = dist.get_world_size()

    print(f“World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}”)

     

    # Initialize the mesh for tensor parallelism

    n_tensor_parallel = 2

    assert world_size % n_tensor_parallel == 0, “Expect world size to be divisible by number of tensor parallel GPUs”

    mesh = dist.device_mesh.init_device_mesh(

        “cuda”,

        (world_size // n_tensor_parallel, n_tensor_parallel),

        mesh_dim_names=(“dp”, “tp”),

    )

    print(f“({rank}) Mesh: {mesh}, DP size: {mesh[‘dp’].size()}, TP size: {mesh[‘tp’].size()}, DP local rank: {mesh[‘dp’].get_local_rank()}, TP local rank: {mesh[‘tp’].get_local_rank()}”)

     

    # Create pretraining model on meta device, on all ranks

    with torch.device(“meta”):

        model_config = LlamaConfig()

        model = LlamaForPretraining(model_config)

     

    # Set up tensor parallelism on each transformer block in the base model

    tp_plan = {

        “input_layernorm”: SequenceParallel(),

        “self_attn”: PrepareModuleInput(

            input_layouts=Shard(dim=1),  # only one position arg will be used

            desired_input_layouts=Replicate(),

        ),

        # Q/K projections output will be used with RoPE, need to be replicated

        # Q/K/V output will be used with GQA, also need to be replicated

        “self_attn.q_proj”: ColwiseParallel(output_layouts=Replicate()),

        “self_attn.k_proj”: ColwiseParallel(output_layouts=Replicate()),

        “self_attn.v_proj”: ColwiseParallel(output_layouts=Replicate()),

        “self_attn.o_proj”: RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),

        “post_attention_layernorm”: SequenceParallel(),

        “mlp”: PrepareModuleInput(

            input_layouts=Shard(dim=1),

            desired_input_layouts=Replicate(),

        ),

        “mlp.gate_proj”: ColwiseParallel(),

        “mlp.up_proj”: ColwiseParallel(),

        “mlp.down_proj”: RowwiseParallel(output_layouts=Shard(1)),

    }

    for layer in model.base_model.layers:

        parallelize_module(layer, mesh[“tp”], tp_plan)

     

    # Set up tensor parallelism on the embedding and output norm layers in the base model

    # and the prediction head in the top-level model

    tp_plan = {

        “base_model.embed_tokens”: RowwiseParallel(

            input_layouts=Replicate(),

            output_layouts=Shard(1),

        ),

        “base_model.norm”: SequenceParallel(),

        “lm_head”: ColwiseParallel(

            input_layouts=Shard(1),

            # output_layouts=Replicate(), # only if not using loss parallel

            use_local_output=False,  # Keep DTensor output for loss parallel

        ),

    }

    parallelize_module(model, mesh[“tp”], tp_plan)

     

    # Convert tensor-parallelized model to FSDP2, must shard every component

    # shard across the “dp” dimension of the mesh

    for layer in model.base_model.layers:

        fully_shard(layer, mesh=mesh[“dp”])

    fully_shard(model.base_model, mesh=mesh[“dp”])

    fully_shard(model, mesh=mesh[“dp”])

     

    def reset_all_weights(model: nn.Module) -> None:

        “”“Initialize all weights of the model after moving it away from meta device.”“”

        @torch.no_grad()

        def weight_reset(m: nn.Module):

            reset_parameters = getattr(m, “reset_parameters”, None)

            if callable(reset_parameters):

                m.reset_parameters()

     

        # Applies fn recursively to model itself and all of model.children()

        model.apply(fn=weight_reset)

     

    torch.manual_seed(42)

    model.to_empty(device=device)

    reset_all_weights(model)

    assert isinstance(model, FSDPModule), f“Expected FSDPModule, got {type(model)}”

     

    # Training parameters

    epochs = 3

    learning_rate = 1e–3

    batch_size = 64 // mesh[“dp”].size()

    seq_length = 512

    num_warmup_steps = 1000

    PAD_TOKEN_ID = tokenizer.token_to_id(“[PAD]”)

    model.train()

     

    # DataLoader, optimizer, scheduler, and loss function

    # Sampler is needed to shard the dataset across world size

    dataset = PretrainingDataset(dataset, tokenizer, seq_length)

    sampler = DistributedSampler(

        dataset, shuffle=False, drop_last=True,

        num_replicas=mesh[“dp”].size(),

        rank=mesh[“dp”].get_local_rank(),

    )

    dataloader = torch.utils.data.DataLoader(

        dataset,

        sampler=sampler,

        batch_size=batch_size,

        pin_memory=True,  # optional

        shuffle=False,

        num_workers=2,

        prefetch_factor=2,

    )

    num_training_steps = len(dataloader) * epochs

     

    optimizer = torch.optim.AdamW(

        model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1,

    )

    warmup_scheduler = lr_scheduler.LinearLR(

        optimizer,

        start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps,

    )

    cosine_scheduler = lr_scheduler.CosineAnnealingLR(

        optimizer,

        T_max=num_training_steps – num_warmup_steps,

        eta_min=0,

    )

    scheduler = lr_scheduler.SequentialLR(

        optimizer,

        schedulers=[warmup_scheduler, cosine_scheduler],

        milestones=[num_warmup_steps],

    )

    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)

     

    # if checkpoint-dist dir exists, load the checkpoint to model and optimizer

    if os.path.exists(“checkpoint-dist”):

        load_checkpoint(model, optimizer, scheduler)

     

    # start training

    print(f“({rank}) Starting training”)

    for epoch in range(epochs):

        pbar = tqdm.tqdm(dataloader, desc=f“({rank}) Epoch {epoch+1}/{epochs}”)

        for batch_id, batch in enumerate(pbar):

            if batch_id % 1000 == 0:

                save_checkpoint(model, optimizer, scheduler)

            # Explicit prefetching before sending any data to model

            model.unshard()

            # Get batched data, move from CPU to GPU

            input_ids, target_ids = batch

            input_ids = input_ids.to(device)

            target_ids = target_ids.to(device)

            # create attention mask: causal mask + padding mask

            attn_mask = create_causal_mask(input_ids) + \

                        create_padding_mask(input_ids, PAD_TOKEN_ID)

            # Extract output from model

            logits = model(input_ids, attn_mask)

            optimizer.zero_grad()

            with loss_parallel():

                # Compute loss: cross-entropy between logits and target, ignoring padding tokens

                loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1))

                # Backward with loss on DTensor

                loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            scheduler.step()

            pbar.set_postfix(loss=loss.item())

            pbar.update(1)

        pbar.close()

     

    # Save the model

    save_checkpoint(model, optimizer, scheduler)

     

    # Clean up the distributed environment

    dist.destroy_process_group()

    GPUs Large Model Multiple Parallelism Tensor train
    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    Awais
    • Website

    Related Posts

    Generalizing Real-World Robot Manipulation via Generative Visual Transfer

    March 17, 2026

    CLAG: Adaptive Memory Organization via Agent-Driven Clustering for Small Language Model Agents

    March 17, 2026

    Follow the AI Footpaths | Towards Data Science

    March 17, 2026

    Frequency-Aware Planning and Execution Framework for All-in-One Image Restoration

    March 17, 2026

    Hallucinations in LLMs Are Not a Bug in the Data

    March 16, 2026

    Visual Generalization in Reinforcement Learning via Dynamic Object Tokens

    March 16, 2026
    Leave A Reply Cancel Reply

    Top Posts

    At Least 32 People Dead After a Mine Bridge Collapsed Due to Overcrowding

    November 17, 20250 Views

    Here’s how I turned a Raspberry Pi into an in-car media server

    November 17, 20250 Views

    Beloved SF cat’s death fuels Waymo criticism

    November 17, 20250 Views
    Don't Miss

    Generalizing Real-World Robot Manipulation via Generative Visual Transfer

    March 17, 2026

    [Submitted on 26 Sep 2025 (v1), last revised 16 Mar 2026 (this version, v2)] Authors:Zhehao…

    LinkedIn updates feed algorithm with LLM-powered ranking and retrieval

    March 17, 2026

    Trust Is The New Ranking Factor

    March 17, 2026

    CLAG: Adaptive Memory Organization via Agent-Driven Clustering for Small Language Model Agents

    March 17, 2026
    Stay In Touch
    • Facebook
    • YouTube
    • TikTok
    • WhatsApp
    • Twitter
    • Instagram
    Latest Reviews

    What incrementality really means in affiliate marketing

    March 17, 2026

    3 CMS Platforms Control 73% Of The Market & Shape Technical SEO Defaults

    March 17, 2026
    Most Popular

    13 Trending Songs on TikTok in Nov 2025 (+ How to Use Them)

    November 18, 20257 Views

    How to watch the 2026 GRAMMY Awards online from anywhere

    February 1, 20263 Views

    Corporate Reputation Management Strategies | Sprout Social

    November 19, 20252 Views
    Our Picks

    At Least 32 People Dead After a Mine Bridge Collapsed Due to Overcrowding

    November 17, 2025

    Here’s how I turned a Raspberry Pi into an in-car media server

    November 17, 2025

    Beloved SF cat’s death fuels Waymo criticism

    November 17, 2025

    Subscribe to Updates

    Get the latest creative news from FooBar about art, design and business.

    Facebook X (Twitter) Instagram Pinterest YouTube Dribbble
    • About Us
    • Contact Us
    • Privacy Policy
    • Terms & Conditions
    • Disclaimer

    © 2025 skytik.cc. All rights reserved.

    Type above and press Enter to search. Press Esc to cancel.