scPRINT use case on BPH¶
In this use-case, also presented in Figure 5 of our manuscript, we perform an extensive analysis of a multi studies dataset of benign prostatic hyperplasia.
Our biological question is to check if there exist pre-cancerous cells that exhibits behaviors of mature cancer cells at this early stage of the disease.
In those cells, we want to know which genes might be implicated in cell state changes, and explore potentially novel targets in the treatment of prostate cancer and BPH.
We will start with a fresh datasets coming from the cellXgene database and representing 2 studies of BPH.
We will first explore these dataset to understand:
- what are the cell types that are present in the data
- what are the cell distributions (cell distributions? what are they?)
- what sequencers were used, etc.
We also want to confirm existing target in prostate cancer through precancerous lesion analysis, and find potentially novel ones that would serve as less invasive BPH treatments than current ones.
Finally we want to know how these targets interacts and are involved in biological pathways.
We now showcase how to use scPRINT across its different functionalities to answer some of these questions.
table of contents:¶
- Downloading and preprocessing
- Embedding and annotations
- Annotation cleanup
- Clustering and differential expression
- Denoising and differential expression
- Gene network inference
In the notebook cancer_usecase_part2.ipynb you will see how to analyse cell type specific gene regulatory networks.
from scprint import scPrint
from scdataloader import Preprocessor, utils
from scprint.tasks import GNInfer, Embedder, Denoiser, withknn
from bengrn import BenGRN, get_sroy_gt, compute_genie3
from bengrn.base import train_classifier
from grnndata import utils as grnutils
from grnndata import read_h5ad
from anndata.utils import make_index_unique
from anndata import concat
import scanpy as sc
from matplotlib import pyplot as plt
import numpy as np
import lamindb as ln
%load_ext autoreload
%autoreload 2
import torch
torch.set_float32_matmul_precision('medium')
💡 connected lamindb: jkobject/scprint
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/umap/__init__.py:9: ImportWarning: Tensorflow not installed; ParametricUMAP will be unavailable warn( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/chex/_src/pytypes.py:53: DeprecationWarning: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any]. Shape = jax.core.Shape /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/chex/_src/pytypes.py:54: DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys (i.e. keys created with jax.random.key). For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html PRNGKey = jax.random.KeyArray /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scib_metrics/_types.py:9: DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys (i.e. keys created with jax.random.key). For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html IntOrKey = Union[int, jax.random.KeyArray] /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scib_metrics/utils/_utils.py:40: DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys (i.e. keys created with jax.random.key). For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html def validate_seed(seed: IntOrKey) -> jax.random.KeyArray: /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scib_metrics/utils/_kmeans.py:21: DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys (i.e. keys created with jax.random.key). For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray: /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scib_metrics/utils/_kmeans.py:31: DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys (i.e. keys created with jax.random.key). For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
Downloading and preprocessing¶
We now use lamindb to easily access cellxgene and download a dataset of normal and benign prostatic hyperplasia tissues.
data is available here https://cellxgene.cziscience.com/e/574e9f9e-f8b4-41ef-bf19-89a9964fd9c7.cxg/ .
We then use scDataloader's preprocessing method. This method is quite extensive and does a few things.. find our more about it on its documentation.
On our end we are using the preprocessor to make sure that the the gene expression that we have are raw counts and that we have enough information to use scPRINT (i.e., enough genes expressed and enough counts per cells across the dataset).
Finally, the preprocessor will also increase the size of the expression matrix to be a fixed set of genes defined by the latest version of ensemble.
cx_dataset = ln.Collection.using(instance="laminlabs/cellxgene").filter(name="cellxgene-census", version='2023-12-15').one()
prostate_adata = cx_dataset.artifacts.filter(key='cell-census/2023-12-15/h5ads/574e9f9e-f8b4-41ef-bf19-89a9964fd9c7.h5ad').one().load()
sc.pl.umap(prostate_adata)
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/asyncio/sslproto.py:320: ResourceWarning: unclosed transport <asyncio.sslproto._SSLProtocolTransport object at 0x7fd088b36f80>
_warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
ResourceWarning: Enable tracemalloc to get the object allocation traceback
# preprocessing using scDataloader
prostate_adata.obs.drop(columns="is_primary_data", inplace=True)
preprocessor = Preprocessor(do_postp=False)
prostate_adata = preprocessor(prostate_adata)
Dropping layers: KeysView(Layers with keys: ) checking raw counts removed 0 non primary cells, 83451 renamining filtered out 0 cells, 83451 renamining Removed 76 genes. validating
/home/ml4ig1/Documents code/scPRINT/scDataLoader/scdataloader/preprocess.py:241: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]` data_utils.validate(adata, organism=adata.obs.organism_ontology_term_id[0])
startin QC Seeing 13047 outliers (15.63% of total dataset): done
Embedding and annotations¶
We now start to load a large version of scPRINT from a specific checkpoint. Please download the checkpoints following the instructions in the README.
We will then use out Embedder class to embed the data and annotate the cells. These classes are how we parametrize and access the different functions of scPRINT. Find out more about its parameters in our documentation.
model = scPrint.load_from_checkpoint('../data/temp/o2uniqsx/epoch=18-step=133000.ckpt',
# make sure that you check if you have a GPU with flashattention or not (see README)
# you need this else, the model will look for the file with the embeddings instead of getting them from the weights
precpt_gene_emb = None)
RuntimeError caught: scPrint is not attached to a `Trainer`.
embedder = Embedder(
# can work on random genes or most variables etc..
how="random expr",
# number of genes to use
max_len=4000,
add_zero_genes=0,
# for the dataloading
num_workers=8,
# we will only use the cell type embedding here.
pred_embedding = ["cell_type_ontology_term_id"]
)#, "disease_ontology_term_id"])
Using 16bit Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
# create the embedding
prostate_adata, metrics = embedder(model, prostate_adata, cache=False, output_expression="none")
100%|██████████| 1560/1560 [28:20<00:00, 1.09s/it]
AnnData object with n_obs × n_vars = 99826 × 424
obs: 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_pred_disease_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'conv_pred_self_reported_ethnicity_ontology_term_id'
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map))