is a to Optimizing Data Transfer in AI/ML Workloads where we demonstrated the use of NVIDIA Nsight™ Systems (nsys) in studying and solving the common data-loading bottleneck — occurrences where the GPU idles while it waits for input data from the CPU. In this post we focus our attention on data travelling in the opposite direction, from the GPU device to the CPU host. More specifically, we address AI/ML inference workloads where the size of the output being returned by the model is relatively high. Common examples include: 1) running a scene segmentation (per-pixel labeling) model on batches of high-resolution images and 2) capturing high dimensional feature embeddings of input sequences using an encoder model (e.g., to create a vector database). Both examples involve executing a model on an input batch and then copying the output tensor from the GPU to the CPU for additional processing, storage, and/or over-the-network communication.
GPU-to-CPU memory copies of the model output typically receive much less attention in optimization tutorials than the CPU-to-GPU copies that feed the model (e.g., see here). But their potential impact on model efficiency and execution costs can be just as detrimental. Moreover, while optimizations to CPU-to-GPU data-loading are well documented and easy to implement, optimizing data copy in the opposite direction requires a bit more manual labor.
In this post we will apply the same strategy we used in our previous post: We will define a toy model and use nsys profiler to identify and solve performance bottlenecks. We will run our experiments on an Amazon EC2 g6e.2xlarge instance (with an NVIDIA L40S GPU) running an AWS Deep Learning (Ubuntu 24.04) AMI with PyTorch (2.8), nsys-cli profiler (version 2025.6.1), and the NVIDIA Tools Extension (NVTX) library.
Disclaimers
The code we will share is intended for demonstrative purposes; please do not rely on its correctness or optimality. Please do not interpret our use of any library, tool, or platform, as an endorsement of its use. The impact of the optimizations we will cover can vary greatly based on the details of the model and the runtime environment. Please be sure to assess their effect on your own use case before integrating their use.
Many thanks to Yitzhak Levi and Gilad Wasserman for their contributions to this post.
A Toy PyTorch Model
We introduce a batched inference script that performs image segmentation on a synthetic dataset using a DeepLabV3 model with a ResNet-50 backbone. The model outputs are copied to the CPU for post processing and storage. We wrap the different portions of the inference step with color-coded nvtx annotations:
import time, torch, nvtx
from torch.utils.data import Dataset, DataLoader
from torch.cuda import profiler
from torchvision.models.segmentation import deeplabv3_resnet50
DEVICE = "cuda"
WARMUP_STEPS = 10
PROFILE_STEPS = 3
COOLDOWN_STEPS = 1
TOTAL_STEPS = WARMUP_STEPS + PROFILE_STEPS + COOLDOWN_STEPS
BATCH_SIZE = 64
TOTAL_SAMPLES = TOTAL_STEPS * BATCH_SIZE
IMG_SIZE = 512
N_CLASSES = 21
NUM_WORKERS = 8
ASYNC_DATALOAD = True
# A synthetic Dataset with random images
class FakeDataset(Dataset):
def __len__(self):
return TOTAL_SAMPLES
def __getitem__(self, index):
img = torch.randn((3, IMG_SIZE, IMG_SIZE))
return img
# utility class for prefetching data to GPU
class DataPrefetcher:
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.next_batch = None
self.preload()
def preload(self):
try:
data = next(self.loader)
with torch.cuda.stream(self.stream):
next_data = data.to(DEVICE, non_blocking=ASYNC_DATALOAD)
self.next_batch = next_data
except:
self.next_batch = None
def __iter__(self):
return self
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
data = self.next_batch
self.preload()
return data
model = deeplabv3_resnet50(weights_backbone=None).to(DEVICE).eval()
data_loader = DataLoader(
FakeDataset(),
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=ASYNC_DATALOAD
)
data_iter = DataPrefetcher(data_loader)
def synchronize_all():
torch.cuda.synchronize()
def to_cpu(output):
return output.cpu()
def process_output(batch_id, logits):
# do some post processing on output
with open('/dev/null', 'wb') as f:
f.write(logits.numpy().tobytes())
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu = to_cpu(output['out'])
with nvtx.annotate("process output", color="cyan"):
process_output(i, output_cpu)
total_time = end_time - start_time
throughput = PROFILE_STEPS / total_time
print(f"Throughput: {throughput:.2f} steps/sec")Note the inclusion of all of the CPU-to-GPU data-loading optimizations discussed in our previous post.
We run the following command to capture an nsys profile trace:
nsys profile \
--capture-range=cudaProfilerApi \
--trace=cuda,nvtx,osrt \
--output=baseline \
python batch_infer.pyThis results in a baseline.nsys-rep trace file that we copy over to our development machine for analysis.
To measure the inference throughput, we increase the number of steps to 100. The average throughput of our baseline experiment is 0.45 steps-per-second. In the following sections we will use the nsys profile traces to incrementally improve this result.
Baseline Performance Analysis
The image below shows the nsys profile trace of our baseline experiment:

In the GPU section we see the following recurring pattern:
- A block of kernel compute (in light blue) that runs for ~520 milliseconds.
- A small block of host-to-device memory copy (in green) that runs in parallel to the kernel compute. This concurrency was achieved using the optimizations discussed in our previous post.
- A block of device-to-host memory copy (in red) that runs for ~750 milliseconds.
- A long period (~940 milliseconds) of GPU idle time (white space) between every two steps.
Looking at the NVTX bar of the CPU section, we can see that the whitespace aligns perfectly with the “process output” block (in cyan). In our initial implementation, both the model execution and the output storage function run in the same single process in a sequential manner. This leads to significant idle time on the GPU as the CPU waits for the storage function to return before feeding the GPU the next batch.
Optimization 1: Multi-Worker Output Processing
The first step we take is to run the output storage function in parallel worker processes. We took a similar step in our previous post when we moved the input batch preparation sequence to dedicated workers. However, whereas there we were able to automate multi-process data loading by simply setting the num_workers argument of the DataLoader class to a non-zero value, applying multi-worker output-processing requires a manual implementation. Here we choose a simple solution for demonstrative purposes. This should be customized per your needs and design preferences.
PyTorch Multiprocessing
We implement a producer-consumer strategy using PyTorch’s built-in multiprocessing package, torch.multiprocessing. We define a queue for storing output batches and multiple consumer workers that process the batches on the queue. We modify our inference loop to put the output buffers in the output queue. We also update the synchronize_all() utility to drain the queue and append a cleanup sequence at the end of the script.
The following block of code contains our initial implementation. As we will see in the next sections, this will require some tuning in order to reach maximum performance.
import torch.multiprocessing as mp
POSTPROC_WORKERS = 8 # tune for optimal throughput
output_queue = mp.JoinableQueue(maxsize=POSTPROC_WORKERS)
def output_worker(in_q):
while True:
item = in_q.get()
if item is None: break # signal to shut down
batch_id, batch_preds = item
process_output(batch_id, batch_preds)
in_q.task_done()
processes = []
for _ in range(POSTPROC_WORKERS):
p = mp.Process(target=output_worker, args=(output_queue,))
p.start()
processes.append(p)
def synchronize_all():
torch.cuda.synchronize()
output_queue.join() # drain queue
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu = to_cpu(output['out'])
with nvtx.annotate("queue output", color="cyan"):
output_queue.put((i, output_cpu))
total_time = end_time - start_time
throughput = PROFILE_STEPS / total_time
print(f"Throughput: {throughput:.2f} steps/sec")
# cleanup
for _ in range(POSTPROC_WORKERS):
output_queue.put(None)The multi-worker output processing optimization results in a throughput of 0.71 steps-per-second — a 58% increase over our baseline results.
Rerunning the nsys command results in the following profile trace:

We can see that the size of the block of whitespace has dropped considerably (from ~940 milliseconds to ~50). Were we to zoom in on the remaining whitespace, we would find it aligned to an “munmap” operation. In our previous post, the same finding informed our asynchronous data copy optimization. But this time we take an intermediate memory-optimization step in the form of a pre-allocated pool of buffers.
Optimization 2: Buffer Pool Pre-allocation
In order to reduce the overhead of allocating and managing a new CPU tensor on every iteration, we initialize a pool of tensors pre-allocated in shared memory and define a second queue to manage their use.
Our updated code appears below:
shape = (BATCH_SIZE, N_CLASSES, IMG_SIZE, IMG_SIZE)
buffer_pool = [torch.empty(shape).share_memory_()
for _ in range(POSTPROC_WORKERS)]
buf_queue = mp.Queue()
for i in range(POSTPROC_WORKERS):
buf_queue.put(i)
def output_worker(buffer_pool, in_q, buf_q):
while True:
item = in_q.get()
if item is None: break # signal to shut down
batch_id, buf_id = item
process_output(batch_id, buffer_pool[buf_id])
buf_q.put(buf_id)
in_q.task_done()
processes = []
for _ in range(POSTPROC_WORKERS):
p = mp.Process(target=output_worker,
args=(buffer_pool,output_queue,buf_queue))
p.start()
processes.append(p)
def to_cpu(output):
buf_id = buf_queue.get()
output_cpu = buffer_pool[buf_id]
output_cpu.copy_(output)
return output_cpu, buf_id
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu, buf_id = to_cpu(output['out'])
with nvtx.annotate("queue output", color="cyan"):
output_queue.put((i, buf_id))Following these changes, the inference throughput jumps to 1.51 — a more than 2X speed-up over our previous result.
The new profile trace appears below:

Not only has the whitespace all but disappeared, but the CUDA DtoH memory operation (in red) has dropped from ~750 milliseconds to ~110. Presumably, the large GPU-to-CPU data copy involved quite a bit of memory-management overhead that we have removed by implementing a dedicated buffer pool.
Despite the considerable improvement, if we zoom in we will find that there remains around ~0.5 milliseconds of whitespace that is caused by the synchronicity of the GPU-to-CPU copy command — so long as the copy has not completed the CPU does not trigger the kernel computation of the next batch.
Optimization 3: Asynchronous Data Copy
Our third optimization is to change the device-to-host copy to be asynchronous. As before, we will find that implementing this change is more difficult than in the CPU-to-GPU direction.
The first step is to pass non_blocking=True to the GPU-to-CPU copy command.
def to_cpu(output):
buf_id = buf_queue.get()
output_cpu = buffer_pool[buf_id]
output_cpu.copy_(output, non_blocking=True)
return output_cpu, buf_idHowever, as we saw in our previous post, this change will not have a meaningful impact unless we modify our tensors to use pinned memory:
shape = (BATCH_SIZE, N_CLASSES, IMG_SIZE, IMG_SIZE)
buffer_pool = [torch.empty(shape, pin_memory=True).share_memory_()
for _ in range(POSTPROC_WORKERS)]Crucially, if we apply only these two changes to our script, the throughput would increase but the output may be corrupted (e.g., see here). We need an event-based mechanism for identifying each time a GPU-to-CPU copy has been completed so that we can proceed with the output data processing. (Note, that this was not required when making the CPU-to-GPU copy asynchronous. Because a single GPU stream processes commands sequentially, the kernel computation only starts when the copy has completed. Synchronization was only required when introducing a second stream.)
To implement the notification mechanism, we define a pool of CUDA events and an additional queue for managing their use. We further define a listener thread for monitoring the state of events on the queue and populating the output queue once the copies are complete.
import threading, queue
event_pool = [torch.cuda.Event() for _ in range(POSTPROC_WORKERS)]
event_queue = queue.Queue()
def event_monitor(event_pool, event_queue, output_queue):
while True:
item = event_queue.get()
if item is None: break
batch_id, buf_idx = item
event_pool[buf_idx].synchronize()
output_queue.put((batch_id, buf_idx))
event_queue.task_done()
monitor = threading.Thread(target=event_monitor,
args=(event_pool, event_queue, output_queue))
monitor.start()The updated inference sequence consists of the following steps:
- Get an input batch that was prefetched to the GPU.
- Execute the model on the input batch to get an output tensor on the GPU.
- Request a vacant CPU buffer from the buffer queue and use it to trigger an asynchronous data copy. Configure an event to trigger when the copy is complete and push the event to the event-queue.
- The monitor thread waits for the event to trigger and then pushes the output tensor to the output queue for processing.
- A worker thread pulls the output tensor from the queue and saves it to disk. It then releases the buffer back to the buffer queue.
The updated code appears below.
def synchronize_all():
torch.cuda.synchronize()
event_queue.join()
output_queue.join()
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu, buf_id = to_cpu(output['out'])
with nvtx.annotate("queue CUDA event", color="cyan"):
event_pool[buf_id].record()
event_queue.put((i, buf_id))
total_time = end_time - start_time
throughput = PROFILE_STEPS / total_time
print(f"Throughput: {throughput:.2f} steps/sec")
# cleanup
event_queue.put(None)
for _ in range(POSTPROC_WORKERS):
output_queue.put(None)The resultant throughput is 1.55 steps-per-second.
The new profile trace appears below:

In the NVTX row of the CPU section we can see all of the operations in the inference loop bunched together on left side — implying that they all ran immediately and asynchronously. We also see the event synchronization calls (in light green) running on the dedicated monitor thread. In the GPU section we see that the kernel computation begins immediately after the device-to-host copy has completed.
Our final optimization will focus on improving the parallelization of the kernel and memory operations on the GPU.
Optimization 4: Pipelining Using CUDA Streams
As in our previous post, we wish to take advantage of the independent engines for memory copying (the DMA) and kernel compute (the SMs). We do this by assigning the memory copy to a dedicated CUDA stream:
egress_stream = torch.cuda.Stream()
with torch.inference_mode():
for i in range(TOTAL_STEPS):
if i == WARMUP_STEPS:
synchronize_all()
start_time = time.perf_counter()
profiler.start()
elif i == WARMUP_STEPS + PROFILE_STEPS:
synchronize_all()
profiler.stop()
end_time = time.perf_counter()
with nvtx.annotate(f"Batch {i}", color="blue"):
with nvtx.annotate("get batch", color="red"):
batch = next(data_iter)
with nvtx.annotate("compute", color="green"):
output = model(batch)
# on separate stream
with torch.cuda.stream(egress_stream):
# wait for default stream to complete compute
egress_stream.wait_stream(torch.cuda.default_stream())
with nvtx.annotate("copy to CPU", color="yellow"):
output_cpu, buf_id = to_cpu(output['out'])
with nvtx.annotate("queue CUDA event", color="cyan"):
event_pool[buf_id].record(egress_stream)
event_queue.put((i, buf_id))This results in a throughput of 1.85 steps per second — an additional 19.3% improvement over our previous experiment.
The final profile trace appears below:

In the GPU section we see a continuous block of kernel compute (in light blue) with both the host-to-device (in light green) and device-to-host (in purple) running in parallel. Our inference loop is now compute-bound, implying that we have exhausted all practical opportunities for data-transfer optimization.
Results
We summarize our results in the following table:

Through the use of nsys profiler we were able to increase efficiency by over 4X. Naturally, the impact of the optimizations we discussed will vary based on the details of the model and runtime environment.
Summary
This concludes the second part of our series of posts on the topic of optimizing data-transfer in AI/ML workloads. Part one focused on host-to-device copies and part two on device-to-host copies. When implemented naively, data-transfer in either direction can lead to significant performance bottlenecks resulting in GPU starvation and increased runtime costs. Using Nsight Systems profiler, we demonstrated how to identify and resolve these bottlenecks and increase runtime efficiency.
Although the optimization of both directions involved similar steps, the implementation details were very different. While optimizing CPU-to-GPU data-transfer is well-supported by PyTorch’s data-loading APIs and required relatively small changes to the execution loop, optimizing the the GPU-to-CPU direction required a bit more software engineering. Importantly, the solutions we put forth in this post were chosen for demonstrative purposes. Your own solution may differ considerably based on your project needs and design preferences.
Having covered both CPU-to-GPU and GPU-to-CPU data copies, we turn our attention to GPU-to-GPU transactions: Stay tuned for a future post on the topic of optimizing data transfer between GPUs in distributed training workloads.


