Documentation for DataModule

scdataloader.datamodule.DataModule

Bases: LightningDataModule

PyTorch Lightning DataModule for loading single-cell data from a LaminDB Collection.

This DataModule provides train/val/test dataloaders with configurable sampling strategies. It combines MappedCollection, Dataset, and Collator to create efficient data pipelines for training single-cell foundation models.

The training dataloader uses weighted random sampling based on class frequencies, validation uses random sampling, and test uses sequential sampling on held-out datasets.

Parameters:
  • collection_name (str) –

    Key of the LaminDB Collection to load.

  • clss_to_weight (List[str], default: ['organism_ontology_term_id'] ) –

    Label columns to use for weighted sampling in the training dataloader. Supports "nnz" for weighting by number of non-zero genes. Defaults to ["organism_ontology_term_id"].

  • weight_scaler (int, default: 10 ) –

    Controls balance between rare and common classes. Higher values lead to more uniform sampling across classes. Set to 0 to disable weighted sampling. Defaults to 10.

  • n_samples_per_epoch (int, default: 2000000 ) –

    Number of samples to draw per training epoch. Defaults to 2,000,000.

  • validation_split (float | int, default: 0.2 ) –

    Proportion (float) or absolute number (int) of samples for validation. Defaults to 0.2.

  • test_split (float | int, default: 0 ) –

    Proportion (float) or absolute number (int) of samples for testing. Uses entire datasets as test sets, rounding to nearest dataset boundary. Defaults to 0.

  • use_default_col (bool, default: True ) –

    Whether to use the default Collator for batch preparation. If False, no collate_fn is applied. Defaults to True.

  • clss_to_predict (List[str], default: ['organism_ontology_term_id'] ) –

    Observation columns to encode as prediction targets. Must include "organism_ontology_term_id". Defaults to ["organism_ontology_term_id"].

  • hierarchical_clss (List[str], default: [] ) –

    Observation columns with hierarchical ontology structure to be processed. Defaults to [].

  • how (str, default: 'random expr' ) –

    Gene selection strategy passed to Collator. One of "most expr", "random expr", "all", "some". Defaults to "random expr".

  • organism_col (str, default: 'organism_ontology_term_id' ) –

    Column name for organism ontology term ID. Defaults to "organism_ontology_term_id".

  • max_len (int, default: 1000 ) –

    Maximum number of genes per sample passed to Collator. Defaults to 1000.

  • replacement (bool, default: True ) –

    Whether to sample with replacement in training. Defaults to True.

  • gene_subset (List[str], default: None ) –

    List of genes to restrict the dataset to. Useful when model only supports specific genes. Defaults to None.

  • tp_name (str, default: None ) –

    Column name for time point or heat diffusion values. Defaults to None.

  • assays_to_drop (List[str], default: ['EFO:0030007'] ) –

    List of assay ontology term IDs to exclude from training. Defaults to ["EFO:0030007"] (ATAC-seq).

  • metacell_mode (float, default: 0.0 ) –

    Probability of using metacell aggregation mode. Cannot be used with get_knn_cells. Defaults to 0.0.

  • get_knn_cells (bool, default: False ) –

    Whether to include k-nearest neighbor cell expression data. Cannot be used with metacell_mode. Defaults to False.

  • store_location (str, default: None ) –

    Directory path to cache sampler indices and labels for faster subsequent loading. Defaults to None.

  • force_recompute_indices (bool, default: False ) –

    Force recomputation of cached indices even if they exist. Defaults to False.

  • sampler_workers (int, default: None ) –

    Number of parallel workers for building sampler indices. Auto-determined based on available CPUs if None. Defaults to None.

  • sampler_chunk_size (int, default: None ) –

    Chunk size for parallel sampler processing. Auto-determined based on available memory if None. Defaults to None.

  • organisms (List[str], default: None ) –

    List of organisms to include. If None, uses all organisms in the dataset. Defaults to None.

  • genedf (DataFrame, default: None ) –

    Gene information DataFrame. If None, loaded automatically. Defaults to None.

  • n_bins (int, default: 0 ) –

    Number of bins for expression discretization. 0 means no binning. Defaults to 0.

  • curiculum (int, default: 0 ) –

    Curriculum learning parameter. If > 0, gradually increases sampling weight balance over epochs. Defaults to 0.

  • start_at (int, default: 0 ) –

    Starting index for resuming inference. Requires same number of GPUs as previous run. Defaults to 0.

  • **kwargs

    Additional arguments passed to PyTorch DataLoader (e.g., batch_size, num_workers, pin_memory).

Attributes:
  • dataset (Dataset) –

    The underlying Dataset instance.

  • classes (dict[str, int]) –

    Mapping from class names to number of categories.

  • train_labels (ndarray) –

    Label array for weighted sampling.

  • idx_full (ndarray) –

    Indices for training samples.

  • valid_idx (ndarray) –

    Indices for validation samples.

  • test_idx (ndarray) –

    Indices for test samples.

  • test_datasets (List[str]) –

    Paths to datasets used for testing.

Raises:
  • ValueError

    If "organism_ontology_term_id" not in clss_to_predict.

  • ValueError

    If both metacell_mode > 0 and get_knn_cells are True.

Example

dm = DataModule( ... collection_name="my_collection", ... batch_size=32, ... num_workers=4, ... max_len=2000, ... ) dm.setup() train_loader = dm.train_dataloader()

Methods:

Name Description
predict_dataloader

Create a DataLoader for prediction over all training data.

setup

Prepare data splits for training, validation, and testing.

test_dataloader

Create the test DataLoader with sequential sampling.

train_dataloader

Create the training DataLoader with weighted random sampling.

val_dataloader

Create the validation DataLoader.

Attributes:
  • decoders

    decoders the decoders for any labels that would have been encoded

  • genes

    genes the genes used in this datamodule

  • labels_hierarchy

    labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy

Source code in scdataloader/datamodule.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def __init__(
    self,
    collection_name: str,
    clss_to_weight: List[str] = ["organism_ontology_term_id"],
    weight_scaler: int = 10,
    n_samples_per_epoch: int = 2_000_000,
    validation_split: float = 0.2,
    test_split: float = 0,
    use_default_col: bool = True,
    # this is for the mappedCollection
    clss_to_predict: List[str] = ["organism_ontology_term_id"],
    hierarchical_clss: List[str] = [],
    # this is for the collator
    how: str = "random expr",
    organism_col: str = "organism_ontology_term_id",
    max_len: int = 1000,
    replacement: bool = True,
    gene_subset: Optional[list[str]] = None,
    tp_name: Optional[str] = None,  # "heat_diff"
    assays_to_drop: List[str] = [
        # "EFO:0008853", #patch seq
        # "EFO:0010961", # visium
        "EFO:0030007",  # ATACseq
        # "EFO:0030062", # slide-seq
    ],
    metacell_mode: float = 0.0,
    get_knn_cells: bool = False,
    store_location: str = None,
    force_recompute_indices: bool = False,
    sampler_workers: int = None,
    sampler_chunk_size: int = None,
    organisms: Optional[str] = None,
    genedf: Optional[pd.DataFrame] = None,
    n_bins: int = 0,
    curiculum: int = 0,
    start_at: int = 0,
    **kwargs,
):
    """
    PyTorch Lightning DataModule for loading single-cell data from a LaminDB Collection.

    This DataModule provides train/val/test dataloaders with configurable sampling strategies.
    It combines MappedCollection, Dataset, and Collator to create efficient data pipelines
    for training single-cell foundation models.

    The training dataloader uses weighted random sampling based on class frequencies,
    validation uses random sampling, and test uses sequential sampling on held-out datasets.

    Args:
        collection_name (str): Key of the LaminDB Collection to load.
        clss_to_weight (List[str], optional): Label columns to use for weighted sampling
            in the training dataloader. Supports "nnz" for weighting by number of
            non-zero genes. Defaults to ["organism_ontology_term_id"].
        weight_scaler (int, optional): Controls balance between rare and common classes.
            Higher values lead to more uniform sampling across classes. Set to 0 to
            disable weighted sampling. Defaults to 10.
        n_samples_per_epoch (int, optional): Number of samples to draw per training epoch.
            Defaults to 2,000,000.
        validation_split (float | int, optional): Proportion (float) or absolute number (int)
            of samples for validation. Defaults to 0.2.
        test_split (float | int, optional): Proportion (float) or absolute number (int)
            of samples for testing. Uses entire datasets as test sets, rounding to
            nearest dataset boundary. Defaults to 0.
        use_default_col (bool, optional): Whether to use the default Collator for batch
            preparation. If False, no collate_fn is applied. Defaults to True.
        clss_to_predict (List[str], optional): Observation columns to encode as prediction
            targets. Must include "organism_ontology_term_id". Defaults to
            ["organism_ontology_term_id"].
        hierarchical_clss (List[str], optional): Observation columns with hierarchical
            ontology structure to be processed. Defaults to [].
        how (str, optional): Gene selection strategy passed to Collator. One of
            "most expr", "random expr", "all", "some". Defaults to "random expr".
        organism_col (str, optional): Column name for organism ontology term ID.
            Defaults to "organism_ontology_term_id".
        max_len (int, optional): Maximum number of genes per sample passed to Collator.
            Defaults to 1000.
        replacement (bool, optional): Whether to sample with replacement in training.
            Defaults to True.
        gene_subset (List[str], optional): List of genes to restrict the dataset to.
            Useful when model only supports specific genes. Defaults to None.
        tp_name (str, optional): Column name for time point or heat diffusion values.
            Defaults to None.
        assays_to_drop (List[str], optional): List of assay ontology term IDs to exclude
            from training. Defaults to ["EFO:0030007"] (ATAC-seq).
        metacell_mode (float, optional): Probability of using metacell aggregation mode.
            Cannot be used with get_knn_cells. Defaults to 0.0.
        get_knn_cells (bool, optional): Whether to include k-nearest neighbor cell
            expression data. Cannot be used with metacell_mode. Defaults to False.
        store_location (str, optional): Directory path to cache sampler indices and
            labels for faster subsequent loading. Defaults to None.
        force_recompute_indices (bool, optional): Force recomputation of cached indices
            even if they exist. Defaults to False.
        sampler_workers (int, optional): Number of parallel workers for building sampler
            indices. Auto-determined based on available CPUs if None. Defaults to None.
        sampler_chunk_size (int, optional): Chunk size for parallel sampler processing.
            Auto-determined based on available memory if None. Defaults to None.
        organisms (List[str], optional): List of organisms to include. If None, uses
            all organisms in the dataset. Defaults to None.
        genedf (pd.DataFrame, optional): Gene information DataFrame. If None, loaded
            automatically. Defaults to None.
        n_bins (int, optional): Number of bins for expression discretization. 0 means
            no binning. Defaults to 0.
        curiculum (int, optional): Curriculum learning parameter. If > 0, gradually
            increases sampling weight balance over epochs. Defaults to 0.
        start_at (int, optional): Starting index for resuming inference. Requires same
            number of GPUs as previous run. Defaults to 0.
        **kwargs: Additional arguments passed to PyTorch DataLoader (e.g., batch_size,
            num_workers, pin_memory).

    Attributes:
        dataset (Dataset): The underlying Dataset instance.
        classes (dict[str, int]): Mapping from class names to number of categories.
        train_labels (np.ndarray): Label array for weighted sampling.
        idx_full (np.ndarray): Indices for training samples.
        valid_idx (np.ndarray): Indices for validation samples.
        test_idx (np.ndarray): Indices for test samples.
        test_datasets (List[str]): Paths to datasets used for testing.

    Raises:
        ValueError: If "organism_ontology_term_id" not in clss_to_predict.
        ValueError: If both metacell_mode > 0 and get_knn_cells are True.

    Example:
        >>> dm = DataModule(
        ...     collection_name="my_collection",
        ...     batch_size=32,
        ...     num_workers=4,
        ...     max_len=2000,
        ... )
        >>> dm.setup()
        >>> train_loader = dm.train_dataloader()
    """
    if "organism_ontology_term_id" not in clss_to_predict:
        raise ValueError(
            "need 'organism_ontology_term_id' in the set of classes at least"
        )
    if metacell_mode > 0 and get_knn_cells:
        raise ValueError(
            "cannot use metacell mode and get_knn_cells at the same time"
        )
    mdataset = Dataset(
        ln.Collection.filter(key=collection_name, is_latest=True).first(),
        clss_to_predict=clss_to_predict,
        hierarchical_clss=hierarchical_clss,
        metacell_mode=metacell_mode,
        get_knn_cells=get_knn_cells,
        store_location=store_location,
        force_recompute_indices=force_recompute_indices,
        genedf=genedf,
    )
    # and location
    self.metacell_mode = bool(metacell_mode)
    self.gene_pos = None
    self.collection_name = collection_name
    if gene_subset is not None:
        tokeep = set(mdataset.genedf.index.tolist())
        gene_subset = [u for u in gene_subset if u in tokeep]
    self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
    # we might want not to order the genes by expression (or do it?)
    # we might want to not introduce zeros and
    if use_default_col:
        kwargs["collate_fn"] = Collator(
            organisms=mdataset.organisms if organisms is None else organisms,
            how=how,
            valid_genes=gene_subset,
            max_len=max_len,
            org_to_id=mdataset.encoder[organism_col],
            tp_name=tp_name,
            organism_name=organism_col,
            class_names=list(self.classes.keys()),
            genedf=genedf,
            n_bins=n_bins,
        )
    self.n_bins = n_bins
    self.validation_split = validation_split
    self.test_split = test_split
    self.dataset = mdataset
    self.replacement = replacement
    self.kwargs = kwargs
    if "sampler" in self.kwargs:
        self.kwargs.pop("sampler")
    self.assays_to_drop = assays_to_drop
    self.n_samples = len(mdataset)
    self.weight_scaler = weight_scaler
    self.n_samples_per_epoch = n_samples_per_epoch
    self.clss_to_weight = clss_to_weight
    self.train_weights = None
    self.train_labels = None
    self.sampler_workers = sampler_workers
    self.sampler_chunk_size = sampler_chunk_size
    self.store_location = store_location
    self.nnz = None
    self.start_at = start_at
    self.idx_full = None
    self.max_len = max_len
    self.test_datasets = []
    self.force_recompute_indices = force_recompute_indices
    self.curiculum = curiculum
    self.valid_idx = []
    self.test_idx = []
    super().__init__()
    print("finished init")

decoders property

decoders the decoders for any labels that would have been encoded

Returns:
  • dict[str, dict[int, str]]

genes property

genes the genes used in this datamodule

Returns:
  • list

labels_hierarchy property

labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy

Returns:
  • dict[str, dict[str, str]]

predict_dataloader

Create a DataLoader for prediction over all training data.

Uses RankShardSampler for distributed inference.

Returns:
  • DataLoader

    Prediction DataLoader instance.

Source code in scdataloader/datamodule.py
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
def predict_dataloader(self):
    """
    Create a DataLoader for prediction over all training data.

    Uses RankShardSampler for distributed inference.

    Returns:
        DataLoader: Prediction DataLoader instance.
    """
    subset = Subset(self.dataset, self.idx_full)
    return DataLoader(
        self.dataset,
        sampler=RankShardSampler(len(subset), start_at=self.start_at),
        **self.kwargs,
    )

setup

Prepare data splits for training, validation, and testing.

This method shuffles the data, computes sample weights for weighted sampling, removes samples from dropped assays, and creates train/val/test splits. Test splits use entire datasets to ensure evaluation on unseen data sources.

Results can be cached to store_location for faster subsequent runs.

Parameters:
  • stage (str, default: None ) –

    Training stage ('fit', 'test', or None for both). Currently not used but kept for Lightning compatibility. Defaults to None.

Returns:
  • List[str]: List of paths to test datasets.

Note

Must be called before using dataloaders. The train/val/test split is deterministic when loading from cache.

Source code in scdataloader/datamodule.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
def setup(self, stage=None):
    """
    Prepare data splits for training, validation, and testing.

    This method shuffles the data, computes sample weights for weighted sampling,
    removes samples from dropped assays, and creates train/val/test splits.
    Test splits use entire datasets to ensure evaluation on unseen data sources.

    Results can be cached to `store_location` for faster subsequent runs.

    Args:
        stage (str, optional): Training stage ('fit', 'test', or None for both).
            Currently not used but kept for Lightning compatibility. Defaults to None.

    Returns:
        List[str]: List of paths to test datasets.

    Note:
        Must be called before using dataloaders. The train/val/test split is
        deterministic when loading from cache.
    """
    print("setting up the datamodule")
    start_time = time.time()
    if (
        self.store_location is None
        or not os.path.exists(os.path.join(self.store_location, "train_labels.npy"))
        or self.force_recompute_indices
    ):
        if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
            self.nnz = self.dataset.mapped_dataset.get_merged_labels(
                "nnz", is_cat=False
            )
            self.clss_to_weight.remove("nnz")
            # Sigmoid scaling with 2 parameters
            midpoint = 2000
            steepness = 0.003
            # Apply sigmoid transformation
            # sigmoid(x) = 1 / (1 + exp(-steepness * (x - midpoint)))
            # Then scale to [1, NNZ_SCALE] range
            sigmoid_values = 1 / (1 + np.exp(-steepness * (self.nnz - midpoint)))
            self.nnz = 1 + ((NNZ_SCALE - 1) * sigmoid_values)
        if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
            labels = self.dataset.get_label_cats(
                self.clss_to_weight,
            )
        else:
            labels = np.zeros(self.n_samples, dtype=int)
        if isinstance(self.validation_split, int):
            len_valid = self.validation_split
        else:
            len_valid = int(self.n_samples * self.validation_split)
        if isinstance(self.test_split, int):
            len_test = self.test_split
        else:
            len_test = int(self.n_samples * self.test_split)
        assert (
            len_test + len_valid < self.n_samples
        ), "test set + valid set size is configured to be larger than entire dataset."

        idx_full = []
        if len(self.assays_to_drop) > 0:
            badloc = np.isin(
                self.dataset.mapped_dataset.get_merged_labels(
                    "assay_ontology_term_id"
                ),
                self.assays_to_drop,
            )
            idx_full = np.arange(len(labels))[~badloc]
        else:
            idx_full = np.arange(self.n_samples)
        if len_test > 0:
            # this way we work on some never seen datasets
            # keeping at least one
            len_test = (
                len_test
                if len_test > self.dataset.mapped_dataset.n_obs_list[0]
                else self.dataset.mapped_dataset.n_obs_list[0]
            )
            cs = 0
            d_size = list(enumerate(self.dataset.mapped_dataset.n_obs_list))
            random.Random(42).shuffle(d_size)  # always same order
            for i, c in d_size:
                if cs + c > len_test:
                    break
                else:
                    self.test_datasets.append(
                        self.dataset.mapped_dataset.path_list[i].path
                    )
                    cs += c
            len_test = cs
            self.test_idx = idx_full[:len_test]
            idx_full = idx_full[len_test:]
        else:
            self.test_idx = None

        np.random.shuffle(idx_full)
        if len_valid > 0:
            self.valid_idx = idx_full[:len_valid].copy()
            # store it for later
            idx_full = idx_full[len_valid:]
        else:
            self.valid_idx = None
        labels[~np.isin(np.arange(self.n_samples), idx_full)] = labels.max() + 1
        # some labels will now not exist anymore as replaced by len(weights) - 1.
        # this means that the associated weights should be 0.
        # by doing np.bincount(labels)*weights this will be taken into account
        self.train_labels = labels
        self.idx_full = idx_full
    if self.store_location is not None:
        if (
            not os.path.exists(
                os.path.join(self.store_location, "train_labels.npy")
            )
            or self.force_recompute_indices
        ):
            os.makedirs(self.store_location, exist_ok=True)
            if self.nnz is not None:
                np.save(os.path.join(self.store_location, "nnz.npy"), self.nnz)
            np.save(
                os.path.join(self.store_location, "train_labels.npy"),
                self.train_labels,
            )
            np.save(
                os.path.join(self.store_location, "idx_full.npy"), self.idx_full
            )
            if self.test_idx is not None:
                np.save(
                    os.path.join(self.store_location, "test_idx.npy"), self.test_idx
                )
            if self.valid_idx is not None:
                np.save(
                    os.path.join(self.store_location, "valid_idx.npy"),
                    self.valid_idx,
                )
            listToFile(
                self.test_datasets,
                os.path.join(self.store_location, "test_datasets.txt"),
            )
        else:
            self.nnz = (
                np.load(os.path.join(self.store_location, "nnz.npy"), mmap_mode="r")
                if os.path.exists(os.path.join(self.store_location, "nnz.npy"))
                else None
            )
            self.train_labels = np.load(
                os.path.join(self.store_location, "train_labels.npy")
            )
            self.idx_full = np.load(
                os.path.join(self.store_location, "idx_full.npy"), mmap_mode="r"
            )
            self.test_idx = (
                np.load(os.path.join(self.store_location, "test_idx.npy"))
                if os.path.exists(os.path.join(self.store_location, "test_idx.npy"))
                else None
            )
            self.valid_idx = (
                np.load(os.path.join(self.store_location, "valid_idx.npy"))
                if os.path.exists(
                    os.path.join(self.store_location, "valid_idx.npy")
                )
                else None
            )
            self.test_datasets = fileToList(
                os.path.join(self.store_location, "test_datasets.txt")
            )
            print("loaded from store")
    print(f"done setup, took {time.time() - start_time:.2f} seconds")
    return self.test_datasets

test_dataloader

Create the test DataLoader with sequential sampling.

Returns:
  • DataLoader | List: Test DataLoader, or empty list if no test split.

Source code in scdataloader/datamodule.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
def test_dataloader(self):
    """
    Create the test DataLoader with sequential sampling.

    Returns:
        DataLoader | List: Test DataLoader, or empty list if no test split.
    """
    return (
        DataLoader(
            self.dataset, sampler=SequentialSampler(self.test_idx), **self.kwargs
        )
        if self.test_idx is not None
        else []
    )

train_dataloader

Create the training DataLoader with weighted random sampling.

Uses LabelWeightedSampler for class-balanced sampling when weight_scaler > 0 and clss_to_weight is specified. Otherwise uses RankShardSampler for distributed training without weighting.

Parameters:
  • **kwargs

    Additional arguments passed to DataLoader, overriding defaults.

Returns:
  • DataLoader

    Training DataLoader instance.

Raises:
  • ValueError

    If setup() has not been called.

Source code in scdataloader/datamodule.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
def train_dataloader(self, **kwargs):
    """
    Create the training DataLoader with weighted random sampling.

    Uses LabelWeightedSampler for class-balanced sampling when weight_scaler > 0
    and clss_to_weight is specified. Otherwise uses RankShardSampler for
    distributed training without weighting.

    Args:
        **kwargs: Additional arguments passed to DataLoader, overriding defaults.

    Returns:
        DataLoader: Training DataLoader instance.

    Raises:
        ValueError: If setup() has not been called.
    """
    if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
        try:
            print("Setting up the parallel train sampler...")
            # Create the optimized parallel sampler
            print(f"Using {self.sampler_workers} workers for class indexing")
            train_sampler = LabelWeightedSampler(
                labels=self.train_labels,
                weight_scaler=self.weight_scaler,
                num_samples=int(self.n_samples_per_epoch),
                element_weights=self.nnz,
                replacement=self.replacement,
                n_workers=self.sampler_workers,
                chunk_size=self.sampler_chunk_size,
                store_location=self.store_location,
                force_recompute_indices=self.force_recompute_indices,
                curiculum=self.curiculum,
            )
        except ValueError as e:
            raise ValueError(str(e) + " Have you run `datamodule.setup()`?")
        dataset = None
    else:
        dataset = Subset(self.dataset, self.idx_full)
        train_sampler = RankShardSampler(len(dataset), start_at=self.start_at)
    current_loader_kwargs = kwargs.copy()
    current_loader_kwargs.update(self.kwargs)
    return DataLoader(
        self.dataset if dataset is None else dataset,
        sampler=train_sampler,
        **current_loader_kwargs,
    )

val_dataloader

Create the validation DataLoader.

Returns:
  • DataLoader | List: Validation DataLoader, or empty list if no validation split.

Source code in scdataloader/datamodule.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
def val_dataloader(self):
    """
    Create the validation DataLoader.

    Returns:
        DataLoader | List: Validation DataLoader, or empty list if no validation split.
    """
    return (
        DataLoader(
            Subset(self.dataset, self.valid_idx),
            **self.kwargs,
        )
        if self.valid_idx is not None
        else []
    )