scTab: Scaling cross-tissue single-cell annotation models

Dataset preparationThe dataset used in this paper is based on the CELLxGENE15 census version 2023-05-15 (https://chanzuckerberg.github.io/cellxgene-census/index.html). The census version 2023-05-15 is selected as it is a long-term supported (LTS) release and will be hosted by CELLxGENE for at least 5 years. This makes the dataset creation easily reproducible for the foreseeable future. We subsetted to human datasets and used the human protein-coding genes (19,331) as a feature space.The following criteria are used to filter the human CELLxGENE census data:

1.

The census data is subset to primary data only (is_primary_data == True) to prevent label leakage between the train, validation, and test set.

2.

Only sequencing data from 10x-based sequencing protocols is used. In terms of the CELLxGENE census, this means subsetting the assay metadata column to the following terms: 10×5’ v2, 10×3’ v3, 10×3’ v2, 10×5’ v1, 10×3’ v1, 10×3’ transcription profiling, 10×5’ transcription profiling.

3.

The annotated cell type has to be a subtype of the native cell label based on the underlying cell type ontology.

4.

For each cell type, there have to be at least 5000 unique cells. Otherwise, the whole cell type is dropped from the dataset.

5.

Each cell type has to be observed across at least 30 donors to reliably quantify whether the trained classifier can generalize to new unseen donors for each cell type. With the used 70-15-15 train, validation, and test split this means that each cell type is represented with at least 4-5 donors in the validation and test set, respectively.

6.

Each cell type needs to have at least seven parent nodes in the cell type ontology. This criterion is used as a heuristic to filter out general cell type labels that do not contain much information.

To be able to better assess how well the trained classifiers generalize to unseen donors or in general to better assess the generalization capabilities of the trained classifiers, the data is split into train, validation, and test sets based on donors and not based on random subsampling. Meaning, each donor is exclusively found either in the training, validation, or test set. Unlike splitting based on e.g. holdout datasets, donor-based splitting mostly preserves the proportion of cells in the training, validation, and test set compared to random subsampling. This is not the case when subsetting the available data based on e.g. datasets, which often results in a very uneven distribution of cells across the training, validation, and test sets as the datasets in the census usually range anywhere between a few thousand cells to a few million cells. Furthermore, dataset-based splitting often makes it hard to ensure that each cell type is observed across both the training data as well as the test data. In the end, the data is split such that 70% of the donors are assigned to the training set and 15% of the donors are assigned to the validation and test set respectively.The data is size factor normalized to 10,000 counts per cell and log1p-transformed.The selection described above results in 22,189,056 cells being selected which span 164 unique cell types, 5052 unique donors, and 56 different tissues. Of the 22.2 million cells 15,240,192 cells are assigned to the training set, 3,500,032 are assigned to the validation set and 3,448,832 cells are assigned to the test set.More detailed explanations and references to the code that can be used to reproduce the above data selection and splitting exactly can be found in the associated GitHub repository under docs/data.md.Subsampled datasetsWe used a subsampled training dataset in the following settings:Dataset size scaling:

Random subsampling: 15% subsampling (2.3 million cells), 30% subsampling (4.6 million cells), 50% subsampling (7.6 million cells), 70% subsampling (10.7 million cells), 100% subsampling (15.2 million cells)

Donor-based subsampling: Subsample to 15% of donors (531 donors / 2.1 million cells), Subsample to 30% of donors (1061 donors / 4.3 million cells), Subsample to 50% of donors (1768 donors / 7.4 million cells), Subsample to 70% of donors (2476 donors / 10.4 million cells), Subsample to 100% of donors (3536 donors / 15.2 million cells)

Data augmentation:In all other cases, the full training dataset is used.All subsampling is done incrementally, e.g. the 30% subsampled dataset includes all cells/donors that are present in the 15% subsampled dataset and so forth.Data loading infrastructureTraining machine learning models on large-scale tabular datasets (which is the case for the scRNA-seq data used in this paper) comes with a set of unique challenges. The first challenge is that the entire dataset does not fit into the memory of a usual server commonly used for training deep learning models. Additionally, the unique nature of tabular data means that you cannot load individual observations from disk efficiently, as individual observations are rather small, and thus loading data points individually creates a lot of random reads which even modern SSDs cannot handle efficiently. Thus, a consecutive block of samples must be loaded at once and then shuffled. Fortunately, there already exist Python libraries that do exactly what is described above. The data loading infrastructure used in this paper is based on the Nvidia Merlin dataloader (https://github.com/NVIDIA-Merlin/dataloader) which gives an easy-to-use API, uses the widely adopted Apache Parquet format to store data on disk and gives performant data loading with GPU-optimized data loaders that directly load the data from disk into GPU memory and then do a 0-copy transfer to PyTorch, TensorFlow or JAX (see Supp. Fig. 8 for details about data loading speed). The above-described data loading infrastructure was fast enough to fully utilize a Nvidia A100 GPU for the models trained in this paper. Moreover, Merlin comes with a wide range of supporting infrastructure like Docker containers (https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-pytorch) from the NGC container hub which makes it easy for people to start using Merlin without the need to set up Python environments first.Data augmentationThe idea behind the data augmentation strategy developed in this paper is that the difference in raw gene space between the same cell type observed across two donors can be used as a data augmentation vector that can simulate how the gene expression of a cell might look like for a different donor. The general idea behind data augmentation is to have easy-to-compute transformations that can be applied during model training. Thus, in this case, we pre-compute augmentation vectors that can be added to the observed gene expression of a cell to artificially increase the training data size during model training:$${x}_{{augmented}}={x}_{{original\; cell}}\pm {x}_{{augmentation\; vector}}$$
(1)
Calculation of augmentation vectorsThe augmentation vectors are calculated as follows:

1.

Subsample 500,000 cells from the training data to have an even distribution across cell types.

2.

Calculate the mean centroids grouped by cell type and donor

3.

Calculate the difference vectors between the mean centroids from step 2 by cell type.

4.

Set all values in the range [−0.25, 0.25] to zero to enforce more sparse augmentation vectors.

5.

Clamp the resulting augmentation vectors to the interval of [−1.5, 1.5] to remove outlier values.

6.

Filter the resulting augmentation vectors for outliers by only sampling the used augmentation vectors from the most prominent k-means clusters (clustering is done with 50 clusters) → sample e.g. 5000 augmentation vectors from the biggest k-means clusters (clusters with more than 2000 difference vectors).

Note that Step 6 is used to enforce the selection of cell type independent augmentation vectors. As can be seen in Supp. Fig. 9a, some of the calculated augmentation vectors are influenced by the cell type based on which they are calculated. This can be problematic if the augmentation vectors are applied in a cell type-independent fashion e.g. by randomly sampling augmentation vectors. To ensure that the augmentation vectors can be applied in cell type independent fashion, we filtered the augmentation vectors in Step 6 to only select augmentation vectors that are mostly cell type independent. This filtering based on K-means clustering indeed results in mostly cell type independent augmentation vectors (Supp. Fig. 9b).The associated parameters of our data augmentation strategy should be tuned in a similar fashion as one would tune the hyperparameter of a neural network. This means calculating the augmentation vectors based on the training split, then selecting the best parameter set based on the validation split, and finally, reporting the performance on the holdout test set.Calculation of augmented gene expression vectors during model trainingThe augmented gene expression vectors are calculated as follows:

1.

Sample an augmentation vector \({x}_{{augmentation\; vector}}\) from the set of augmentation vectors

2.

Sample whether the augmentation vector is added to or subtracted from the original gene expression vector \({x}_{{original\; cell}}\)

3.

Add/subtract the sampled augmentation vector to the original gene expression vector and clamp all values of the newly created vector to be within the interval of [0., 9.]

Explained variance by cell type before and after data augmentationTo estimate how our data augmentation influences the proportion of the overall variance that can be attributed to cell type variation, we fitted a linear regression (sci-kit learn LinearRegression) model which predicts the normalized gene expression based on the cell type and donor of each cell. This corresponds to the following design matrix:$$\hat{y}=1+{onehot}({celltype})+{onehot}({donor})$$
(2)
In the next step, the \({R}^{2}\) score of the model fitted on the original/non-augmented data is compared to the one from the model fitted on the augmented data to show how the amount of total variation in gene expression, which can be attributed to the cell type, changes.Ontology-corrected cell type classificationThe classification performance of the trained models in this paper is evaluated based on the macro average of the F1-scores for each individual cell type. The macro average is used to give each cell type the same weight in the overall classification performance. The F1-score is calculated as follows:$$F1-{score}= 2 \, \frac{{precision}\cdot {recall}}{{precision}+{recall}} \\= \frac{2\cdot {tp}}{2\cdot {tp}+{fp}+{fn}} \, ({{{\rm{tp}}}}:{{{\rm{true}}}}\; {{{\rm{positives}}}},\, {{{\rm{fp}}}}:{{{\rm{false}}}}\; {{{\rm{positives}}}},\, {{{\rm{fn}}}}:{{{\rm{false}}}}\; {{{\rm{negatives}}}})$$
(3)
In order to deal with the often different granularity of annotations (e.g. label T-cell vs label CD4-positive, alpha-beta T cell) the following rules are applied to evaluate whether a prediction is considered right or wrong. A prediction is considered as right, either if the classifier predicts the same label as supplied by the original dataset, or if the classifier predicts a subtype of the label provided by the original dataset – we consider this as a right prediction as the prediction agrees with the true label up to the annotation granularity the author provided. The subtype relations are evaluated based on the Cell Ontology31. An example is if the model predicts the label CD4-positive, alpha-beta T cell when the author annotated cell type is T cell. Moreover, a prediction is considered wrong if the classifier predicts a parent cell type of the true label – we consider this as a wrong prediction as the author supplied a more fine-grained label that the classifier should replicate. An example is if the classifier predicts the label T cell while the cell is labeled as a CD4-positive, alpha-beta T cell in the original dataset. In all other cases, the prediction is considered wrong. Furthermore, the lookup of child nodes in the cell ontology is based on the Ontology Lookup Service (OLS): https://www.ebi.ac.uk/ols/ontologies/cl31.Performance evaluation on coarse cell type labelsTo give an impression of how scTab performs on more coarse cell type labels, we evaluated the performance of our scTab model on a set of more coarsely annotated cell type labels. We selected coarse cell type labels based on the information content score provided by the cell ontology (https://github.com/INCATools/ubergraph?tab=readme-ov-file#graph-organization). The information content score is calculated based on the count of terms related to a given cell ontology term and is in the interval [0, 100], where 100 corresponds to a very specific term with no subclasses. Based on the information content score we used the following rules to define a set of coarse cell type labels:

1.

Get all cell type labels present in the CELLxGENE census which are a subset of the native cell cell type label.

2.

Keep all cell type labels with an information content score of less or equal to 60.

3.

Assign each cell type label to one of the coarse cell type labels from step 2. If based on the cell type ontology, a cell type label can be assigned to more than one of the coarse labels, we only assign it to the coarse label with the highest information content score. Example: the label alpha-beta T cell would be assigned to T cell as the coarse label and not lymphocyte.

4.

Use the grouping from Step 3 to assign each of the fine-grained cell type labels to a coarse cell type label.

Moreover, we would like to note that we did not retrain the model from scratch for the evaluation on the coarse cell type labels. Instead, we aggregated the predictions of the model that was trained on the fine-grained cell type labels. For instance, all predictions of mature T cell subtypes count as predicting the label mature T cell (based on the underlying hierarchy of the Cell Ontology).Model detailsscTab modelOur implementation of scTab is based on the TabNet architecture33 and is mostly taken from the dreamquark-ai/tabnet GitHub repository with some adaptation towards the single-cell use case. The input to the model is all 19,331 protein-coding genes (GENCODE v38/Ensembl 104) selected from the CELLxGENE census data. Moreover, unlike in the original TabNet model, we normalized the input data before feeding it into the neural network. scRNA-seq data is often normalized to have 10,000 counts per cell and is then log1p transformed afterward6,12,22, we applied the same normalization for our scTab model on top of the simple batch normalization layer, which is used in the original TabNet model to normalize the input features, as such a non-linear normalization cannot be achieved by a simple batch normalization layer.The adapted TabNet architecture for scTab (Fig. 1b) consists of two key building blocks: The first building block is the feature transformer, which is a multi-layer perceptron with batch normalization (BN), skip connections, and a gated linear unit nonlinearity (GLU). The feature transformer maps from the input gene expression space to an n_d + n_a dimensional latent space. In the next step, the n_d + n_a dimensional embedding is split into two parts: one with dimension n_d and one with dimension n_a. The part with dimension n_d is used to classify the different cell types and the second part with dimension n_a is used to calculate the attention masks. The feature attention mask is obtained by using a single linear layer followed by a batch normalization layer that maps from the feature attention embedding to the input feature space. The feature attention mask is then obtained by applying the 1.5-entmax46 function to the output of the linear projection layer. Using the 1.5-entmax function instead of the sparsemax function, which is used in the original TabNet model, improved training dynamics and yielded slightly higher model performance. The 1.5-entmax function is defined as follows:$${H}_{1.5}^{T}(p)=\frac{1}{1.5\cdot (1.5-1)} \, {\sum }_{j} \quad ({p}_{j}-{p}_{j}^{1.5}) \, {{{\rm{for}}}}\; {{{\rm{any}}}} \, p\in {\varDelta }^{d}$$
(4)
$$1.5{entmax}(z)= {{argmax}}_{p \, \in \, {\Delta }^{d}}\,{p}^{T}\cdot z+{H}_{1.5}^{T}(p)$$
(5)
After obtaining the feature attention mask, the masked input features are fed into the feature transformer to obtain the feature embedding used to classify cell types. Thus, by giving the neural network the ability to mask individual input features, it can focus its network capacity only on more reliable input features. In contrast to the original TabNet model, we only used a single decision step as using more than one decision step only yielded marginal performance improvements and did not justify the increased computational costs.The objective function used to train scTab is a cross-entropy loss where each cell type label is weighted in correspondence to its relative frequency in the training data to account for the strong class imbalance in the training data:$${weigh}{t}_{{celltype}}=\frac{{n}_{{samples}}}{{n}_{{classes}} \cdot {\Sigma }_{{cell}\; {in}\; {cells}} \quad {labe}{l}_{{cell}}=={celltype}}$$
(6)
The models for Fig. 1 and Fig. 3 were fitted with our proposed data augmentation strategy. The models for Fig. 2 were fitted without data augmentation to better show the scaling with respect to the training data size.List of used hyperparameters:ParameterValuebatch_size2048learning_rate0.005learning rate schedulertorch.optim.lr_scheduler.StepLRgamma = 0.9step_size = 1 epochoptimizertorch.optim.AdamWweight_decay0.05n_d128n_a64n_shared3n_independent5n_steps1lambda_sparse1e-5mask_typeentmaxvirtual_batch_size256augment_training_dataTrueXGBoost modelThe input to the XGBoost model is a 256-dimensional PCA embedding due to the high memory usage and runtime of the XGBoost model. The PCA is only fitted on the training data to have a clear separation between the training and test set. Furthermore, the data is normalized to 10,000 counts per cell and is then log1p-transformed before calculating the PCA embeddings. The XGBoost model is fitted with the multi:softprob objective function and like for the scTab model classes are weighted in accordance to their relative frequency in the training data.List of non-default hyperparameters:ParameterValuen_estimators800eta0.05subsample0.75max_depth10early_stopping_rounds10For the benchmarks in this paper, we used XGBoost version 1.6.2Multi-layer perceptron model (MLP)The input to the model is all 19,331 protein-coding human genes selected from the CELLxGENE census data. The model is trained to predict the corresponding cell type label for each cell with a cross-entropy loss where each cell type is weighted in correspondence to its relative frequency (see scTab model).The input count data is normalized to 10,000 counts per cell and is then log1p-transformed before feeding it into the model.List of used hyperparameters:ParameterValuebatch_size2048learning_rate0.002learning rate schedulertorch.optim.lr_scheduler.StepLRgamma = 0.9step_size = 1 epochoptimizertorch.optim.AdamWweight_decay0.05n_hidden8hidden_size128dropout0.1augment_training_dataTrueOptimized linear modelThe input to the model is all 19,331 protein-coding human genes selected from the CELLxGENE census data. The model consists of a single weight matrix and bias vector and is trained to predict the corresponding cell type label for each cell with a cross-entropy loss where each cell type is weighted in correspondence to its relative frequency (see scTab model).The input count data is normalized to 10,000 counts per cell and is then log1p transformed before feeding them into the model.List of used hyperparameters:ParameterValuebatch_size2048learning_rate0.0005learning rate schedulertorch.optim.lr_scheduler.StepLRgamma = 0.9step_size = 1 epochoptimizertorch.optim.AdamWweight_decay0.01CellTypist modelThe CellTypist6 model was fitted in accordance with the best practice tutorial supplied on the CellTypist website with the difference that the mean centering step was disabled (with_mean=False) as this negatively impacted model performance and increased memory usage. Furthermore, the training data was subsampled to 1.5 million cells to keep both the memory usage (350GB of max memory) and runtime in check.List of non-default hyperparameters:ParameterValuefeature_selectionTrueuse_SGDTruemini_batchTruebatch_number1500epochs10with_meanFalseFor the benchmarks in this paper, we used CellTypist version 1.5.3scGPT (zero-shot setting)We evaluated the performance of scGPT in the zero-shot setting, meaning we used the pre-trained whole-human scGPT model to get cell embeddings and used those embeddings as input to a logistic regression classifier. The logistic regression classifier was trained on a random subsample of 1,500,000 cells from the training data.List of non-default hyperparameters for cuml LogisticRegression:ParameterValueclass_weightbalancedmax_iter5000C1000For the benchmarks in this paper, we used cuml version 23.10 and scgpt version 0.1.7.Whole-human scGPT model: https://drive.google.com/drive/folders/1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y?usp=sharing.scGPT (fine-tuned)We fine-tuned the scGPT in accordance with the following example notebook provided by the authors of the scGPT paper: https://scgpt.readthedocs.io/en/latest/tutorial_annotation.htmlDue to the high memory usage of scGPT, we were only able to fine-tune the scGPT model on a random subsample of 150,000 cells of our training data. For our benchmark, we fine-tuned the whole-human scGPT model: https://drive.google.com/drive/folders/1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y?usp=sharing.Universal Cell Embedding (UCE) (zero-shot)As the UCE model47 is very resource intensive – even when just using it for inference, we used the pre-computed UCE embeddings which are hosted by CELLxGENE (https://cellxgene.cziscience.com/census-models) and evaluated them in the linear probing setting. This means fitting a logistic regression model based on the embeddings obtained by UCE. Unfortunately, these pre-computed embeddings only exist for census version 2023-12-15 which is missing some datasets that were included in census version 2023-05-15 (the census version used in this paper). In numbers, this means, we could only evaluate the UCE model on 736 of the 758 donors from our test data.We fitted the logistic regression classifier on a random subsample of 1,500,000 cells of the training data and then evaluated this classifier on the reduced test data (only 736 of the original 758 donors from our test data).Uncertainty quantification for scTab modelThe uncertainty quantification for scTab is based on deep ensembles34 using \(1-{maximum\; predicted\; probability}\) as an estimate for the model uncertainty. Deep ensembles are commonly used to assess the uncertainty in predictions of neural networks. They are simple, yet achieve state-of-the-art results: one just averages the predicted probabilities across several networks that were independently trained (each with a different random initialization of the weights). In our case, we averaged the predictions across 5 models.To assess how well one can identify cell types that are not present in the training data or cells with wrong predictions we split the CELLxGENE data into three parts:

Group 1: Correct Predictions Cell types that are present in the training data and which are predicted correctly by the model (this serves as a reference group). This group is referenced as in-distribution (right prediction) or simply as Group 1 below.

Group 2: Incorrect Predictions Cell types that are present in the training data but which the model predicted wrongly to assess how well wrong predictions can be distinguished from right predictions based on the uncertainty scores. This group is referenced as in-distribution (wrong prediction) or simply as Group 2 below.

Group 3: Absent in Training Data Cell types that are not present in the training data to assess how well unknown cell types can be identified based on the uncertainty scores provided by scTab. These are the cell types that we excluded from the CELLxGENE training data because there were too few observations present. This group is referenced as out-of-distribution or simply as Group 3 below.

Now, to understand the quality of uncertainty estimates, we want to assess how well Group 2 and Group 3 (incorrect and absent) can be separated from reference Group 1 (correct). Note that the separation between Group 1 and Group 2 (correct vs incorrect) measures how well the uncertainty scores can be used to assess whether a model prediction can be trusted or not, and the separation between Group 1 and Group 3 (correct vs absent) gives an estimate of how well the uncertainty scores can be used to detect new/unseen cell types. The model uncertainty is defined as follows (logits are the outputs of the last layer of the neural network):$${uncertainty}=1.-\max (p)$$
(7)
$$p={softmax}({logits})$$
(8)
To get a first impression, one can look at the distribution of uncertainty scores conditioned on which one of the three groups a cell belongs to (see Supp. Fig. 4c). As expected, one can see that the uncertainty scores for Group 2 and Group 3 are usually a lot higher than for Group 1.Now, to provide a more mathematically rigorous benchmark, one can measure how well one can distinguish Group 2 and Group 3 from the reference Group 1 based on the uncertainty scores provided by the scTab model by looking at the area under the curve of the receiver operating characteristic (ROC-AUC). A ROC-AUC score of 1.0 means that the groups can be perfectly separated and a score of 0.5 means that there is no separation between the groups based on the uncertainty scores (see Supp. Fig. 4d). The above approach can also be used to assess how the quality of uncertainty estimates improves with the number of models in the deep ensemble. Looking at the results, one can see that for both cases our uncertainty estimates provide a useful way to distinguish between the groups. Group 1 and Group 3 can be separated with an ROC-AUC score of 0.782 and Group 1 and Group 2 can even be separated with an ROC-AUC score of 0.891.In practice, biologists and computational biologists can overlay the uncertainty estimates from scTab alongside the predicted cell type labels and their defined clustering on a UMAP or tSNE visualization of their data to hint at which predictions are associated with higher uncertainty and, hence, should be investigated in more detail (see Supp. Fig. 4a).Statistics and reproducibilityNo statistical method was used to predetermine sample size. We simply used all the available data from the CELLxGENE data corpus (version 2023-05-15) subject to the our filtering criterion described in the Dataset preparation section (Methods).Reporting summaryFurther information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

Hot Topics

Related Articles