25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183 | def __init__(
self,
collection_name: str,
clss_to_weight: list = ["organism_ontology_term_id"],
organisms: list = ["NCBITaxon:9606"],
weight_scaler: int = 10,
train_oversampling_per_epoch: float = 0.1,
validation_split: float = 0.2,
test_split: float = 0,
gene_embeddings: str = "",
use_default_col: bool = True,
gene_position_tolerance: int = 10_000,
# this is for the mappedCollection
clss_to_predict: list = ["organism_ontology_term_id"],
hierarchical_clss: list = [],
# this is for the collator
how: str = "random expr",
organism_name: str = "organism_ontology_term_id",
max_len: int = 1000,
add_zero_genes: int = 100,
replacement: bool = True,
do_gene_pos: Union[bool, str] = True,
tp_name: Optional[str] = None, # "heat_diff"
assays_to_drop: list = [
# "EFO:0008853", #patch seq
# "EFO:0010961", # visium
"EFO:0030007", # ATACseq
# "EFO:0030062", # slide-seq
],
metacell_mode: float = 0.0,
get_knn_cells: bool = False,
modify_seed_on_requeue: bool = True,
**kwargs,
):
"""
DataModule a pytorch lighting datamodule directly from a lamin Collection.
it can work with bare pytorch too
It implements train / val / test dataloaders. the train is weighted random, val is random, test is one to many separated datasets.
This is where the mappedCollection, dataset, and collator are combined to create the dataloaders.
Args:
collection_name (str): The lamindb collection to be used.
organisms (list, optional): The organisms to include in the dataset. Defaults to ["NCBITaxon:9606"].
weight_scaler (int, optional): how much more you will see the most present vs less present category.
train_oversampling_per_epoch (float, optional): The proportion of the dataset to include in the training set for each epoch. Defaults to 0.1.
validation_split (float, optional): The proportion of the dataset to include in the validation split. Defaults to 0.2.
test_split (float, optional): The proportion of the dataset to include in the test split. Defaults to 0.
it will use a full dataset and will round to the nearest dataset's cell count.
gene_embeddings (str, optional): The path to the gene embeddings file. Defaults to "".
the file must have ensembl_gene_id as index.
This is used to subset the available genes further to the ones that have embeddings in your model.
use_default_col (bool, optional): Whether to use the default collator. Defaults to True.
gene_position_tolerance (int, optional): The tolerance for gene position. Defaults to 10_000.
any genes within this distance of each other will be considered at the same position.
clss_to_weight (list, optional): List of labels to weight in the trainer's weighted random sampler. Defaults to [].
assays_to_drop (list, optional): List of assays to drop from the dataset. Defaults to [].
do_gene_pos (Union[bool, str], optional): Whether to use gene positions. Defaults to True.
max_len (int, optional): The maximum length of the input tensor. Defaults to 1000.
add_zero_genes (int, optional): The number of zero genes to add to the input tensor. Defaults to 100.
how (str, optional): The method to use for the collator. Defaults to "random expr".
organism_name (str, optional): The name of the organism. Defaults to "organism_ontology_term_id".
tp_name (Optional[str], optional): The name of the timepoint. Defaults to None.
hierarchical_clss (list, optional): List of hierarchical classes. Defaults to [].
metacell_mode (float, optional): The probability of using metacell mode. Defaults to 0.0.
clss_to_predict (list, optional): List of classes to predict. Defaults to ["organism_ontology_term_id"].
modify_seed_on_requeue (bool, optional): Whether to modify the seed on requeue. Defaults to True.
get_knn_cells (bool, optional): Whether to get the k-nearest neighbors of each queried cells. Defaults to False.
**kwargs: Additional keyword arguments passed to the pytorch DataLoader.
see @file data.py and @file collator.py for more details about some of the parameters
"""
if collection_name is not None:
mdataset = Dataset(
ln.Collection.filter(name=collection_name).first(),
organisms=organisms,
clss_to_predict=clss_to_predict,
hierarchical_clss=hierarchical_clss,
metacell_mode=metacell_mode,
get_knn_cells=get_knn_cells,
)
# and location
self.metacell_mode = bool(metacell_mode)
self.gene_pos = None
self.collection_name = collection_name
if do_gene_pos:
if type(do_gene_pos) is str:
print("seeing a string: loading gene positions as biomart parquet file")
biomart = pd.read_parquet(do_gene_pos)
else:
# and annotations
if organisms != ["NCBITaxon:9606"]:
raise ValueError(
"need to provide your own table as this automated function only works for humans for now"
)
biomart = getBiomartTable(
attributes=["start_position", "chromosome_name"],
useCache=True,
).set_index("ensembl_gene_id")
biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
c = []
i = 0
prev_position = -100000
prev_chromosome = None
for _, r in biomart.iterrows():
if (
r["chromosome_name"] != prev_chromosome
or r["start_position"] - prev_position > gene_position_tolerance
):
i += 1
c.append(i)
prev_position = r["start_position"]
prev_chromosome = r["chromosome_name"]
print(f"reduced the size to {len(set(c)) / len(biomart)}")
biomart["pos"] = c
mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
if gene_embeddings != "":
mdataset.genedf = mdataset.genedf.join(
pd.read_parquet(gene_embeddings), how="inner"
)
if do_gene_pos:
self.gene_pos = mdataset.genedf["pos"].tolist()
self.classes = {k: len(v) for k, v in mdataset.class_topred.items()}
# we might want not to order the genes by expression (or do it?)
# we might want to not introduce zeros and
if use_default_col:
kwargs["collate_fn"] = Collator(
organisms=organisms,
how=how,
valid_genes=mdataset.genedf.index.tolist(),
max_len=max_len,
add_zero_genes=add_zero_genes,
org_to_id=mdataset.encoder[organism_name],
tp_name=tp_name,
organism_name=organism_name,
class_names=clss_to_predict,
)
self.validation_split = validation_split
self.test_split = test_split
self.dataset = mdataset
self.replacement = replacement
self.kwargs = kwargs
if "sampler" in self.kwargs:
self.kwargs.pop("sampler")
self.assays_to_drop = assays_to_drop
self.n_samples = len(mdataset)
self.weight_scaler = weight_scaler
self.train_oversampling_per_epoch = train_oversampling_per_epoch
self.clss_to_weight = clss_to_weight
self.train_weights = None
self.train_labels = None
self.modify_seed_on_requeue = modify_seed_on_requeue
self.nnz = None
self.restart_num = 0
self.test_datasets = []
self.test_idx = []
super().__init__()
|