Documentation for base module

bengrn.base

bengrn base module.

Classes:

Name Description
BenGRN

Functions:

Name Description
compute_epr

compute_epr computes the Expected Precision Recall (EPR) metric for the given classifier, test data, and true labels.

compute_genie3

This function computes the GENIE3 algorithm on the given data.

compute_pr

compute_pr computes the precision and recall metrics for the given GRN and true matrix.

download_perturb_gt

download_perturb_gt downloads the genome wide perturb seq ground truth data.

get_GT_db

use_prior_network loads a prior GRN from a list of available networks.

get_perturb_gt

get_perturb_gt retrieves the genome wide perturb seq ground truth data.

get_scenicplus

This function retrieves a loomx scenicplus data from a given file path and loads it as a GrnnData

get_sroy_gt

This function retrieves the ground truth data from the McCall et al.'s paper.

load_genes

load_genes loads the genes for the given organisms.

precision_recall

Calculate precision and recall from the true and predicted connections.

train_classifier

train_classifier trains a classifier to generate a GRN that maps to the ground truth.

BenGRN

Initializes the BenGRN class.

Parameters:
  • grn (GRNAnnData) –

    The Gene Regulatory Network data.

  • full_dataset (Optional[AnnData], default: None ) –

    The full dataset, defaults to None.

  • doplot (bool, default: True ) –

    Whether to plot the results, defaults to True.

  • do_auc (bool, default: True ) –

    Whether to calculate the Area Under the Precision-Recall Curve, defaults to True.

Methods:

Name Description
compare_to

compare_to compares the GRN to another GRN.

scprint_benchmark

scprint_benchmark full benchmarks of the GRN as in the scPRINT paper.

Source code in bengrn/base.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    grn: GRNAnnData,
    full_dataset: Optional[AnnData] = None,
    doplot: bool = True,
    do_auc: bool = True,
):
    """
    Initializes the BenGRN class.

    Args:
        grn (GRNAnnData): The Gene Regulatory Network data.
        full_dataset (Optional[AnnData]): The full dataset, defaults to None.
        doplot (bool): Whether to plot the results, defaults to True.
        do_auc (bool): Whether to calculate the Area Under the Precision-Recall Curve, defaults to True.
    """
    self.grn = grn
    self.full_dataset = full_dataset
    self.doplot = doplot
    self.do_auc = do_auc

compare_to

compare_to compares the GRN to another GRN.

Parameters:
  • other (Optional[GRNAnnData], default: None ) –

    The other GRN to compare to. Defaults to None. If not given can use a default GRN from the 'to' argument.

  • to (str, default: 'collectri' ) –

    The name of the other GRN to compare to. Defaults to "collectri". If 'other' is given, this argument is ignored.

  • organism (str, default: 'human' ) –

    The organism of the GRN to compare to. Defaults to "human".

Returns:
  • dict

    The metrics of the comparison.

Source code in bengrn/base.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def compare_to(
    self,
    other: Optional[GRNAnnData] = None,
    to: str = "collectri",
    organism: str = "human",
):
    """
    compare_to compares the GRN to another GRN.

    Args:
        other (Optional[GRNAnnData], optional): The other GRN to compare to. Defaults to None.
            If not given can use a default GRN from the 'to' argument.
        to (str, optional): The name of the other GRN to compare to. Defaults to "collectri".
            If 'other' is given, this argument is ignored.
        organism (str, optional): The organism of the GRN to compare to. Defaults to "human".

    Returns:
        dict: The metrics of the comparison.
    """
    if other is None:
        if self.doplot:
            print("loading GT, ", to)
        gt = get_GT_db(name=to, organism=organism)
        # gt = gt[gt.type != "post_translational"]
        varnames = set(gt.iloc[:, :2].values.flatten())
        intersection = varnames & set(self.grn.var["symbol"].tolist())
        loc = self.grn.var["symbol"].isin(intersection)
        adj = self.grn.varp["GRN"][:, loc][loc, :]
        genes = self.grn.var.loc[loc, "symbol"].tolist()

        da = np.zeros(adj.shape, dtype=float)
        for i, j in gt.iloc[:, :2].values:
            if i in genes and j in genes:
                da[genes.index(i), genes.index(j)] = 1
        if self.doplot:
            print("intersection of {} genes".format(len(intersection)))
            print("intersection pct:", len(intersection) / len(self.grn.grn.index))
    else:
        elems = other.var[other.grn.sum(1) != 0].index.tolist()
        da = other.get(self.grn.var.index.tolist()).get(elems).targets
        if da.shape[1] < 5:
            print("da is very small: ", da.shape[1])
        # da = da.iloc[6:]
        adj = self.grn.grn.loc[da.index.values, da.columns.values].values
        da = da.values

    return compute_pr(
        adj,
        da,
        doplot=self.doplot,
        do_auc=self.do_auc,
    )

scprint_benchmark

scprint_benchmark full benchmarks of the GRN as in the scPRINT paper.

It will apply first an enrichment analysis over the [elems] of the GRN looking for TF enrichment and cell type marker gene enrichment It will then apply an enrichment over each TF in the GRN for their targets in ENCODE. Finaly, it will compare it to the OmniPath database GRN using precision recall type metrics.

Parameters:
  • elems (list, default: ['Central', 'Regulators', 'Targets'] ) –

    The genes in the GRN, to benchmark. Defaults to ["Central", "Regulators", "Targets"]. It corresponds to different views of the network.

Returns:
  • dict

    The metrics of the benchmark.

Source code in bengrn/base.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def scprint_benchmark(self, elems=["Central", "Regulators", "Targets"]):
    """
    scprint_benchmark full benchmarks of the GRN as in the scPRINT paper.

    It will apply first an enrichment analysis over the [elems] of the GRN looking for TF enrichment and cell type marker gene enrichment
    It will then apply an enrichment over each TF in the GRN for their targets in ENCODE.
    Finaly, it will compare it to the OmniPath database GRN using precision recall type metrics.

    Args:
        elems (list, optional): The genes in the GRN, to benchmark. Defaults to ["Central", "Regulators", "Targets"].
            It corresponds to different views of the network.

    Returns:
        dict: The metrics of the benchmark.
    """
    print("base enrichment")
    metrics = {}
    for elem in elems:
        if elem == "Central" and (self.grn.varp["GRN"] != 0).sum() > 100_000_000:
            print("too many genes for central computation")
            continue
        res = utils.enrichment(
            self.grn,
            of=elem,
            gene_sets=[
                {"TFs": [i.split(".")[0] for i in utils.TF]},
                utils.file_dir + "/celltype.gmt",
            ],
            doplot=False,
            maxsize=2000,
            top_k=10,
        )
        if res is None:
            continue
        if (
            len(res.res2d[(res.res2d["FDR q-val"] < 0.1) & (res.res2d["NES"] > 1)])
            > 0
        ):
            metrics.update(
                {
                    "enriched_terms_" + elem: res.res2d[
                        (res.res2d["FDR q-val"] < 0.1) & (res.res2d["NES"] > 1)
                    ].Term.tolist()
                }
            )
        if self.doplot:
            try:
                _ = res.plot(terms="0__TFs")
                plt.show()
            except KeyError:
                pass
        istrue = metrics.get("TF_enr", False)
        if len(res.res2d.loc[res.res2d.Term == "0__TFs"]) > 0:
            istrue = istrue or (
                res.res2d.loc[res.res2d.Term == "0__TFs", "FDR q-val"].iloc[0] < 0.1
            )
        metrics.update({"TF_enr": istrue})
    if self.doplot:
        print("_________________________________________")
        print("TF specific enrichment")
    with open(FILEDIR + "/../data/tfchip_data.json", "r") as file:
        tfchip = json.load(file)
    TFinchip = {i: i.split(" ")[0] for i in tfchip.keys()}
    res = {}
    i, j = 0, 0
    previous_level = logging.root.manager.disable
    logging.disable(logging.WARNING)
    for k, v in TFinchip.items():
        if v not in self.grn.grn.columns:
            continue
        j += 1
        test = self.grn.grn.T.loc[[v]].sort_values(by=v, ascending=False).T
        if len(set(test.index) & set(tfchip[k])) == 0:
            continue
        if test.iloc[:, 0].sum() == 0:
            continue
        try:
            pre_res = gp.prerank(
                rnk=test,
                gene_sets=[{v: tfchip[k]}],
                background=self.grn.var.index.tolist(),
                min_size=1,
                max_size=4000,
                permutation_num=1000,
            )
        except IndexError:
            continue
        val = (
            pre_res.res2d[
                (pre_res.res2d["FDR q-val"] < 0.05) & (pre_res.res2d["NES"] > 1)
            ]
            .sort_values(by=["NES"], ascending=False)
            .drop(columns=["Name"])
        )
        if len(val.Term.tolist()) > 0:
            i += 1
        else:
            pass
        res[k] = pre_res.res2d
    logging.disable(previous_level)
    j = j if j != 0 else 1
    if self.doplot:
        print("found some significant results for ", i * 100 / j, "% TFs\n")
        print("_________________________________________")
    metrics.update({"significant_enriched_TFtargets": i * 100 / j})
    metrics.update(self.compare_to(to="omnipath", organism="human"))
    return metrics

compute_epr

compute_epr computes the Expected Precision Recall (EPR) metric for the given classifier, test data, and true labels.

Parameters:
  • clf (ClassifierMixin) –

    The classifier to evaluate.

  • X_test (ndarray) –

    The test data features.

  • y_test (ndarray) –

    The true labels for the test data.

Returns:
  • float( float ) –

    The computed Expected Precision Recall (EPR) metric.

Source code in bengrn/base.py
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
def compute_epr(clf, X_test: np.ndarray, y_test: np.ndarray) -> float:
    """
    compute_epr computes the Expected Precision Recall (EPR) metric for the given classifier, test data, and true labels.

    Args:
        clf (sklearn.base.ClassifierMixin): The classifier to evaluate.
        X_test (numpy.ndarray): The test data features.
        y_test (numpy.ndarray): The true labels for the test data.

    Returns:
        float: The computed Expected Precision Recall (EPR) metric.
    """
    prb = clf.predict_proba(X_test)[:, 1]

    K = sum(y_test)
    # get only the top-K elems from prb
    pred = np.zeros(prb.shape)
    pred[np.argsort(prb)[-int(K) :]] = 1

    true_positive = np.sum(pred[y_test == 1] == 1)
    false_positive = np.sum(pred[y_test == 0] == 1)
    false_negative = np.sum(pred[y_test == 1] == 0)
    true_negative = np.sum(pred[y_test == 0] == 0)
    odds_ratio = (true_positive * true_negative) / (false_positive * false_negative)
    return odds_ratio

compute_genie3

This function computes the GENIE3 algorithm on the given data.

Parameters:
  • adata (AnnData) –

    The annotated data matrix of shape n_obs x n_vars. Rows correspond to cells and columns to genes.

  • nthreads (int, default: 30 ) –

    The number of threads to use for computation. Defaults to 30.

  • ntrees (int, default: 100 ) –

    The number of trees to use for the Random Forests. Defaults to 100.

  • **kwargs

    Additional arguments to pass to the GENIE3 function.

Returns:
  • GRNAnnData( GRNAnnData ) –

    The Gene Regulatory Network data computed using the GENIE3 algorithm.

Source code in bengrn/base.py
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
def compute_genie3(
    adata: AnnData, nthreads: int = 30, ntrees: int = 100, **kwargs
) -> GRNAnnData:
    """
    This function computes the GENIE3 algorithm on the given data.

    Args:
        adata (AnnData): The annotated data matrix of shape n_obs x n_vars. Rows correspond to cells and columns to genes.
        nthreads (int, optional): The number of threads to use for computation. Defaults to 30.
        ntrees (int, optional): The number of trees to use for the Random Forests. Defaults to 100.
        **kwargs: Additional arguments to pass to the GENIE3 function.

    Returns:
        GRNAnnData: The Gene Regulatory Network data computed using the GENIE3 algorithm.
    """
    mat = np.asarray(adata.X.toarray() if issparse(adata.X) else adata.X)
    names = adata.var_names.tolist()  # [mat.sum(0) > 0].tolist()
    var = adata.var  # [mat.sum(0) > 0]
    # mat = mat[:, mat.sum(0) > 0]
    VIM = GENIE3(mat, gene_names=names, nthreads=nthreads, ntrees=ntrees, **kwargs)
    grn = GRNAnnData(grn=VIM, X=mat, var=var, obs=adata.obs)
    grn.var_names = grn.var["symbol"]
    grn.var["TFs"] = [True if i in utils.TF else False for i in grn.var_names]
    return grn

compute_pr

compute_pr computes the precision and recall metrics for the given GRN and true matrix.

Parameters:
  • grn (array) –

    The Gene Regulatory Network matrix, where each element represents the strength of the regulatory relationship between genes.

  • true (array) –

    The ground truth matrix, where each element indicates the presence (1) or absence (0) of a regulatory relationship.

  • do_auc (bool, default: True ) –

    Whether to compute the Area Under the Precision-Recall Curve (AUPRC). Defaults to True.

  • doplot (bool, default: True ) –

    Whether to plot the precision and recall metrics. Defaults to True.

Raises:
  • ValueError

    If the shape of the GRN and the true matrix do not match.

Returns:
  • dict

    A dictionary containing precision, recall, and random precision metrics.

Source code in bengrn/base.py
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
def compute_pr(
    grn: np.array,
    true: np.array,
    do_auc: bool = True,
    doplot: bool = True,
):
    """
    compute_pr computes the precision and recall metrics for the given GRN and true matrix.

    Args:
        grn (np.array): The Gene Regulatory Network matrix, where each element represents the strength of the regulatory relationship between genes.
        true (np.array): The ground truth matrix, where each element indicates the presence (1) or absence (0) of a regulatory relationship.
        do_auc (bool, optional): Whether to compute the Area Under the Precision-Recall Curve (AUPRC). Defaults to True.
        doplot (bool, optional): Whether to plot the precision and recall metrics. Defaults to True.

    Raises:
        ValueError: If the shape of the GRN and the true matrix do not match.

    Returns:
        dict: A dictionary containing precision, recall, and random precision metrics.
    """
    if grn.shape != true.shape:
        raise ValueError("The shape of the GRN and the true matrix do not match.")
    metrics = {}
    if isinstance(grn, (csr_matrix, csc_matrix)):
        grn = grn.toarray()
    if isinstance(true, (csr_matrix, csc_matrix)):
        true = true.toarray()
    true = true.astype(bool)
    tot = (grn.shape[0] * grn.shape[1]) - grn.shape[0]
    precision = (grn[true] != 0).sum() / (grn != 0).sum()
    recall = (grn[true] != 0).sum() / true.sum()
    rand_prec = true.sum() / tot

    if doplot:
        print(
            "precision: ",
            precision,
            "\nrecall: ",
            recall,
            "\nrandom precision:",
            rand_prec,
        )
    metrics.update(
        {
            "precision": precision,
            "recall": recall,
            "rand_precision": rand_prec,
        }
    )
    # Initialize lists to store precision and recall values
    precision_list = [precision]
    recall_list = [recall]
    # Define the thresholds to vary
    thresholds = np.append(
        np.linspace(0, 1, 101)[:-2], np.log10(np.logspace(0.99, 1, 30))
    )
    thresholds = np.quantile(grn, thresholds)
    # Calculate precision and recall for each threshold
    if do_auc:
        for threshold in tqdm.tqdm(thresholds[1:]):
            precision = (grn[true] > threshold).sum() / (grn > threshold).sum()
            recall = (grn[true] > threshold).sum() / true.sum()
            precision_list.append(precision)
            recall_list.append(recall)
        # Calculate AUPRC by integrating the precision-recall curve
        if 1.0 not in recall_list:
            precision_list.insert(0, rand_prec)
            recall_list.insert(0, recall_list[0])
            precision_list.insert(0, rand_prec)
            recall_list.insert(0, 1.0)
        precision_list = np.nan_to_num(np.array(precision_list))
        recall_list = np.nan_to_num(np.array(recall_list))
        auprc = -np.trapz(precision_list, recall_list)
        metrics["auprc"] = auprc

        # Compute Average Precision (AP) manually
        sorted_indices = np.argsort(-grn.flatten())
        sorted_true = true.flatten()[sorted_indices]

        tp_cumsum = np.cumsum(sorted_true)
        fp_cumsum = np.cumsum(~sorted_true)

        precision_at_k = tp_cumsum / (tp_cumsum + fp_cumsum)
        recall_at_k = tp_cumsum / true.sum()

        ap = np.sum(precision_at_k[1:] * np.diff(recall_at_k))
        metrics["ap"] = ap
        if doplot:
            print("Average Precision (AP): ", ap)
        if doplot:
            print("Area Under Precision-Recall Curve (AUPRC): ", auprc)

    # compute EPR
    # get the indices of the topK highest values in "grn"
    if isinstance(grn, csr_matrix):
        grn = grn.toarray()
    if isinstance(grn, csc_matrix):
        grn = grn.toarray()
    indices = np.argpartition(grn.flatten(), -int(true.sum()))[-int(true.sum()) :]
    # Compute the odds ratio
    true_positive = true[np.unravel_index(indices, true.shape)].sum()
    false_positive = true.sum() - true_positive
    # this is normal as we compute on the same number of pred_pos as true_pos
    false_negative = true.sum() - true_positive
    true_negative = tot - true_positive - false_positive - false_negative
    # Avoid division by zero
    # this is a debugger line
    if true_negative == 0 or false_positive == 0:
        odds_ratio = float("inf")
    else:
        odds_ratio = (true_positive * true_negative) / (false_positive * false_negative)

    metrics.update({"epr": odds_ratio})
    if doplot:
        print("EPR:", odds_ratio)
        plt.figure(figsize=(10, 8))
        plt.plot(
            recall_list,
            precision_list,
            marker=".",
            linestyle="-",
            color="b",
            label="p-r",
        )
        plt.plot(
            [recall_list[0], recall_list[-1]],
            [rand_prec, rand_prec],
            linestyle="--",
            color="r",
            label="Random Precision",
        )
        plt.legend(loc="lower left")
        plt.title("Precision-Recall Curve")
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        plt.xscale("log")
        plt.grid(True)
        plt.show()
    return metrics

download_perturb_gt

download_perturb_gt downloads the genome wide perturb seq ground truth data.

Parameters:
  • filename_bh (str, default: FILEDIR + '/../data/BH-corrected.csv.gz' ) –

    The local filename to save the BH-corrected data. Defaults to FILEDIR + "/../data/BH-corrected.csv.gz".

  • filename_adata (str, default: FILEDIR + '/../data/ess_perturb_sc.h5ad' ) –

    The local filename to save the single-cell perturbation data. Defaults to FILEDIR + "/../data/ess_perturb_sc.h5ad".

  • url_bh (str, default: 'https://plus.figshare.com/ndownloader/files/38349308' ) –

    The URL to download the BH-corrected data. Defaults to "https://plus.figshare.com/ndownloader/files/38349308".

  • url_adata (str, default: 'https://plus.figshare.com/ndownloader/files/35773219' ) –

    The URL to download the single-cell perturbation data. Defaults to "https://plus.figshare.com/ndownloader/files/35773219".

Source code in bengrn/base.py
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
def download_perturb_gt(
    filename_bh: str = FILEDIR + "/../data/BH-corrected.csv.gz",
    filename_adata: str = FILEDIR + "/../data/ess_perturb_sc.h5ad",
    url_bh: str = "https://plus.figshare.com/ndownloader/files/38349308",
    url_adata: str = "https://plus.figshare.com/ndownloader/files/35773219",
):
    """
    download_perturb_gt downloads the genome wide perturb seq ground truth data.

    Args:
        filename_bh (str, optional): The local filename to save the BH-corrected data. Defaults to FILEDIR + "/../data/BH-corrected.csv.gz".
        filename_adata (str, optional): The local filename to save the single-cell perturbation data. Defaults to FILEDIR + "/../data/ess_perturb_sc.h5ad".
        url_bh (str, optional): The URL to download the BH-corrected data. Defaults to "https://plus.figshare.com/ndownloader/files/38349308".
        url_adata (str, optional): The URL to download the single-cell perturbation data. Defaults to "https://plus.figshare.com/ndownloader/files/35773219".
    """
    os.makedirs(os.path.dirname(filename_bh), exist_ok=True)
    urllib.request.urlretrieve(url_bh, filename_bh)
    sc.read(
        filename_adata,
        backup_url=url_adata,
    )

get_GT_db

use_prior_network loads a prior GRN from a list of available networks.

Parameters:
  • name (str, default: 'collectri' ) –

    name of the network to load. Defaults to "collectri".

  • organism (str, default: 'human' ) –

    organism to load the network for. Defaults to "human".

  • split_complexes (bool, default: True ) –

    whether to split complexes into individual genes. Defaults to True.

Returns:
  • DataFrame

    pd.DataFrame: The prior GRN as a pandas DataFrame.

Raises:
  • ValueError

    if the provided name is not amongst the available names.

Source code in bengrn/base.py
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
def get_GT_db(
    name: str = "collectri", organism: str = "human", split_complexes: bool = True
) -> pd.DataFrame:
    """
    use_prior_network loads a prior GRN from a list of available networks.

    Args:
        name (str, optional): name of the network to load. Defaults to "collectri".
        organism (str, optional): organism to load the network for. Defaults to "human".
        split_complexes (bool, optional): whether to split complexes into individual genes. Defaults to True.

    Returns:
        pd.DataFrame: The prior GRN as a pandas DataFrame.

    Raises:
        ValueError: if the provided name is not amongst the available names.
    """
    if name == "tflink":
        TFLINK = "https://cdn.netbiol.org/tflink/download_files/TFLink_Homo_sapiens_interactions_All_simpleFormat_v1.0.tsv.gz"
        net = pd_load_cached(TFLINK)
        net = net.rename(columns={"Name.TF": "source", "Name.Target": "target"})
    elif name == "htftarget":
        HTFTARGET = "http://bioinfo.life.hust.edu.cn/static/hTFtarget/file_download/tf-target-infomation.txt"
        net = pd_load_cached(HTFTARGET)
        net = net.rename(columns={"TF": "source"})
    elif name == "collectri":
        import decoupler as dc

        net = dc.get_collectri(organism=organism, split_complexes=split_complexes).drop(
            columns=["PMID"]
        )
    elif name == "dorothea":
        import decoupler as dc

        net = dc.get_dorothea(organism=organism)
    elif name == "omnipath":
        if not os.path.exists(FILEDIR + "/../data/omnipath.parquet"):
            os.makedirs(
                os.path.dirname(FILEDIR + "/../data/omnipath.parquet"), exist_ok=True
            )
            from omnipath.interactions import AllInteractions
            from omnipath.requests import Annotations

            interactions = AllInteractions()
            net = interactions.get(exclude=["small_molecule", "lncrna_mrna"])
            hgnc = Annotations.get(resources="HGNC")
            rename = {v.uniprot: v.genesymbol for k, v in hgnc.iterrows()}
            net.source = net.source.replace(rename)
            net.target = net.target.replace(rename)
            net.to_parquet(FILEDIR + "/../data/omnipath.parquet")
        else:
            net = pd.read_parquet(FILEDIR + "/../data/omnipath.parquet")
    else:
        raise ValueError(f"provided name: '{name}' is not amongst the available names.")
    # varnames = list(set(net.iloc[:, :2].values.flatten()))
    # adata = AnnData(var=varnames)
    # adata.var_names = varnames
    # grn = from_adata_and_longform(adata, net, has_weight=True)
    # return grn
    return net

get_perturb_gt

get_perturb_gt retrieves the genome wide perturb seq ground truth data.

Parameters:
  • filename_bh (str, default: FILEDIR + '/../data/BH-corrected.csv.gz' ) –

    The local filename to save the BH-corrected data. Defaults to FILEDIR + "/../data/BH-corrected.csv.gz".

  • url_adata (str, default: 'https://plus.figshare.com/ndownloader/files/35773219' ) –

    The URL to download the single-cell perturbation data. Defaults to "https://plus.figshare.com/ndownloader/files/35773219".

  • filename_adata (str, default: FILEDIR + '/../data/ess_perturb_sc.h5ad' ) –

    The local filename to save the single-cell perturbation data. Defaults to FILEDIR + "/../data/ess_perturb_sc.h5ad".

Returns:
  • GRNAnnData

    The Gene Regulatory Network data as a GRNAnnData object.

Source code in bengrn/base.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
def get_perturb_gt(
    filename_adata: str = FILEDIR + "/../data/ess_perturb_sc.h5ad",
    filename_bh: str = FILEDIR + "/../data/BH-corrected.csv.gz",
    url_adata: str = "https://plus.figshare.com/ndownloader/files/35773219",
):
    """
    get_perturb_gt retrieves the genome wide perturb seq ground truth data.

    Args:
        filename_bh (str, optional): The local filename to save the BH-corrected data. Defaults to FILEDIR + "/../data/BH-corrected.csv.gz".
        url_adata (str, optional): The URL to download the single-cell perturbation data. Defaults to "https://plus.figshare.com/ndownloader/files/35773219".
        filename_adata (str, optional): The local filename to save the single-cell perturbation data. Defaults to FILEDIR + "/../data/ess_perturb_sc.h5ad".

    Returns:
        GRNAnnData: The Gene Regulatory Network data as a GRNAnnData object.
    """
    if not os.path.exists(filename_bh):
        download_perturb_gt(filename_bh=filename_bh)
    pert = pd.read_csv(filename_bh)
    pert = pert.set_index("Unnamed: 0").T
    pert.index = [i.split("_")[-1] for i in pert.index]
    pert = pert[~pert.index.duplicated(keep="first")].T
    pert = pert < 0.05
    adata_sc = sc.read(
        filename_adata,
        backup_url=url_adata,
    )
    adata_sc = adata_sc[adata_sc.obs.gene_id == "non-targeting"]
    adata_sc[:, adata_sc.var.index.isin(set(pert.index) | set(pert.columns))]
    adata_sc.obs["organism_ontology_term_id"] = "NCBITaxon:9606"
    adata_sc = adata_sc[:, adata_sc.var.sort_index().index]

    pert = pert.loc[pert.index.isin(adata_sc.var.index)].loc[
        :, pert.columns.isin(adata_sc.var.index)
    ]

    missing_indices = list(set(adata_sc.var.index) - set(pert.index))
    pert = pert.reindex(pert.index.union(missing_indices), fill_value=False)
    missing_indices = list(set(adata_sc.var.index) - set(pert.columns))
    pert = pert.reindex(columns=pert.columns.union(missing_indices), fill_value=False)

    pert = pert.loc[adata_sc.var.index].loc[:, adata_sc.var.index]
    return GRNAnnData(
        X=csr_matrix(adata_sc.X.toarray()),
        var=adata_sc.var,
        obs=adata_sc.obs,
        grn=csr_matrix(pert.values),
    )

get_scenicplus

This function retrieves a loomx scenicplus data from a given file path and loads it as a GrnnData

Parameters:
  • filepath

    str, optional The path to the scenicplus data file. Default is FILEDIR + "/../data/10xPBMC_homo_scenicplus_genebased_scope.loom".

Raises:
  • FileNotFoundError

    If the file at the given path does not exist.

Returns:
  • GrnAnnData

    The scenicplus data from the given file as a grnndata object

Source code in bengrn/base.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def get_scenicplus(
    filepath=FILEDIR + "/../data/10xPBMC_homo_scenicplus_genebased_scope.loom",
):
    """
    This function retrieves a loomx scenicplus data from a given file path and loads it as a GrnnData

    Args:
        filepath : str, optional
            The path to the scenicplus data file.
            Default is FILEDIR + "/../data/10xPBMC_homo_scenicplus_genebased_scope.loom".

    Raises:
        FileNotFoundError: If the file at the given path does not exist.

    Returns:
        GrnAnnData: The scenicplus data from the given file as a grnndata object
    """
    if not os.path.exists(filepath):
        raise FileNotFoundError(
            f"The file {filepath} does not exist. You likely need to download \
                this or another loomxfile from the scope website"
        )

    return from_scope_loomfile(filepath)

get_sroy_gt

This function retrieves the ground truth data from the McCall et al.'s paper.

Parameters:
  • get (str, default: 'mine' ) –

    The specific dataset to retrieve. Options include "mine", "liu", and "chen".

  • join (str, default: 'outer' ) –

    The type of join to be performed when concatenating the data. Default is "outer".

  • species (str, default: 'human' ) –

    The species of the dataset. Default is "human".

  • gt (str, default: 'full' ) –

    The type of ground truth data to retrieve. Options include "full", "chip", and "ko". Default is "full".

Returns:
  • GrnAnnData( GRNAnnData ) –

    The ground truth data as a grnndata object

Source code in bengrn/base.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def get_sroy_gt(
    get: str = "mine", join: str = "outer", species: str = "human", gt: str = "full"
) -> GRNAnnData:
    """
    This function retrieves the ground truth data from the McCall et al.'s paper.

    Args:
        get (str): The specific dataset to retrieve. Options include "mine", "liu", and "chen".
        join (str, optional): The type of join to be performed when concatenating the data. Default is "outer".
        species (str, optional): The species of the dataset. Default is "human".
        gt (str, optional): The type of ground truth data to retrieve. Options include "full", "chip", and "ko". Default is "full".

    Returns:
        GrnAnnData: The ground truth data as a grnndata object
    """
    # Download and store the ground truth data file

    if not os.path.exists(
        os.path.join(FILEDIR, "..", "data", "GroundTruth", "stone_and_sroy")
    ):
        download_sroy_gt(
            gt_file_path=os.path.join(FILEDIR, "..", "data", "GroundTruth.tar.gz")
        )
    if species == "human":
        if gt == "full":
            df = pd.read_csv(
                FILEDIR + "/../data/GroundTruth/stone_and_sroy/hESC_ground_truth.tsv",
                sep="\t",
                header=None,
            )
        elif gt == "chip":
            df = pd.read_csv(
                FILEDIR
                + "/../data/GroundTruth/stone_and_sroy/gold_standards/hESC/hESC_chipunion.txt",
                sep="\t",
                header=None,
            )
        elif gt == "ko":
            df = pd.read_csv(
                FILEDIR
                + "/../data/GroundTruth/stone_and_sroy/gold_standards/hESC/hESC_KDUnion.txt",
                sep="\t",
                header=None,
            )
        if get == "liu":
            adata = AnnData(
                (
                    2
                    ** pd.read_csv(
                        FILEDIR
                        + "/../data/GroundTruth/stone_and_sroy/scRNA/liu_rna_filtered_log2.tsv.gz",
                        sep="\t",
                    )
                )
                - 1
            ).T
        elif get == "chen":
            adata = AnnData(
                (
                    2
                    ** pd.read_csv(
                        FILEDIR
                        + "/../data/GroundTruth/stone_and_sroy/scRNA/chen_rna_filtered_log2.tsv.gz",
                        sep="\t",
                    )
                )
                - 1
            ).T
        elif get == "han":
            adata = AnnData(
                unnormalize(
                    pd.read_csv(
                        FILEDIR
                        + "/../data/GroundTruth/stone_and_sroy/scRNA/human_han_GSE107552.csv.gz",
                    )
                    .set_index("Cell")
                    .T,
                    is_root=True,
                )
            )
        elif get == "mine":
            # https://www.ebi.ac.uk/gxa/sc/experiments/E-GEOD-36552/downloads
            adata = sc.read_mtx(
                FILEDIR
                + "/../data/GroundTruth/stone_and_sroy/scRNA/E-GEOD-36552.aggregated_filtered_counts.mtx"
            ).T
            col = pd.read_csv(
                FILEDIR
                + "/../data/GroundTruth/stone_and_sroy/scRNA/E-GEOD-36552.aggregated_filtered_counts.mtx_rows",
                header=None,
                sep="\t",
            )
            adata.var.index = col[0]
            genesdf = load_genes()
            intersect_genes = set(adata.var.index).intersection(set(genesdf.index))
            adata = adata[:, list(intersect_genes)]
            genesdf = genesdf.loc[adata.var.index]
            adata.var["ensembl_id"] = adata.var.index
            adata.var.index = make_index_unique(genesdf["symbol"].astype(str))
        else:
            raise ValueError("get must be one of 'liu', 'chen', 'han', or 'mine'")
        adata.obs["organism_ontology_term_id"] = "NCBITaxon:9606"
    elif species == "mouse":
        if gt == "full":
            df = pd.read_csv(
                FILEDIR + "/../data/GroundTruth/stone_and_sroy/mESC_ground_truth.tsv",
                sep="\t",
                header=None,
            )
        elif gt == "chip":
            df = pd.read_csv(
                FILEDIR
                + "/../data/GroundTruth/stone_and_sroy/gold_standards/mESC/mESC_chipunion.txt",
                sep="\t",
                header=None,
            )
        elif gt == "ko":
            df = pd.read_csv(
                FILEDIR
                + "/../data/GroundTruth/stone_and_sroy/gold_standards/mESC/mESC_KDUnion.txt",
                sep="\t",
                header=None,
            )
        if get == "duren":
            adata = AnnData(
                (
                    2
                    ** pd.read_csv(
                        FILEDIR
                        + "/../data/GroundTruth/stone_and_sroy/scRNA/duren_rna_filtered_log2.tsv.gz",
                        sep="\t",
                    )
                )
                - 1
            ).T
        elif get == "semrau":
            adata = AnnData(
                (
                    2
                    ** pd.read_csv(
                        FILEDIR
                        + "/../data/GroundTruth/stone_and_sroy/scRNA/semrau_rna_filtered_log2.tsv.gz",
                        sep="\t",
                    )
                )
                - 1
            ).T
        elif get == "tran":
            adata = AnnData(
                unnormalize(
                    pd.read_csv(
                        FILEDIR
                        + "/../data/GroundTruth/stone_and_sroy/scRNA/mouse_tran_A2S.csv.gz",
                    )
                    .set_index("Cell")
                    .T,
                    is_root=True,
                )
            )
        elif get == "zhao":
            adata = AnnData(
                unnormalize(
                    pd.read_csv(
                        FILEDIR
                        + "/../data/GroundTruth/stone_and_sroy/scRNA/mouse_zhao_GSE114952.csv.gz",
                    )
                    .set_index("Cell")
                    .T,
                    is_root=True,
                )
            )
        else:
            raise ValueError("get must be one of 'duren', 'semrau', 'tran', or 'zhao'")
        adata.obs["organism_ontology_term_id"] = "NCBITaxon:10090"
    return from_adata_and_longform(adata, df)

load_genes

load_genes loads the genes for the given organisms.

Parameters:
  • organisms (Union[str, list], default: 'NCBITaxon:9606' ) –

    The organism(s) to load genes for. Can be a single organism string or a list of organism strings. Defaults to "NCBITaxon:9606".

Returns:
  • pd.DataFrame: A DataFrame containing gene information for the specified organisms, including columns for gene symbols, mitochondrial genes, ribosomal genes, hemoglobin genes, and organism.

Source code in bengrn/base.py
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
def load_genes(organisms: Union[str, list] = "NCBITaxon:9606"):  # "NCBITaxon:10090",
    """
    load_genes loads the genes for the given organisms.

    Args:
        organisms (Union[str, list], optional): The organism(s) to load genes for. Can be a single organism string or a list of organism strings. Defaults to "NCBITaxon:9606".

    Returns:
        pd.DataFrame: A DataFrame containing gene information for the specified organisms, including columns for gene symbols, mitochondrial genes, ribosomal genes, hemoglobin genes, and organism.
    """
    try:
        import bionty as bt
    except ImportError:
        raise ImportError(
            "bionty is not installed. Please install it with pip install bionty \
            you will also need to populate its genes, have a look at jkobject/scdataloader's package \
            and its populate_ontology function"
        )
    organismdf = []
    if type(organisms) is str:
        organisms = [organisms]
    for organism in organisms:
        genesdf = bt.Gene.filter(
            organism_id=bt.Organism.filter(ontology_id=organism).first().id
        ).df()
        genesdf = genesdf.drop_duplicates(subset="ensembl_gene_id")
        genesdf = genesdf.set_index("ensembl_gene_id").sort_index()
        # mitochondrial genes
        genesdf["mt"] = genesdf.symbol.astype(str).str.startswith("MT-")
        # ribosomal genes
        genesdf["ribo"] = genesdf.symbol.astype(str).str.startswith(("RPS", "RPL"))
        # hemoglobin genes.
        genesdf["hb"] = genesdf.symbol.astype(str).str.contains(("^HB[^(P)]"))
        genesdf["organism"] = organism
        organismdf.append(genesdf)
    organismdf = pd.concat(organismdf)
    for col in ["source_id", "run_id", "created_by_id", "updated_at", "stable_id"]:
        if col in organismdf.columns:
            organismdf.drop(columns=[col], inplace=True)
    return organismdf

precision_recall

Calculate precision and recall from the true and predicted connections.

Parameters:
  • true_con (array) –

    The true connections.

  • grn_con (array) –

    The predicted connections.

Returns:
  • float

    The precision value.

  • float

    The recall value.

Source code in bengrn/base.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def precision_recall(true_con, grn_con):
    """
    Calculate precision and recall from the true and predicted connections.

    Args:
        true_con (np.array): The true connections.
        grn_con (np.array): The predicted connections.

    Returns:
        float: The precision value.
        float: The recall value.
    """
    tp = len(true_con & grn_con)
    fp = len(grn_con - true_con)
    fn = len(true_con - grn_con)

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)

    return precision, recall

train_classifier

train_classifier trains a classifier to generate a GRN that maps to the ground truth.

Uses a RidgeClassifier to select the best combination of networks to predict the ground truth. It is used for the head classification part in the scPRINT paper.

Parameters:
  • grn (GRNAnnData) –

    The Gene Regulatory Network data.

  • gt (str, default: 'omnipath' ) –

    The name of the ground truth database to use. Defaults to "omnipath".

  • other (GRNAnnData, default: None ) –

    Another GRN to compare against. Defaults to None.

  • use_col (str, default: 'symbol' ) –

    The column name to use for gene symbols. Defaults to "symbol".

  • train_size (float, default: 0.2 ) –

    The proportion of the dataset to include in the train split. Defaults to 0.2.

  • doplot (bool, default: True ) –

    Whether to plot the results. Defaults to True.

  • class_weight (dict, default: {1: 200, 0: 1} ) –

    Weights associated with classes in the form {class_label: weight}. Defaults to {1: 200, 0: 1}.

  • max_iter (int, default: 1000 ) –

    Maximum number of iterations for the classifier. Defaults to 1_000.

  • C (float, default: 1.0 ) –

    Regularization strength; must be a positive float. Defaults to 1.0.

  • return_full (bool, default: True ) –

    Whether to return the full classifier object. Defaults to True.

  • shuffle (bool, default: False ) –

    Whether to shuffle the data before splitting. Defaults to False.

Returns:
  • (GRNAnnData, dict, RidgeClassifier)

    The Gene Regulatory Network data, the metrics of the classifier, and the classifier object.

Source code in bengrn/base.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def train_classifier(
    grn: GRNAnnData,
    gt: str = "omnipath",
    other: Optional[GRNAnnData] = None,
    use_col: str = "symbol",
    train_size: float = 0.2,
    doplot: bool = True,
    class_weight: dict = {1: 200, 0: 1},
    max_iter: int = 1_000,
    C: float = 1.0,
    return_full: bool = True,
    shuffle: bool = False,
    **kwargs,
):
    """
    train_classifier trains a classifier to generate a GRN that maps to the ground truth.

    Uses a RidgeClassifier to select the best combination of networks to predict the ground truth.
    It is used for the head classification part in the scPRINT paper.

    Args:
        grn (GRNAnnData): The Gene Regulatory Network data.
        gt (str, optional): The name of the ground truth database to use. Defaults to "omnipath".
        other (GRNAnnData, optional): Another GRN to compare against. Defaults to None.
        use_col (str, optional): The column name to use for gene symbols. Defaults to "symbol".
        train_size (float, optional): The proportion of the dataset to include in the train split. Defaults to 0.2.
        doplot (bool, optional): Whether to plot the results. Defaults to True.
        class_weight (dict, optional): Weights associated with classes in the form {class_label: weight}. Defaults to {1: 200, 0: 1}.
        max_iter (int, optional): Maximum number of iterations for the classifier. Defaults to 1_000.
        C (float, optional): Regularization strength; must be a positive float. Defaults to 1.0.
        return_full (bool, optional): Whether to return the full classifier object. Defaults to True.
        shuffle (bool, optional): Whether to shuffle the data before splitting. Defaults to False.

    Returns:
        (GRNAnnData, dict, RidgeClassifier): The Gene Regulatory Network data, the metrics of the classifier, and the classifier object.
    """
    if other is not None:
        elems = other.var[other.grn.sum(1) != 0].index.tolist()
        sub = other.get(grn.var[use_col].tolist()).get(elems).targets
        if sub.shape[1] < 5:
            print("sub is very small: ", sub.shape[1])
        genes = grn.var[use_col].tolist()
        args = np.argsort(genes)
        genes = np.array(genes)[args]
        adj = grn.varp["GRN"][args, :, :][:, args, :][np.isin(genes, sub.index.values)][
            :, np.isin(genes, sub.columns.values)
        ]
        print("pred shape", adj.shape)
        da = sub.values
    else:
        gt = get_GT_db(name=gt)
        varnames = set(gt.iloc[:, :2].values.flatten())
        intersection = varnames & set(grn.var[use_col].tolist())
        loc = grn.var[use_col].isin(intersection)
        adj = grn.varp["GRN"][:, loc, :][loc, :, :]
        genes = grn.var.loc[loc][use_col].tolist()

        da = np.zeros((len(genes), len(genes)), dtype=float)
        for i, j in gt.iloc[:, :2].values:
            if i in genes and j in genes:
                da[genes.index(i), genes.index(j)] = 1

    print("true elem", int(da.sum()), "...")
    da = da.flatten()
    adj = adj.reshape(-1, adj.shape[-1])

    X_train, X_test, y_train, y_test = train_test_split(
        adj, da, random_state=0, train_size=train_size, shuffle=shuffle
    )
    print("doing classification....")

    clf = RidgeClassifier(
        alpha=C,
        fit_intercept=False,
        class_weight=class_weight,
        # solver="saga",
        max_iter=max_iter,
        positive=True,
    )
    # clf = LogisticRegression(
    #    penalty="l1",
    #    C=C,
    #    solver="saga",
    #    class_weight=class_weight,
    #    max_iter=max_iter,
    #    n_jobs=8,
    #    fit_intercept=False,
    #    verbose=10,
    #    **kwargs,
    # )
    clf.fit(X_train, y_train)
    pred = clf.predict(X_test)
    # epr = compute_epr(clf, X_test, y_test)
    metrics = {
        "used_heads": (clf.coef_ != 0).sum(),
        "precision": (pred[y_test == 1] == 1).sum() / (pred == 1).sum(),
        "random_precision": y_test.sum() / len(y_test),
        "recall": (pred[y_test == 1] == 1).sum() / y_test.sum(),
        "predicted_true": pred.sum(),
        "number_of_true": y_test.sum(),
        # "epr": epr,
    }
    if doplot:
        print("metrics", metrics)
        PrecisionRecallDisplay.from_estimator(
            clf, X_test, y_test, plot_chance_level=True
        )
        plt.show()
    if return_full:
        adj = grn.varp["GRN"]
        grn.varp["classified"] = clf.predict(adj.reshape(-1, adj.shape[-1])).reshape(
            len(grn.var), len(grn.var)
        )
    return grn, metrics, clf