DiSCERN - Deep Single Cell Expression ReconstructioN can be used to reconstruct missing expression information using a reference data set.
In this tutorial, we'll apply DISCERN on two single cell RNAseq data sets to reconstruct expression information of the query data set (called citeseq here) and a reference data set (called pbmc8k).
Both data sets consist of PBMCs from Healthy Donors sequenced using the 10x Chromium technology. The pbmc8k data set is available here and will be downloaded automatically with the following functions. The second data set, citeseq, was introduced in Stoeckius et al. (2017) and is available here. Additionally, to the transcriptomics information, this data set also contains Protein-Abundance information by Cite-Seq. However, the protein abundance is not used in this tutorial.
The tutorial consists of four parts:
In part 1 the example expression data and cell type information is downloaded from public websites. In part 2 the preprocessing pipeline provided with DiSCERN is applied. This includes log-scaling and scaling of the data. In part 3 DiSCERN is trained and applied on the data sets. pbmc8k is used as a reference data set and the citeseq data set is reconstructed using information from the pbmc8k data set. Finally, some basic downstream analysis steps, as UMAP and optionally clustering, are performed on the DiSCERN-reconstructed data.
import json
import pathlib
import discern
import matplotlib.pyplot as plt
import scanpy as sc
import tensorflow as tf
import numpy as np
tf.random.set_seed(42)
np.random.seed(42)
If discern is not available try installing it with:
pip install -U discern-reconstruction
This part is for downloading the example data. If you have already downloaded this data or plan to use your own data, you can skip this part without any issues.
import json
import pathlib
import pickle
import tarfile
import tempfile
import anndata
import pandas as pd
import requests
def get_citeseq(counts, metadata):
counts = pd.read_csv(counts, index_col=0)
counts = counts.filter(like="HUMAN_", axis=0)
counts.index = counts.index.str.replace("HUMAN_", "")
counts = counts.T
dataset = anndata.AnnData(X=counts)
metadata = pd.read_csv(metadata).set_index("cellnames")
clusterlist = {
0: "CD4 T cells",
1: "CD4 T cells",
2: "NK cells",
3: "CD14+ Monocytes",
4: "B cells",
5: "CD8 T cells",
6: "FCGR3A+ Monocytes",
7: "Other",
}
metadata["celltype"] = (
metadata["cluster"].astype(int).replace(clusterlist).astype("category")
)
dataset.obs = pd.merge(
dataset.obs, metadata, left_index=True, right_index=True, how="left"
)
dataset = dataset[dataset.obs.dropna().index]
dataset = dataset[dataset.obs.celltype != "Other"].copy()
return dataset
def get_8k(counts, metadata):
with tempfile.TemporaryDirectory(dir=counts.parent) as tmp:
with tarfile.open(counts, "r:gz") as file:
file.extractall(tmp)
matrixfile = list(pathlib.Path(tmp).rglob("matrix.mtx"))[0]
dataset = sc.read_10x_mtx(matrixfile.parent)
with pathlib.Path(metadata).open("rb") as file:
metadata = pickle.load(file)
df = pd.DataFrame(
{"barcodes": metadata["barcodes"].index, "clusters": metadata["clusters"]}
).set_index("barcodes")
df["celltype"] = pd.Categorical.from_codes(
metadata["clusters"], metadata["list_clusters"]
)
dataset.obs = pd.merge(
dataset.obs, df, left_index=True, right_index=True, how="left"
)
dataset = dataset[dataset.obs.dropna().index]
dataset = dataset[dataset.obs.celltype != "Other"].copy()
dataset.X = dataset.X.toarray()
return dataset
def download(raw_folder="raw", data_path="data"):
raw_folder = pathlib.Path(raw_folder)
raw_folder.mkdir(exist_ok=True)
data_path = pathlib.Path(data_path)
data_path.mkdir(exist_ok=True)
citeseq_data = raw_folder.joinpath("citeseq_data.csv.gz")
citeseq_metadata = raw_folder.joinpath("citeseq_labels.csv")
pbmc8k_data = raw_folder.joinpath("pbmc8k_data.tar.gz")
pbmc8k_metadata = raw_folder.joinpath("pbmc8k_labels.pickle")
if not citeseq_data.exists():
r = requests.get(
"https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE100866&format=file&file=GSE100866_PBMC_vs_flow_10X-RNA_umi.csv.gz",
allow_redirects=True,
)
citeseq_data.write_bytes(r.content)
if not citeseq_metadata.exists():
r = requests.get(
"https://raw.githubusercontent.com/YosefLab/scVI-data/master/cite.seurat.labels",
allow_redirects=True,
)
citeseq_metadata.write_bytes(r.content)
if not pbmc8k_data.exists():
r = requests.get(
"http://cf.10xgenomics.com/samples/cell-exp/2.1.0/pbmc8k/pbmc8k_filtered_gene_bc_matrices.tar.gz",
allow_redirects=True,
)
pbmc8k_data.write_bytes(r.content)
if not pbmc8k_metadata.exists():
r = requests.get(
"https://github.com/YosefLab/scVI-data/raw/master/pbmc_metadata.pickle",
allow_redirects=True,
)
pbmc8k_metadata.write_bytes(r.content)
citeseq = data_path.joinpath("citeseq.h5ad")
if not citeseq.exists():
get_citeseq(citeseq_data, citeseq_metadata).write(citeseq)
pbmc8k = data_path.joinpath("pbmc8k.h5ad")
if not pbmc8k.exists():
get_8k(pbmc8k_data, pbmc8k_metadata).write(pbmc8k)
download()
DiSCERN requires the training data in a format build upon anndata.AnnData similar to scanpy.
Data set preparation requires a custom preprocessing involving library-size normalization, log-scaling and mean-variance-centering. The preprocessing pipeline closely follows Zheng et al. (2017).
The preprocessing pipeline is implemented in discern.WAERecipe
.
from discern import WAERecipe
Default parameters for this preprocessing and model training can be found in parameters.json
. These parameters should work for most applications.
default_parameter = json.loads(pathlib.Path("parameters.json").read_bytes())[
"experiments"
]["default_experiment"]
In this tutorial we used the dummy data downloaded by download()
function above.
The data should be located in data
directory, but a also custom data sets can be used.
The data sets can easily be read using sc.read
citeseq = sc.read("data/citeseq.h5ad")
pbmc8k = sc.read("data/pbmc8k.h5ad")
discern.WAERecipe
requires the parameters as found in the parameters.json
and
two or more input data sets as dictionary.
The dictionary keys will be used as batch labels later for reconstruction.
recipe = WAERecipe(
params=default_parameter, inputs={"citeseq": citeseq, "pbmc8k": pbmc8k}
)
After initializing the WAERecipe
you can run the pipeline by calling and accessing the processed data using the .sc_raw
attribute.
This will be used as input for DiSCERN model training.
preprocessed_results = recipe().sc_raw
The discern model is implemented in the DISCERN
class and can be initialized directly from the default parameters, read from the parameters.json
file above.
from discern import DISCERN
model = DISCERN.from_json(default_parameter)
As DISCERN is build on the Tensorflow backend, the deep learning model needs to be initalized.
This can be achieved using the build_model
function. It requires 3 arguments to to determine the input and output shapes:
n_nenes
which determines the input and output shape of the gene expression matrix.n_labels
which indicates the number of datasets.scale
which is a loss scaling factor, based on the total number of training samples, it is available in preprocessed_results.config["total_train_count"]
.model.build_model(
n_genes=preprocessed_results.var_names.size,
n_labels=preprocessed_results.obs.batch.nunique(),
scale=preprocessed_results.config["total_train_count"],
)
If you're interested in the model details, you can uncomment the following lines, to see a brief description of the model structure.
# model.wae_model.summary() # Overall structure summary
# model.encoder.summary() # Encoder structure summary
# model.decoder.summary() # Decoder structure summary
The DISCERN model learns the gene regulatory network using the gene expression data. This is called "training".
To do so the data set is splitted into parts of smaller size, called mini-batches.
The size of such a part needs to be set using the batch_size
attribute. Usually 192 is a good choice but a higher or lower size can be choosen depending on the computational power.
preprocessed_results.batch_size = 192
The training can be performed using the training
method.
The number of training steps (max_steps
) is data set dependent and a higher value increase the running time, but up to a point, also the reconstruction quality.
Usually something around 30-50 is a good choice. Here it runs with 20 steps for computational reasons.
The Commandline - API contains a more advanced method to determine the number of steps automatically (if defined in the early_stopping
section of the parameters.json
file).
losses = model.training(inputdata=preprocessed_results, max_steps=20)
Epoch 1/20 Epoch 1/20 66/66 - 51s - loss: 8858.6342 - decoder_counts_loss: 0.4703 - decoder_dropouts_loss: 0.1634 - mmdpp_loss: 0.1129 - sigma_regularization_loss: 478.0991 - val_loss: 6895.9799 - val_decoder_counts_loss: 0.4349 - val_decoder_dropouts_loss: 0.0023 - val_mmdpp_loss: 0.2336 - val_sigma_regularization_loss: 476.2852 Epoch 2/20 Epoch 1/20 66/66 - 44s - loss: 6518.0000 - decoder_counts_loss: 0.4317 - decoder_dropouts_loss: 0.0012 - mmdpp_loss: 0.0205 - sigma_regularization_loss: 477.4609 - val_loss: 6776.7784 - val_decoder_counts_loss: 0.4280 - val_decoder_dropouts_loss: 6.8535e-04 - val_mmdpp_loss: 0.2337 - val_sigma_regularization_loss: 475.4300 Epoch 3/20 Epoch 1/20 66/66 - 45s - loss: 6429.1937 - decoder_counts_loss: 0.4264 - decoder_dropouts_loss: 5.6759e-04 - mmdpp_loss: 0.0182 - sigma_regularization_loss: 475.3278 - val_loss: 6768.5010 - val_decoder_counts_loss: 0.4262 - val_decoder_dropouts_loss: 4.7638e-04 - val_mmdpp_loss: 0.2475 - val_sigma_regularization_loss: 475.0632 Epoch 4/20 Epoch 1/20 66/66 - 44s - loss: 6395.8050 - decoder_counts_loss: 0.4247 - decoder_dropouts_loss: 3.0565e-04 - mmdpp_loss: 0.0149 - sigma_regularization_loss: 475.3589 - val_loss: 6779.2905 - val_decoder_counts_loss: 0.4242 - val_decoder_dropouts_loss: 1.9061e-04 - val_mmdpp_loss: 0.2758 - val_sigma_regularization_loss: 475.9529 Epoch 5/20 Epoch 1/20 66/66 - 44s - loss: 6372.0859 - decoder_counts_loss: 0.4233 - decoder_dropouts_loss: 1.5485e-04 - mmdpp_loss: 0.0137 - sigma_regularization_loss: 475.9685 - val_loss: 6794.5139 - val_decoder_counts_loss: 0.4231 - val_decoder_dropouts_loss: 1.1588e-04 - val_mmdpp_loss: 0.2982 - val_sigma_regularization_loss: 474.8975 Epoch 6/20 Epoch 1/20 66/66 - 44s - loss: 6354.2130 - decoder_counts_loss: 0.4223 - decoder_dropouts_loss: 1.0324e-04 - mmdpp_loss: 0.0122 - sigma_regularization_loss: 475.9850 - val_loss: 6776.5525 - val_decoder_counts_loss: 0.4226 - val_decoder_dropouts_loss: 8.3747e-05 - val_mmdpp_loss: 0.2916 - val_sigma_regularization_loss: 475.9801 Epoch 7/20 Epoch 1/20 66/66 - 45s - loss: 6340.4487 - decoder_counts_loss: 0.4216 - decoder_dropouts_loss: 8.5378e-05 - mmdpp_loss: 0.0103 - sigma_regularization_loss: 476.8190 - val_loss: 6798.2028 - val_decoder_counts_loss: 0.4222 - val_decoder_dropouts_loss: 8.2006e-05 - val_mmdpp_loss: 0.3092 - val_sigma_regularization_loss: 476.5573 Epoch 8/20 Epoch 1/20 66/66 - 44s - loss: 6329.9957 - decoder_counts_loss: 0.4210 - decoder_dropouts_loss: 8.1420e-05 - mmdpp_loss: 0.0095 - sigma_regularization_loss: 476.7119 - val_loss: 6789.8202 - val_decoder_counts_loss: 0.4223 - val_decoder_dropouts_loss: 6.1828e-05 - val_mmdpp_loss: 0.3027 - val_sigma_regularization_loss: 477.3794 Epoch 9/20 Epoch 1/20 66/66 - 45s - loss: 6327.3057 - decoder_counts_loss: 0.4208 - decoder_dropouts_loss: 7.2532e-05 - mmdpp_loss: 0.0100 - sigma_regularization_loss: 477.9093 - val_loss: 6784.3025 - val_decoder_counts_loss: 0.4213 - val_decoder_dropouts_loss: 5.6156e-05 - val_mmdpp_loss: 0.3098 - val_sigma_regularization_loss: 478.8586 Epoch 10/20 Epoch 1/20 66/66 - 44s - loss: 6320.8418 - decoder_counts_loss: 0.4204 - decoder_dropouts_loss: 6.2778e-05 - mmdpp_loss: 0.0090 - sigma_regularization_loss: 477.3816 - val_loss: 6813.0956 - val_decoder_counts_loss: 0.4214 - val_decoder_dropouts_loss: 4.6027e-05 - val_mmdpp_loss: 0.3278 - val_sigma_regularization_loss: 477.6196 Epoch 11/20 Epoch 1/20 66/66 - 44s - loss: 6317.3252 - decoder_counts_loss: 0.4203 - decoder_dropouts_loss: 5.3565e-05 - mmdpp_loss: 0.0085 - sigma_regularization_loss: 477.2933 - val_loss: 6787.0486 - val_decoder_counts_loss: 0.4214 - val_decoder_dropouts_loss: 4.1761e-05 - val_mmdpp_loss: 0.3108 - val_sigma_regularization_loss: 477.5898 Epoch 12/20 Epoch 1/20 66/66 - 44s - loss: 6311.5696 - decoder_counts_loss: 0.4199 - decoder_dropouts_loss: 4.6641e-05 - mmdpp_loss: 0.0086 - sigma_regularization_loss: 477.5662 - val_loss: 6791.5834 - val_decoder_counts_loss: 0.4205 - val_decoder_dropouts_loss: 3.6801e-05 - val_mmdpp_loss: 0.3227 - val_sigma_regularization_loss: 477.0631 Epoch 13/20 Epoch 1/20 66/66 - 46s - loss: 6308.3087 - decoder_counts_loss: 0.4197 - decoder_dropouts_loss: 3.9964e-05 - mmdpp_loss: 0.0085 - sigma_regularization_loss: 477.0289 - val_loss: 6808.4854 - val_decoder_counts_loss: 0.4210 - val_decoder_dropouts_loss: 3.2413e-05 - val_mmdpp_loss: 0.3289 - val_sigma_regularization_loss: 477.6583 Epoch 14/20 Epoch 1/20 66/66 - 45s - loss: 6306.0642 - decoder_counts_loss: 0.4196 - decoder_dropouts_loss: 3.4463e-05 - mmdpp_loss: 0.0073 - sigma_regularization_loss: 477.0714 - val_loss: 6807.2794 - val_decoder_counts_loss: 0.4205 - val_decoder_dropouts_loss: 2.9361e-05 - val_mmdpp_loss: 0.3327 - val_sigma_regularization_loss: 477.5063 Epoch 15/20 Epoch 1/20 66/66 - 45s - loss: 6304.7871 - decoder_counts_loss: 0.4195 - decoder_dropouts_loss: 3.1513e-05 - mmdpp_loss: 0.0077 - sigma_regularization_loss: 477.7596 - val_loss: 6806.5796 - val_decoder_counts_loss: 0.4207 - val_decoder_dropouts_loss: 2.3668e-05 - val_mmdpp_loss: 0.3301 - val_sigma_regularization_loss: 479.3650 Epoch 16/20 Epoch 1/20 66/66 - 47s - loss: 6299.7686 - decoder_counts_loss: 0.4193 - decoder_dropouts_loss: 2.7554e-05 - mmdpp_loss: 0.0070 - sigma_regularization_loss: 479.6511 - val_loss: 6782.5933 - val_decoder_counts_loss: 0.4192 - val_decoder_dropouts_loss: 2.2759e-05 - val_mmdpp_loss: 0.3299 - val_sigma_regularization_loss: 480.7833 Epoch 17/20 Epoch 1/20 66/66 - 46s - loss: 6299.3759 - decoder_counts_loss: 0.4192 - decoder_dropouts_loss: 2.5020e-05 - mmdpp_loss: 0.0074 - sigma_regularization_loss: 480.3706 - val_loss: 6825.1320 - val_decoder_counts_loss: 0.4202 - val_decoder_dropouts_loss: 2.1739e-05 - val_mmdpp_loss: 0.3475 - val_sigma_regularization_loss: 480.8604 Epoch 18/20 Epoch 1/20 66/66 - 47s - loss: 6294.7341 - decoder_counts_loss: 0.4190 - decoder_dropouts_loss: 2.3145e-05 - mmdpp_loss: 0.0062 - sigma_regularization_loss: 480.8022 - val_loss: 6806.9532 - val_decoder_counts_loss: 0.4199 - val_decoder_dropouts_loss: 2.1477e-05 - val_mmdpp_loss: 0.3390 - val_sigma_regularization_loss: 481.4603 Epoch 19/20 Epoch 1/20 66/66 - 46s - loss: 6292.3716 - decoder_counts_loss: 0.4187 - decoder_dropouts_loss: 2.1096e-05 - mmdpp_loss: 0.0075 - sigma_regularization_loss: 481.0567 - val_loss: 6795.8798 - val_decoder_counts_loss: 0.4204 - val_decoder_dropouts_loss: 1.9491e-05 - val_mmdpp_loss: 0.3269 - val_sigma_regularization_loss: 480.9211 Epoch 20/20 Epoch 1/20 66/66 - 47s - loss: 6287.3505 - decoder_counts_loss: 0.4186 - decoder_dropouts_loss: 1.9698e-05 - mmdpp_loss: 0.0059 - sigma_regularization_loss: 480.2509 - val_loss: 6804.0711 - val_decoder_counts_loss: 0.4198 - val_decoder_dropouts_loss: 1.6895e-05 - val_mmdpp_loss: 0.3378 - val_sigma_regularization_loss: 480.4938
The model losses, indicating the quality of the learned information, can be visualized from the resulting losses
object.
plt.plot(losses.epoch, losses.history["loss"], label="training loss")
plt.plot(losses.epoch, losses.history["val_loss"], label="validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
Finally the preprocessed data can be reconstructed with the trained model.
Therefore the reference data set needs to be mentioned using a column available in preprocessed_results.obs
.
In the usual case this column is batch
and the reference batch is supplied by the column_value
argment. Here we uses the pbmc8k
data set as reference.
reconstructed_data = model.reconstruct(
input_data=preprocessed_results,
column="batch",
column_value="pbmc8k",
)
The reconstructed data contains the latent representation of the expression data as computed by DiSCERN in .obsm["X_DISCERN"]
and the reconstructed expression information in .X
.
Additionally, Information about the estimated counts before applying the estimated dropouts can be found in .layers
.
display(reconstructed_data.obsm)
reconstructed_data.layers
AxisArrays with keys: X_DISCERN
Layers with keys: estimated_counts, estimated_dropouts
Finally the data can be analyzed with a standard pipeline, for example for data visulization.
Thus, we perform PCA and UMAP computation as implemented in scanpy
.
See here fore more information.
sc.pp.pca(preprocessed_results)
sc.pp.pca(reconstructed_data)
sc.pp.neighbors(preprocessed_results)
sc.pp.neighbors(reconstructed_data)
sc.tl.umap(preprocessed_results)
sc.tl.umap(reconstructed_data)
sc.pl.umap(
preprocessed_results,
color=["batch", "celltype"],
title=["Batch (uncorrected)", "Cell type (Uncorrected)"],
)
sc.pl.umap(
reconstructed_data,
color=["batch", "celltype"],
title=["Batch (DISCERN)", "Cell type (DISCERN)"],
)
... storing 'orig.ident' as categorical ... storing 'dataset' as categorical
... storing 'orig.ident' as categorical ... storing 'dataset' as categorical
As an optional step, we show that with DISCERN integrated data it is easier to recover the originally identified cell types in the combined data set.
pip install leidenalg # Required for running sc.tl.leiden but not installed by default
sc.tl.leiden(reconstructed_data, resolution=0.1)
sc.tl.leiden(preprocessed_results, resolution=0.1)
sc.pl.umap(
preprocessed_results,
color=["celltype", "leiden"],
title=["Cell type (Uncorrected)", "Cluster (Uncorrected)"],
)
sc.pl.umap(
reconstructed_data,
color=["celltype", "leiden"],
title=["Cell type (DISCERN)", "Cluster (DISCERN)"],
)