scGAA: a general gated axial-attention model for accurate cell-type annotation of single-cell RNA-seq data

The scGAA modelThe cell type annotation method for the scRNA-seq data based on the gated axial-attention mechanism (Fig .8) includes the involves steps:Randomly divide the expression matrix including the top n highly variable genes of all cells from the gene expression matrix to obtain several gene sets, so that each gene set represents different feature information, and obtains the input tensor by embedding the gene set. Using the change weight matrix W (learned during the training of the transformer model) and mask matrix M (composed of 0 and 1 randomly and with the same dimension as W), multiply the corresponding positions of W and M, and then multiply them with the embedding representation G of each gene set to obtain the feature information t of each gene set. The formula is as follows:$$\begin{aligned} t = W \cdot M \times G. \end{aligned}$$
(1)
where, \(\cdot\) is the point multiplication operation, \(\times\) is the matrix multiplication, and the mask matrix M is used.The feature information calculation of each gene set is repeated m times to increase the dimension of the space, and then the feature information obtained by m calculations is merged to obtain the input tensor T of each gene set. The formula is as follows:$$\begin{aligned} T = concat(t_1, t_2, \ldots , t_m). \end{aligned}$$
(2)
In this context, the concat() function performs a merge operation, and the shape of the input tensor T is (N, V, H, C), where N is the batch size, V is the height, H is the width, and C is the number of channels.Dividing the gene sets into several sets is to convert genes into the smallest unit for model processing and generate input tensors. This input tensor is proven to be effective in obtaining gene-gene interactions in the following experiments.In this study, the gene expression vectors of all cells are converted into a series of two-dimensional “images”, where each image contains \(H \times W\) “pixels” (each “pixel” corresponds to an element in the original vector). The input of self-attention is represented by matrix X, and the matrices Query, Key and Value are calculated by linear transformation matrices \(W_{Q}\), \(W_{K}\) and \(W_{V}\), and the output of Self-Attention can be calculated by obtaining the matrices Q, K, and V.Furthermore, this study introduces an axial self-attention mechanism was introduced in the gated axial-attention operation module of scGAA. The axial attention mechanism is to divide the attention module into two modules: horizontal attention module (Horizontal-Attention) and vertical attention module (Vertical-Attention), and obtain two output matrices: horizontal attention output matrix (Row matrix) and vertical attention output matrix (Column matrix), and then merge them into one output matrix. The axial attention mechanism effectively simulates the original self-attention mechanism, greatly improves the computational efficiency, and performs self-attention operations in the horizontal and vertical directions respectively, aiming to effectively reduce the computational complexity. In addition, this mechanism can more effectively capture the interaction between genes, which helps to improve the model’s adaptability to different data sets.The horizontal-attention output matrix is calculated in the horizontal attention module. First, the input tensor T is expanded along the row axis into a row input tensor \(T_{row}\), and each row is treated as an independent sequence, expressed as:$$\begin{aligned} T_{row} = reshape(T, (N \cdot H, V, C)). \end{aligned}$$
(3)
In the context, reshape is the reshaping operation on the input tensor T, \(\cdot\) is the dot multiplication operation, N is the batch size, V is the height, H is the width, and C is the number of channels. Then calculate the query \(Q_{h}\), key \(K_{h}\) and value \(V_{h}\):$$\begin{aligned} Q_{h} = W_{Qh} \times T_{row}, \end{aligned}$$
(4)
$$\begin{aligned} K_{h} = W_{Kh} \times T_{row}, \end{aligned}$$
(5)
$$\begin{aligned} V_{h} = W_{Vh} \times T_{row}. \end{aligned}$$
(6)
In the context, \(W_{Qh}\), \(W_{Kh}\) and \(W_{Vh}\) are linear transformation matrices, and \(\times\) is matrix multiplication. The horizontal attention similarity score matrix \(A_{h}\) is calculated for the query \(Q_{h}\), key \(K_{h}\) and value \(V_{h}\) respectively:$$\begin{aligned} A_{h} = softmax\left( \frac{Q_{h}K_{h}^T}{\sqrt{d_h}}\right) . \end{aligned}$$
(7)
In the context, softmax function is the softmax activation function, \(\sqrt{d_h}\) is the dimension of the \(Q_{h}\) and \(K_{h}\), which is used as a scaling factor here to prevent the dot product value of \({{(Q_h})(K_h )^T}\) from being too large, and the superscript T is the matrix transpose. Then, calculate the horizontal-attention output matrix \(O_{h}\):$$\begin{aligned} O_{h} = A_{h} \times V_{h}. \end{aligned}$$
(8)
Similarly, the vertical-attention output matrix \(O_{v}\) is calculated in the vertical attention module. First, the input tensor T is expanded along the column axis into a column input tensor \(T_{col}\), and each column is treated as an independent sequence, expressed as:$$\begin{aligned} T_{col} = reshape(T, (N \cdot V, H, C)). \end{aligned}$$
(9)
Then calculate the query \(Q_{v}\), key \(K_{v}\) and value \(V_{v}\):$$\begin{aligned} Q_{v} = W_{Qv} \times T_{col}, \end{aligned}$$
(10)
$$\begin{aligned} K_{v} = W_{Kv} \times T_{col}, \end{aligned}$$
(11)
$$\begin{aligned} V_{v} = W_{Vv} \times T_{col}. \end{aligned}$$
(12)
In the context, \(W_{Qv}\), \(W_{Kv}\) and \(W_{Vv}\) are linear transformation matrices, and \(\times\) is the matrix multiplication. The vertical-attention similarity score matrix \(A_{v}\) was calculated for the query \(Q_{v}\), key \(K_{v}\) and value \(V_{v}\) :$$\begin{aligned} A_{v} = softmax\left( \frac{Q_{v}K_{v}^T}{\sqrt{d_v}}\right) . \end{aligned}$$
(13)
In the context, softmax function is the softmax activation function, \(\sqrt{d_v}\) is the dimension of the \(Q_{v}\) and \(K_{v}\), which is used as a scaling factor here to prevent the dot product value of \({{(Q_v})(K_v )^T}\) from being too large, and the superscript T is the matrix transpose. Then, calculate the vertical-attention output matrix \(O_{v}\):$$\begin{aligned} O_{v} = A_{v} \times V_{v}. \end{aligned}$$
(14)
On this basis, to further optimize the performance of the scGAA model, six gating units \(G_{Qh}\), \(G_{Kh}\), \(G_{Vh}\), \(G_{Qv}\), \(G_{Kv}\) and \(G_{Vv}\) were introduced to control the information of Q, K and V in horizontal-attention and vertical-attention respectively. These gating units are learnable parameters for extracting important features. Depending on whether the learned information is useful, the gating unit generates a set of weights close to 0 to 1. This weight can be used to control the amount of information passed through, to extract the most important features and improve the prediction accuracy of the model.\(G_{Qh}\), \(G_{Kh}\) and \(G_{Vh}\) are added to \(Q_{h}\), \(K_{h}\) and \(V_{h}\) respectively, and the horizontal-attention similarity score matrix \(A_{h}\) is calculated as follows:$$\begin{aligned} A_h = softmax \left( \frac{(Q_h \cdot G_{Qh})(K_h \cdot G_{Kh})^T}{\sqrt{d_h}} \right) . \end{aligned}$$
(15)
Finally, the horizontal-attention output matrix \(O_{h}\) is calculated:$$\begin{aligned} O_h = A_h \times (V_h \cdot G_{Vh}). \end{aligned}$$
(16)
Similarly, \(G_{Qv}\), \(G_{Kv}\) and \(G_{Vv}\) are added to \(Q_{v}\), \(K_{v}\) and \(V_{v}\) respectively, and the vertical-attention similarity score matrix \(A_{v}\) is calculated as:$$\begin{aligned} A_v = softmax \left( \frac{(Q_v \cdot G_{Qv})(K_v \cdot G_{Kv})^T}{\sqrt{d_v}} \right) . \end{aligned}$$
(17)
Finally, the vertical-attention output matrix \(O_{v}\) is calculated:$$\begin{aligned} O_v = A_v \times (V_v \cdot G_{Vv}). \end{aligned}$$
(18)
The outputs of the horizontal-attention and vertical-attention are combined to better integrate the global information. The output matrix O is obtained, and the formula is expressed as:$$\begin{aligned} O = O_h + O_v . \end{aligned}$$
(19)
Fig. 8Gated axial-attention detailed structure. Firstly, this model implements feature embedding for all gene sets, using horizontal and vertical gated attention mechanisms, respectively. By calculating the horizontal (\(Q_{h} \times K_{h}\)) and vertical (\(Q_{v} \times K_{v}\)) attention scores, the model is able to identify key gene sets. With the help of these key gene sets, we can conduct further studies such as difference analysis, enrichment analysis, and gene characterisation, which provide the basis for subsequent analysis and interpretation. Subsequently, the model fuses the outputs of horizontal and vertical gated attention and calculates scores for each category through a linear layer. Ultimately, these scores were converted into corresponding probabilities via a softmax function.Data preprocessingSince most of the scRNA-seq is not perfect, we need to perform quality control on the data to filter low quality cells. In this context, cells with less than 3 genes expressed and cells with less than 200 genes expressed will be filtered out. Next the data is normalised, in this paper Pearson’s approximation of residuals is used which preserves the intercellular variation including biological heterogeneity and helps to identify rare cell types49. In this paper, we use the preprocessing function pp.preprocess provided in scanpy50, which can directly calculate the Pearson residuals for normalisation51. First, the median of the sum of gene expression for all cells and the sum of gene expression per cell were calculated. Normalise the expression of each gene \(X_{ij}\) for each cell:$$\begin{aligned} X_{ij}^{{norm}} = X_{ij} \times \frac{S_{median}}{S_i}. \end{aligned}$$
(20)
where \(X_{ij}\) is the raw expression of gene j in cell i. \(S_{i}\) is the sum of the expression of all the genes in cell i. \(S_{median}\) is the median of the sum of gene expression in all cells. \(X_{ij}^{{norm}}\) is the normalised gene expression.Afterwards, the normalised gene expression was log-transformed:$$\begin{aligned} X_{ij}^{{log}} = \log (1 + X_{ij}^{{norm}}). \end{aligned}$$
(21)
\(X_{ij}^{{log}}\) is the final log-transformed gene expression. This two-step process made the data comparable across cells and helped reduce the impact of extreme values in the data on subsequent analyses.After that, the normalised data is used as input to the scGAA model to train the model, in order to prevent one category in the dataset being much larger than the others, then the model may be biased towards predicting the one with the highest number of samples, which will lead to poorer performance in predicting categories with fewer samples. Balancing the dataset by category allows the model to be exposed to a more balanced set of samples during training, which helps the model learn the differences between the categories better and improves the model’s prediction performance on unknown data. Then 80% of the dataset is randomly split into the training and test sets, and the remaining data is divided into the validation set.Loss functionThe model training process includes two stages: self-supervised learning on unlabeled data to obtain a pre-trained model; supervised learning on specific cell type labeling tasks to obtain a fine-tuned model. Use this to predict the probability of each cell type. Cross-entropy loss is also used as cell type label prediction loss, and the loss function is calculated as follows:$$\begin{aligned} L = -\sum _{i=1}^{N} \sum _{c=1}^{C} y_{i,c} log(p_{i,c}). \end{aligned}$$
(22)
where L represents the value of the loss function, N represents the total number of cells in the dataset, C represents the total number of cell types, \(y_{i,c}\) is an indicator that is 1 when the true category of cell i is c and 0 otherwise, and \(p_{i,c}\) is the probability that the model predicts that sample i belongs to category c.This formula quantifies the accuracy of the model’s prediction by calculating the logarithm of the predicted probability corresponding to the actual cell type and taking its negative value. As the model’s predicted probability gets closer to the actual category (i.e., \(p_{i,c}\) is close to 1 when \(y_{i,c}\) = 1), the cross-entropy loss is smaller, and vice versa.We use Stochastic Gradient Descent (SGD) as the optimization algorithm and employ a cosine learning rate decay strategy during the training process to prevent issues caused by large steps in the late stages of training. The accuracy and Macro F1 Score metrics are used to evaluate the performance of each method in cell type annotation at both the cell level and cell type level.

Hot Topics

Related Articles