def __init__(
self,
organisms: list[str],
how: str = "all",
org_to_id: dict[str, int] = None,
valid_genes: list[str] = [],
max_len: int = 2000,
add_zero_genes: int = 0,
logp1: bool = False,
norm_to: Optional[float] = None,
n_bins: int = 0,
tp_name: Optional[str] = None,
organism_name: str = "organism_ontology_term_id",
class_names: list[str] = [],
genelist: list[str] = [],
downsample: Optional[float] = None, # don't use it for training!
save_output: Optional[str] = None,
):
"""
This class is responsible for collating data for the scPRINT model. It handles the
organization and preparation of gene expression data from different organisms,
allowing for various configurations such as maximum gene list length, normalization,
and selection method for gene expression.
This Collator should work with scVI's dataloader as well!
Args:
organisms (list): List of organisms to be considered for gene expression data.
it will drop any other organism it sees (might lead to batches of different sizes!)
how (flag, optional): Method for selecting gene expression. Defaults to "most expr".
one of ["most expr", "random expr", "all", "some"]:
"most expr": selects the max_len most expressed genes,
if less genes are expressed, will sample random unexpressed genes,
"random expr": uses a random set of max_len expressed genes.
if less genes are expressed, will sample random unexpressed genes
"all": uses all genes
"some": uses only the genes provided through the genelist param
org_to_id (dict): Dictionary mapping organisms to their respective IDs.
valid_genes (list, optional): List of genes from the datasets, to be considered. Defaults to [].
it will drop any other genes from the input expression data (usefull when your model only works on some genes)
max_len (int, optional): Total number of genes to use (for random expr and most expr). Defaults to 2000.
n_bins (int, optional): Number of bins for binning the data. Defaults to 0. meaning, no binning of expression.
add_zero_genes (int, optional): Number of additional unexpressed genes to add to the input data. Defaults to 0.
logp1 (bool, optional): If True, logp1 normalization is applied. Defaults to False.
norm_to (float, optional): Rescaling value of the normalization to be applied. Defaults to None.
organism_name (str, optional): Name of the organism ontology term id. Defaults to "organism_ontology_term_id".
tp_name (str, optional): Name of the heat diff. Defaults to None.
class_names (list, optional): List of other classes to be considered. Defaults to [].
genelist (list, optional): List of genes to be considered. Defaults to [].
If [] all genes will be considered
downsample (float, optional): Downsample the profile to a certain number of cells. Defaults to None.
This is usually done by the scPRINT model during training but this option allows you to do it directly from the collator
save_output (str, optional): If not None, saves the output to a file. Defaults to None.
This is mainly for debugging purposes
"""
self.organisms = organisms
self.genedf = load_genes(organisms)
self.max_len = max_len
self.n_bins = n_bins
self.add_zero_genes = add_zero_genes
self.logp1 = logp1
self.norm_to = norm_to
self.how = how
if self.how == "some":
assert len(genelist) > 0, "if how is some, genelist must be provided"
self.organism_name = organism_name
self.tp_name = tp_name
self.class_names = class_names
self.save_output = save_output
self.start_idx = {}
self.accepted_genes = {}
self.downsample = downsample
self.to_subset = {}
self._setup(org_to_id, valid_genes, genelist)