Documentation for DataModule

scdataloader.datamodule.DataModule

Bases: LightningDataModule

DataModule a pytorch lighting datamodule directly from a lamin Collection. it can work with bare pytorch too

It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets. This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.

Parameters:
  • collection_name (str) –

    The lamindb collection to be used.

  • organisms (list, default: ['NCBITaxon:9606'] ) –

    The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].

  • weight_scaler (int, default: 10 ) –

    how much more you will see the most present vs less present category.

  • train_oversampling_per_epoch (float, default: 0.1 ) –

    The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.

  • validation_split (float, default: 0.2 ) –

    The proportion of the dataset to include in the validation split. Defaults to 0.2.

  • test_split (float, default: 0 ) –

    The proportion of the dataset to include in the test split. Defaults to 0. it will use a full dataset and will round to the nearest dataset's cell count.

  • gene_embeddings (str, default: '' ) –

    The path to the gene embeddings file. Defaults to "". the file must have ensembl_gene_id as index. This is used to subset the available genes further to the ones that have embeddings in your model.

  • use_default_col (bool, default: True ) –

    Whether to use the default collator. Defaults to True.

  • gene_position_tolerance (int, default: 10000 ) –

    The tolerance for gene position. Defaults to 10_000. any genes within this distance of each other will be considered at the same position.

  • clss_to_weight (list, default: ['organism_ontology_term_id'] ) –

    List of labels to weight in the trainer's weighted random sampler. Defaults to [].

  • assays_to_drop (list, default: ['EFO:0030007'] ) –

    List of assays to drop from the dataset. Defaults to [].

  • do_gene_pos (Union[bool, str], default: True ) –

    Whether to use gene positions. Defaults to True.

  • max_len (int, default: 1000 ) –

    The maximum length of the input tensor. Defaults to 1000.

  • add_zero_genes (int, default: 100 ) –

    The number of zero genes to add to the input tensor. Defaults to 100.

  • how (str, default: 'random expr' ) –

    The method to use for the collator. Defaults to "random expr".

  • organism_name (str, default: 'organism_ontology_term_id' ) –

    The name of the organism. Defaults to "organism_ontology_term_id".

  • tp_name (Optional[str], default: None ) –

    The name of the timepoint. Defaults to None.

  • hierarchical_clss (list, default: [] ) –

    List of hierarchical classes. Defaults to [].

  • metacell_mode (float, default: 0.0 ) –

    The probability of using metacell mode. Defaults to 0.0.

  • clss_to_predict (list, default: ['organism_ontology_term_id'] ) –

    List of classes to predict. Defaults to ["organism_ontology_term_id"].

  • modify_seed_on_requeue (bool, default: True ) –

    Whether to modify the seed on requeue. Defaults to True.

  • get_knn_cells (bool, default: False ) –

    Whether to get the k-nearest neighbors of each queried cells. Defaults to False.

  • **kwargs

    Additional keyword arguments passed to the pytorch DataLoader.

Methods:

Name Description
setup

setup method is used to prepare the data for the training, validation, and test sets.

Attributes:
  • decoders

    decoders the decoders for any labels that would have been encoded

  • genes

    genes the genes used in this datamodule

  • labels_hierarchy

    labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy

Source code in scdataloader/datamodule.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def __init__(
    self,
    collection_name: str,
    clss_to_weight: list = ["organism_ontology_term_id"],
    organisms: list = ["NCBITaxon:9606"],
    weight_scaler: int = 10,
    train_oversampling_per_epoch: float = 0.1,
    validation_split: float = 0.2,
    test_split: float = 0,
    gene_embeddings: str = "",
    use_default_col: bool = True,
    gene_position_tolerance: int = 10_000,
    # this is for the mappedCollection
    clss_to_predict: list = ["organism_ontology_term_id"],
    hierarchical_clss: list = [],
    # this is for the collator
    how: str = "random expr",
    organism_name: str = "organism_ontology_term_id",
    max_len: int = 1000,
    add_zero_genes: int = 100,
    replacement: bool = True,
    do_gene_pos: Union[bool, str] = True,
    tp_name: Optional[str] = None,  # "heat_diff"
    assays_to_drop: list = [
        # "EFO:0008853", #patch seq
        # "EFO:0010961", # visium
        "EFO:0030007",  # ATACseq
        # "EFO:0030062", # slide-seq
    ],
    metacell_mode: float = 0.0,
    get_knn_cells: bool = False,
    modify_seed_on_requeue: bool = True,
    **kwargs,
):
    """
    DataModule a pytorch lighting datamodule directly from a lamin Collection.
    it can work with bare pytorch too

    It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
    This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.

    Args:
        collection_name (str): The lamindb collection to be used.
        organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
        weight_scaler (int, optional): how much more you will see the most present vs less present category.
        train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
        validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
        test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
            it will use a full dataset and will round to the nearest dataset's cell count.
        gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
            the file must have ensembl_gene_id as index.
            This is used to subset the available genes further to the ones that have embeddings in your model.
        use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
        gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
            any genes within this distance of each other will be considered at the same position.
        clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
        assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
        do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
        max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
        add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
        how (str, optional): The method to use for the collator. Defaults to "random expr".
        organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
        tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
        hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
        metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
        clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
        modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
        get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
        **kwargs: Additional keyword arguments passed to the pytorch DataLoader.
        see @file data.py and @file collator.py for more details about some of the parameters
    """
    if collection_name is not None:
        mdataset = Dataset(
            ln.Collection.filter(name=collection_name).first(),
            organisms=organisms,
            clss_to_predict=clss_to_predict,
            hierarchical_clss=hierarchical_clss,
            metacell_mode=metacell_mode,
            get_knn_cells=get_knn_cells,
        )
    # and location
    self.metacell_mode = bool(metacell_mode)
    self.gene_pos = None
    self.collection_name = collection_name
    if do_gene_pos:
        if type(do_gene_pos) is str:
            print("seeing a string: loading gene positions as biomart parquet file")
            biomart = pd.read_parquet(do_gene_pos)
        else:
            # and annotations
            if organisms != ["NCBITaxon:9606"]:
                raise ValueError(
                    "need to provide your own table as this automated function only works for humans for now"
                )
            biomart = getBiomartTable(
                attributes=["start_position", "chromosome_name"],
                useCache=True,
            ).set_index("ensembl_gene_id")
            biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
            biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
            c = []
            i = 0
            prev_position = -100000
            prev_chromosome = None
            for _, r in biomart.iterrows():
                if (
                    r["chromosome_name"] != prev_chromosome
                    or r["start_position"] - prev_position > gene_position_tolerance
                ):
                    i += 1
                c.append(i)
                prev_position = r["start_position"]
                prev_chromosome = r["chromosome_name"]
            print(f"reduced the size to {len(set(c)) / len(biomart)}")
            biomart["pos"] = c
        mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
        self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()

    if gene_embeddings != "":
        mdataset.genedf = mdataset.genedf.join(
            pd.read_parquet(gene_embeddings), how="inner"
        )
        if do_gene_pos:
            self.gene_pos = mdataset.genedf["pos"].tolist()
    self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
    # we might want not to order the genes by expression (or do it?)
    # we might want to not introduce zeros and
    if use_default_col:
        kwargs["collate_fn"] = Collator(
            organisms=organisms,
            how=how,
            valid_genes=mdataset.genedf.index.tolist(),
            max_len=max_len,
            add_zero_genes=add_zero_genes,
            org_to_id=mdataset.encoder[organism_name],
            tp_name=tp_name,
            organism_name=organism_name,
            class_names=clss_to_predict,
        )
    self.validation_split = validation_split
    self.test_split = test_split
    self.dataset = mdataset
    self.replacement = replacement
    self.kwargs = kwargs
    if "sampler" in self.kwargs:
        self.kwargs.pop("sampler")
    self.assays_to_drop = assays_to_drop
    self.n_samples = len(mdataset)
    self.weight_scaler = weight_scaler
    self.train_oversampling_per_epoch = train_oversampling_per_epoch
    self.clss_to_weight = clss_to_weight
    self.train_weights = None
    self.train_labels = None
    self.modify_seed_on_requeue = modify_seed_on_requeue
    self.nnz = None
    self.restart_num = 0
    self.test_datasets = []
    self.test_idx = []
    super().__init__()

decoders property

decoders the decoders for any labels that would have been encoded

Returns:
  • dict[str, dict[int, str]]

genes property

genes the genes used in this datamodule

Returns:
  • list

labels_hierarchy property

labels_hierarchy the hierarchy of labels for any cls that would have a hierarchy

Returns:
  • dict[str, dict[str, str]]

setup

setup method is used to prepare the data for the training, validation, and test sets. It shuffles the data, calculates weights for each set, and creates samplers for each set.

Parameters:
  • stage (str, default: None ) –

    The stage of the model training process.

Source code in scdataloader/datamodule.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def setup(self, stage=None):
    """
    setup method is used to prepare the data for the training, validation, and test sets.
    It shuffles the data, calculates weights for each set, and creates samplers for each set.

    Args:
        stage (str, optional): The stage of the model training process.
        It can be either 'fit' or 'test'. Defaults to None.
    """
    SCALE = 10
    if "nnz" in self.clss_to_weight and self.weight_scaler > 0:
        self.nnz = self.dataset.mapped_dataset.get_merged_labels("nnz")
        self.clss_to_weight.remove("nnz")
        (
            (self.nnz.max() / SCALE)
            / ((1 + self.nnz - self.nnz.min()) + (self.nnz.max() / SCALE))
        ).min()
    if len(self.clss_to_weight) > 0 and self.weight_scaler > 0:
        weights, labels = self.dataset.get_label_weights(
            self.clss_to_weight,
            scaler=self.weight_scaler,
            return_categories=True,
        )
    else:
        weights = np.ones(1)
        labels = np.zeros(self.n_samples, dtype=int)
    if isinstance(self.validation_split, int):
        len_valid = self.validation_split
    else:
        len_valid = int(self.n_samples * self.validation_split)
    if isinstance(self.test_split, int):
        len_test = self.test_split
    else:
        len_test = int(self.n_samples * self.test_split)
    assert (
        len_test + len_valid < self.n_samples
    ), "test set + valid set size is configured to be larger than entire dataset."

    idx_full = []
    if len(self.assays_to_drop) > 0:
        badloc = np.isin(
            self.dataset.mapped_dataset.get_merged_labels("assay_ontology_term_id"),
            self.assays_to_drop,
        )
        idx_full = np.arange(len(labels))[~badloc]
    else:
        idx_full = np.arange(self.n_samples)
    if len_test > 0:
        # this way we work on some never seen datasets
        # keeping at least one
        len_test = (
            len_test
            if len_test > self.dataset.mapped_dataset.n_obs_list[0]
            else self.dataset.mapped_dataset.n_obs_list[0]
        )
        cs = 0
        for i, c in enumerate(self.dataset.mapped_dataset.n_obs_list):
            if cs + c > len_test:
                break
            else:
                self.test_datasets.append(
                    self.dataset.mapped_dataset.path_list[i].path
                )
                cs += c
        len_test = cs
        self.test_idx = idx_full[:len_test]
        idx_full = idx_full[len_test:]
    else:
        self.test_idx = None

    np.random.shuffle(idx_full)
    if len_valid > 0:
        self.valid_idx = idx_full[:len_valid].copy()
        # store it for later
        idx_full = idx_full[len_valid:]
    else:
        self.valid_idx = None
    weights = np.concatenate([weights, np.zeros(1)])
    labels[~np.isin(np.arange(self.n_samples), idx_full)] = len(weights) - 1
    # some labels will now not exist anymore as replaced by len(weights) - 1.
    # this means that the associated weights should be 0.
    # by doing np.bincount(labels)*weights this will be taken into account
    self.train_weights = weights
    self.train_labels = labels
    self.idx_full = idx_full
    return self.test_datasets