Multi-GPU Showcase with persist#

Author: Severin Dicks Copyright: scverse

Overview#

In this notebook, we showcase the multi-GPU computation capabilities of rapids-singlecell
using Dask, enabling the analysis of 11 million cells at unprecedented speed.

This notebook is run on a DGX system with 8 NVIDIA H100 GPUs, demonstrating how Dask can efficiently distribute computations across multiple GPUs.

Key Advantages of Multi-GPU Computation#

By leveraging Dask and RAPIDS, we can:

  • Process massive single-cell datasets without exceeding memory limits.

  • Fully utilize all available GPUs, scaling performance across multiple devices.

  • Enable chunk-based execution, efficiently managing memory by loading only necessary data.

Combining Multi-GPU with Out-of-Core Processing#

  • Multi-GPU processing and out-of-core execution can be combined to analyze even larger datasets that exceed GPU memory.

  • However, in this notebook, we focus purely on multi-GPU scaling without out-of-core execution.

This approach significantly accelerates large-scale single-cell analysis,
making it feasible on high-performance hardware like DGX systems,
while also being adaptable to multi-GPU workstations.

import dask
import time
import gc

from dask_cuda import LocalCUDACluster
from dask.distributed import Client

Initializing a Multi-GPU Dask Cluster for RAPIDS#

To fully utilize all 8 NVIDIA H100 GPUs on the DGX system,
we initialize a multi-GPU Dask cluster and configure RAPIDS Memory Manager (RMM)
for efficient memory handling across GPUs.

Setting Up Memory Management with RMM#

RAPIDS RMM (RAPIDS Memory Manager) helps optimize GPU memory usage by enabling managed memory,
which improves memory efficiency when working with large-scale datasets.

Launching a Multi-GPU Dask Cluster#

We create a Dask CUDA cluster that utilizes all 8 GPUs for preprocessing and analysis.

Additional parameters for LocalCUDACluster#

  • CUDA_VISIBLE_DEVICES=preprocessing_gpus: selects GPUs to use (e.g., "0,1,2,3,4,5,6,7").

  • threads_per_worker=10: CPU threads per GPU worker; tune for your workload and I/O.

  • protocol=”ucx”: enables UCX for high-throughput GPU-aware communication (NVLink/InfiniBand/RDMA).

  • rmm_pool_size=”10GB”: initial per-worker RAPIDS Memory Manager (RMM) pool; reduces allocation overhead.

  • rmm_maximum_pool_size=”110GB”: maximum pool growth per worker; allows RMM to expand up to this cap.

  • rmm_allocator_external_lib_list=”cupy”: integrates CuPy with RMM so CuPy allocations come from the pool.

  • Client(cluster): attaches the Dask client to the cluster (dashboard link available when running).

%%time
preprocessing_gpus="0,1,2,3,4,5,6,7"
cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=preprocessing_gpus,
                           threads_per_worker=10,
                           protocol="ucx",
                           rmm_pool_size= "10GB",
                           rmm_maximum_pool_size = "110GB",
                           rmm_allocator_external_lib_list= "cupy",
                          )

client = Client(cluster)

client
CPU times: user 13.1 s, sys: 6.41 s, total: 19.5 s
Wall time: 23.2 s

Client

Client-f1178a8c-bfbd-11f0-ae39-6cfe5490ac50

Connection method: Cluster object Cluster type: dask_cuda.LocalCUDACluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

import rapids_singlecell as rsc
import anndata as ad

Loading Large Datasets into AnnData with Dask#

To efficiently handle large-scale single-cell datasets, we load data directly from an HDF5 (h5) or Zarr file into an AnnData object using Dask arrays. This enables lazy loading, allowing data to be processed in chunks without exceeding memory limits.

We achieve this using read_elem_as_dask, which loads the expression matrix (X) as a Dask array

from packaging.version import parse as parse_version

if parse_version(ad.__version__) < parse_version("0.12.0rc1"):
    from anndata.experimental import read_elem_as_dask as read_dask
else:
    from anndata.experimental import read_elem_lazy as read_dask
import zarr

SPARSE_CHUNK_SIZE = 50_000
data_pth = "/home/scratch.sdicks_gpu/git/rapids_singlecell-notebooks/zarr/cell_atlas.zarr" #11Million Cells
#data_pth = "zarr/nvidia_1.3M.zarr" #1.3Million Cells

f = zarr.open(data_pth)
X = f["X"]
shape = X.attrs["shape"]
adata = ad.AnnData(
    X = read_dask(X, (SPARSE_CHUNK_SIZE, shape[1])),
    obs = ad.io.read_elem(f["obs"]),
    var = ad.io.read_elem(f["var"])
)

Transferring AnnData to GPU and Persisting Data#

To leverage multi-GPU acceleration, we transfer the AnnData object to GPU memory
and persist its Dask-backed expression matrix for efficient computation.

Step-by-Step Breakdown:

  1. Move AnnData to GPU → rsc.get.anndata_to_GPU(adata)

    • Transfers all numerical data (.X) to GPU memory.

  2. Persist the Expression Matrix → adata.X = adata.X.persist()

    • Keeps adata.X in memory across Dask workers, avoiding redundant recomputation.

  3. Optimize Chunking → adata.X.compute_chunk_sizes()

    • Computes the exact chunk sizes for optimal Dask scheduling and memory usage.

%%time
rsc.get.anndata_to_GPU(adata)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()
CPU times: user 21 s, sys: 974 ms, total: 21.9 s
Wall time: 21.8 s
Array Chunk
Shape (11441407, 45854) (50000, 45854)
Dask graph 229 chunks in 1 graph layer
Data type float32 cupyx.scipy.sparse._csr.csr_matrix
45854 11441407

Quality Control (QC) Metrics Calculation#

Before proceeding with further analysis, we compute quality control (QC) metrics
to assess dataset quality and filter out low-quality cells or genes.

We use rsc.pp.calculate_qc_metrics() to calculate key QC metrics.

Although we are working with Dask-backed AnnData, this operation requires a synchronization step. This means that Dask computations must be evaluated immediately, so the process is not completely lazy like other out-of-core operations.

t1 = time.time()
%%time
rsc.pp.calculate_qc_metrics(adata)
CPU times: user 6.97 s, sys: 1.03 s, total: 8 s
Wall time: 8.07 s

Filtering Cells and Genes Without Additional Computation#

Instead of using sc.pp.filter_cells and sc.pp.filter_genes,
we apply filtering directly using boolean indexing to avoid extra computation.

Why Use Direct Indexing Instead of Built-in Functions?

  • More Efficient with Dask → Avoids triggering additional computations.

  • Preserves Lazy Execution → Filtering is applied without forcing full dataset evaluation.

  • Copy is Essential → Using .copy() prevents views, which may not work reliably with Dask-backed AnnData.

%%time
adata = adata[(adata.obs["n_genes_by_counts"]<=10000) 
            & (adata.obs["n_genes_by_counts"]>=200)].copy()
adata = adata[:,adata.var["n_cells_by_counts"]>=10].copy()
CPU times: user 44.9 s, sys: 6.16 s, total: 51.1 s
Wall time: 29.7 s

Persisting and Optimizing Chunk Sizes After QC and Subsetting#

After performing quality control (QC) and subsetting the dataset,
we persist the Dask-backed expression matrix and optimize its chunk sizes for efficient multi-GPU execution.
Persisting after filtering ensures that only high-quality, relevant data remains in memory.

%%time
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()
CPU times: user 13.9 s, sys: 592 ms, total: 14.4 s
Wall time: 15.3 s
Array Chunk
Shape (11441244, 41291) (49962, 41291)
Dask graph 229 chunks in 1 graph layer
Data type float32 cupyx.scipy.sparse._csr.csr_matrix
41291 11441244

Log Normalization (Fully Lazy Execution)#

Next, we apply log normalization to scale gene expression values.
This step ensures that differences in sequencing depth across cells do not dominate downstream analysis.

gc.collect()  
1503
%%time
rsc.pp.normalize_total(adata,target_sum = 10000)
rsc.pp.log1p(adata)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()
CPU times: user 414 ms, sys: 18 ms, total: 432 ms
Wall time: 438 ms
Array Chunk
Shape (11441244, 41291) (49962, 41291)
Dask graph 229 chunks in 1 graph layer
Data type float32 cupyx.scipy.sparse._csr.csr_matrix
41291 11441244

Storing the Processed Data in Memory#

After Log Normalization, we persist the updated expression matrix to store the new results in memory for efficient access.

Selecting Highly Variable Genes#

To focus on the most informative features, we identify highly variable genes (HVGs)
using the Cell Ranger method and subset the dataset accordingly.

  • Copy is Essential → Using .copy() prevents views, ensuring the operation works properly with Dask-backed AnnData.

%%time
rsc.pp.highly_variable_genes(adata,n_top_genes=5000, flavor="cell_ranger")
CPU times: user 986 ms, sys: 334 ms, total: 1.32 s
Wall time: 1.34 s
%%time
adata = adata[:,adata.var.highly_variable].copy()
CPU times: user 20.7 s, sys: 641 ms, total: 21.3 s
Wall time: 11.5 s

Rechunking the Expression Matrix for Multi-GPU Execution#

To optimize performance across 8 GPUs, we rechunk the expression matrix (adata.X)
so that each GPU processes an equal portion of the dataset.

n_rows = adata.shape[0]
n_cols = adata.shape[1]
rows_per_worker = (n_rows+7-1)//7
adata.X = adata.X.rechunk((rows_per_worker, n_cols)).persist()

adata.X.compute_chunk_sizes()
Array Chunk
Shape (11441244, 5000) (1634464, 5000)
Dask graph 7 chunks in 1 graph layer
Data type float32 cupyx.scipy.sparse._csr.csr_matrix
5000 11441244

Scaling Gene Expression (Requires Synchronization)#

To standardize gene expression values, we apply feature scaling,
We also persist the results to ensure fast accessibility

%%time
rsc.pp.scale(adata, zero_center= False)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()
CPU times: user 691 ms, sys: 237 ms, total: 927 ms
Wall time: 885 ms
Array Chunk
Shape (11441244, 5000) (1634464, 5000)
Dask graph 7 chunks in 1 graph layer
Data type float32 cupyx.scipy.sparse._csr.csr_matrix
5000 11441244

Principal Component Analysis (PCA) on GPU#

To reduce dimensionality while preserving meaningful variation,
we perform Principal Component Analysis (PCA) using GPU acceleration.

Finalizing the Transformation with .compute() * After computing the principal components, the data remains lazy (Dask CuPy array). * Calling .compute() on adata.obsm["X_pca"] performs the final transformation, projecting the data onto the computed PCs and materializing the result as a fully computed CuPy array.

%%time
rsc.pp.pca(adata, n_comps = 100,mask_var=None)
adata.obsm["X_pca"]=adata.obsm["X_pca"].persist()
adata.obsm["X_pca"].compute_chunk_sizes()
CPU times: user 1.23 s, sys: 458 ms, total: 1.69 s
Wall time: 1.44 s
Array Chunk
Bytes 4.26 GiB 623.50 MiB
Shape (11441244, 100) (1634464, 100)
Dask graph 7 chunks in 1 graph layer
Data type float32 cupy.ndarray
100 11441244
print("Total Time",time.time()-t1)
Total Time 72.30928540229797
%%time
adata.obsm["X_pca"]=adata.obsm["X_pca"].compute()
CPU times: user 6.4 s, sys: 385 ms, total: 6.78 s
Wall time: 6.82 s
%%time
rsc.pp.neighbors(adata, n_neighbors=15, n_pcs=50, algorithm="mg_ivfflat")
CPU times: user 1min 9s, sys: 13 s, total: 1min 22s
Wall time: 1min 26s
%%time
rsc.tl.umap(adata, min_dist=0.3)
CPU times: user 12 s, sys: 1.42 s, total: 13.4 s
Wall time: 11.4 s
%%time
rsc.tl.leiden(adata, resolution=1.0)
CPU times: user 19.3 s, sys: 4.23 s, total: 23.6 s
Wall time: 19.1 s