Tabular deep learning: a comparative study applied to multi-task genome-wide prediction | BMC Bioinformatics

DefinitionsThe key part of tabular deep learning is a deep neural network and usually in the form of a feed-forward network. A deep neural network defines a mapping \(\hat{f}\) [7] following$$\begin{aligned} Y = f(X) \approx \hat{f}(X; W), \end{aligned}$$
(1)
that learns the model parameters W (i.e. the weights of a neural network) that results in the best approximation of true underlying and unknown function f. In this case, \(X \in \mathcal {R}^{p}\) is the input feature data with corresponding output target \(Y \in \mathcal {Z}^{k}\) for k classes or \(Y \in \mathcal {R}^{k}\) for k regression tasks represented as a set of tuples \(\{x_i, y_i\}_{i\in n}\). Throughout this study, we focus on genomic data X of dimension \(n \times p\) with associated multiple traits Y of dimension \(n \times k\) with the restriction that the all traits are of the same modality (i.e. either continuous or binary). However, the input data can be of different modalities (for example a mixture between DNA marker and RNA expression data) which fundamentally contrasts with the homogeneous nature of image, audio or text data. The network used in this work is feed-forward, which means that the input information flows in one direction to the output without any feedback connections [7]. Multi-task tabular deep learning involves training a model to simultaneously minimize a multivariate loss function \(\mathcal {L}(\hat{f}(X; W), Y)\). The goal is to capitalize on possible shared representations across related tasks in order to enhance the model’s performance.Tabular neural networksNODEInspired by CatBoost [13], the neural oblivious decision ensemble (NODE) [38] is designed to leverage the interpretability of decision trees while benefiting from the expressive prediction power of neural networks.Each oblivious decision module in NODE consists of several layers of decision nodes. At each layer l, each node makes a decision based on a splitting feature g and a learnable threshold b. For the l-th layer and node j, the decision function can be expressed as$$\begin{aligned} d_{l,j}(X)=\mathcal {H}(g_{l,j}-b_{l,j}), \end{aligned}$$
(2)
where \(\mathcal {H}(\cdot )\) denotes the Heaviside step function and \(d_{l,j}(X)\) is the output of the decision node. Then all r decision nodes in the l-th layer will produce a decision vector \(v_{l} = [d_{l,1}(X), d_{l,2}(X), \cdots , d_{l,r}(X)]\). NODE uses an ensemble of such decision trees through these output nodes.To make the oblivious decision trees (ODTs) differentiable, the splitting feature choice \(g_{l,j}\) and the comparison operator \(\mathcal {H}(g_{l,j}-b_{l,j})\) are replaced by their differentiable counterparts. The choice function is replaced by a weighted sum of features, with weights computed based on the \(\alpha -\text {entmax}\) transformation [49] over the learnable feature selection matrix. The Heaviside function \(\mathcal {H}(\cdot )\) is relaxed to a two-class \(\text {entmax}\), which is denoted as \(\sigma _{\alpha }(v) = \text {entmax}_{\alpha }([v, 0])\). Its scaled version \(c_{l,j}(X) = \sigma _{\alpha }(\frac{g_{l,j}(v) – b_{l,j}}{\tau _{l,j}})\) is used due to potential variations in feature scales by using learnable parameters. The computed \(c_{l,j}(v)\) is combined into a “choice” tensor \(C \in \mathcal {R}^{2^{d}}\). The final prediction is then computed as a weighted linear combination of response tensor entries R with weights from the entries of choice tensor C. Assume that the tree outputs are one-dimensional \(\hat{h}(v)\) and each NODE layer contains several trees whose outputs are concatenated by m individual trees \([\hat{h}_{1}(v), \cdots , \hat{h}_{m}(v)]\). Then the NODE layer can be trained alone or within a complex structure, just like fully-connected layers that serve as input for the subsequent layers. Similar to DenseNet [50], this architecture is a sequence of l NODE layers, where each layer uses a concatenation of previous layers as its input. This aggregated output is then passed through a final fully connected layer to produce the final prediction \(\hat{Y}\), with the structure of this final layer depending on the task (regression or classification).TabRTabR [39] is a feed-forward network incorporating a customized t-Nearest-Neighbors-like component in the middle layer to produce a better prediction. Its main idea is to utilize the self-attention mechanism of transformers to capture complex interactions between features in tabular data.With the feature matrix X, a feed-forward retrieval-free network \(f(X)=P(E(X))\) is first partitioned into two parts: an encode \(E:X\rightarrow \mathcal {R}^{p^{‘}}\) part and a predictor \(P: \mathcal {R}^{p^{‘}}\rightarrow P{\hat{Y}}\) part. To make the model incrementally retrieval-based, a retrieval module R in a residual branch is added after E, where \(\tilde{X} \in \mathcal {R}^{p^{‘}}\) is the intermediate representation of the target object, \(\{\tilde{x}_{i}\}_{i\in I_{cand}} \subset \mathcal {R}^{p^{‘}}\) are the intermediate representations of the candidates and \(\{y_{i}\}_{i\in I_{cand}}\subset Y\) are the labels of the candidates.The retrieval module R is defined in the spirit of k-nearest neighbors. For the target object’s representation, the retrieval module takes the \(x_{1,…,t}\) nearest neighbors among the candidates \(\tilde{x}_{i}\) according to the similarity module \(\mathcal {S}\) and aggregates their values produced by the value module \(\mathcal {V}\) with the definitions$$\begin{aligned} \mathcal {S}(\tilde{X},\tilde{x}_{i}) = W_{Q}(\tilde{X})^{T}W_{K}(\tilde{x}_{i}) \quad \quad \mathcal {V}(\tilde{X},\tilde{x}_{i},y_{i}) = W_{V}(\tilde{x}_{i}), \end{aligned}$$
(3)
where \(W_{Q}\), \(W_{K}\), \(W_{V}\) are the weights for the corresponding transformation. They play a critical role in transforming inputs to better capturing the similarities between entries, contributing to the model’s ability to learn complex patterns and relationships within the data. By adding context labels, the performance of the similarity \(\mathcal {S}\) and the value module \(\mathcal {V}\) can be improved. Finally, the formal complete description of TabR which implements the R module is$$\begin{aligned} \mathcal {S}(\tilde{X},\tilde{x}_{i}) = -\left\| t-t_{i} \right\| ^{2} \quad \quad \mathcal {V}(\tilde{X},\tilde{x}_{i},y_{i}) = W_{y_{i}} + O(t-t_{i}), \end{aligned}$$
(4)
where \(t=W_{K}(\tilde{X})\), \(t_{i}= W_{K}(\tilde{x}_{i})\)and the operation of O is defined as \(O(\cdot )=\text {LinearWithoutBias}(\text {Dropout}(\text {Relu}(\text {Linear}(\cdot ))))\). The retrieval module R enriches the target object’s representation by retrieving and processing relevant objects from the candidates. Finally, the predictor P makes a prediction.TabNetTabNet [37] combines the strengths of both tree-based methods and deep neural networks using a sequential attention mechanism. It emerges as a deep learning model embodying the feature selection principles of decision trees, with its encoder comprising a feature transformer, an attentive transformer, and feature masking.The features in X will be the input to a batch normalization layer which yields \(X^{‘} \in \mathcal {R}^{w \times z}\), where w is the batch size and z is the dimension. Assume the number of hidden layers is l, the output of the feature transformer then becomes \(w \times l\) which is split into two parts to construct a gated linear unit: a standard decision step \(\rho [i]\in \mathcal {R}^{w \times l_{a}}\) and a shared across decision step \(a[i] \in \mathcal {R}^{w \times l_{\rho }}\). The former is used for the final output of TabNet, and the latter is used as an input of the attentive transformer. Each block is composed of a fully-connected (FC) layer, batch normalization (BN) and a gated linear unit (GLU). For the attentive transformer, the main function is to get a learnable mask layer \(M[j] \in \mathcal {R}^{w \times z}\) according to$$\begin{aligned} M[j] = \text {sparsemax}(P[j-1] \cdot \gamma _{i}(a[j-1])), \end{aligned}$$
(5)
where \(\gamma _{i}(a[i-1]))\) is from the FC to BN, and sparsemax is a mapping from the vector to a simplex that obtains sparsity. The scaling prior – P[j-1] has a close connection to the mask M[j] via the features used in previous steps and one can notice that the initial value of P[0] equals 1. To ensure the sparsity of M[j], a regularized constraint is given to the parameters to make the distribution of M[j] more reasonable.All the output from the earlier steps are summed to give the final output through the FC layer. For the multi-task learning, each task-specific branch ends with an output layer that produces a scalar output for the regression task. The shared layers can facilitate the extraction of relevant features for multi-task learning and the task-specific branches capture patterns specific to each multi-task regression target. The TabNet decoder is composed of a feature transformer block at each step. After reconstructing the features from the encoded representation, the aggregated features will be passed through a fully connected layer to do the predictions \(\hat{Y}\).TabTransformerConsidering the characteristics of context embeddings, TabTransformer [42] is built upon self-attention based transformers. This model comprises a column embedding layer, a stack of l Transformer layers, and a multi-layer perceptron. Each Transformer layer consists of a multi-head self-attention layer followed by a position-wise feed-forward layer. For the tuples \(\{x_{i}, y_{i}\}_{i\in n}\), each of the \(x_{i}\) is embedded into a parametric embedding of dimension s using column embedding. Let \(e_{\phi _{i}}(x_{i}) \in \mathcal {R}^{s}\) be the embedding of the \(x_i\) feature, and \(E_{\phi }(x_{cat}) = \{e_{\phi _{1}}(x_{1}), \cdots , e_{\phi _{s^{‘}}}(x_{s^{‘}})\}\) is the set of embeddings for all the categorical features. Then \(E_{\phi }(x_{cat})\) serves as input to the sequential Transformer layers \(f_{\theta }\), which operate on parametric embeddings and return the corresponding contextual embeddings \(h_{s^{‘}}\) where \(h\in \mathcal {R}^{s}\). These contextual embeddings are concatenated along with the features to first form a vector which serves as the input to an MLP that is used to predict the target \(\hat{Y}\).A self-attention layer in TabTransformer comprises three parametric matrices – Key (K), Query (Q) and Value (V). Each input embedding is projected on to these matrices to generate the corresponding vectors and attends to all other embeddings through an attention head, which is computed as \(\text {Attention}(K, Q, V) = A \cdot V\) , where \(A = \text {softmax}((QK^{T})/\sqrt{k {‘}})\)(\(k^{‘}\) is the dimension of Key). The output of the attention head is projected back to the embedding through a FC, which in turn is passed through two position-wise feed-forward layers. The contextual embeddings are concatenated to form the feature \(x_{cont}\). If we let \(\delta\) be the cross-entropy for classification and the mean square error for regression tasks, the prediction \(\hat{Y}\) can be obtained by minimizing the loss function \(\mathcal {L}(\hat{f}(X;W), Y)=\delta (MLP(\text {Transformer}(E_{\phi }(x_{cat})),x_{cont}),Y)\).FT-TransformerFT-Transformer performs feature transformations that enhance the model’s ability to capture complex patterns [43]. It handles individual features independently before combining them to make predictions. There are two important parts of FT-Transformer: the feature tokenizer and the transformer.The feature Tokenizer component first transforms the input feature X to embeddings \(G \in \mathcal {R}^{m^{‘} \times n^{‘}}\). The embedding for the feature \(x_i\) is computed as$$\begin{aligned} G_{i} = f_{i}(x_i) + b_{i}, \end{aligned}$$
(6)
where \(b_{i}\) is the i-th feature bias, \(f_{i}\) is implemented as the element-wise multiplication with the weight matrix \(W_{i}\). There is also a function \(f_{i}^{(cat)}\) implemented as a lookup table \(W_{i}^{(cat)}\) for categorical features with one-hot vectors of the corresponding categorical features \(e_{i}^{T}\). Then, the vectors are stacked as \(G = stack[G_{1}, \cdots , G_{i}, G_{1}^{(cat)}, \cdots , G_{i}^{(cat)}]\) and the embedding of the [CLS] token (or “output token”) is appended to the G and l Transformer layers \(F_{1}, F_{2}, \cdots , F_{l}\) as$$\begin{aligned} G_{0} = \text {stack}[[CLS], G] \quad \quad G_{i}=F_{i}(G_{i-1}). \end{aligned}$$
(7)
After using the PreNorm setting, the final representation of the [CLS] token is used for prediction. For the multi-task learning situation, the initial layer of FT-Transformer consists of a shared transform encoder that will process the input feature that are propagated as the task-specific heads for each regression task. These heads are small MLPs that take the output of the shared encoder and generate task-specific predictions. Denoting the final representation of the [CLS] token as \(G_{l}^{[CLS]}\), then the prediction is \(\hat{Y}=f(G_{l}^{[CLS]};W)\).AutoIntGiven the limited ability of shallow networks to model interactions, AutoInt [40] is designed based on transformer mechanisms that enhance the modelling capabilities for feature interactions. The main idea of AutoInt is mapping of the original features to sparse low dimensions and modeling of the interactions among the high-order features.With an embedding vector \(\upsilon _{i}\) for field i, the original feature \(x_i\) is embedded into low dense vectors through the embedding layer as \(\sigma _{i} = \upsilon _{i} x_{i}\). The output of the embedding layer is a concatenation of multiple embedding vectors, which are the input of an interaction layer. For the following interaction layer, a multi-head mechanism is utilized to map the feature into multiple subspaces and generate the different feature interaction pattern in these spaces. Further on, more high-order interactions will be produced through stacking of interaction layers. For the feature \(\sigma _{i}\) in attention space I, there are three vectors: \(W_{Q}\) for query, \(W_{K}\) for key, and \(W_{V}\) for value, respectively. The similarity between the feature \(\sigma _{i}\) and feature \(\sigma _{j}\) is first obtained as \(\phi ^{I} = <W_{Q}^{I}\sigma _{i}, W_{K}^{I}\sigma _{j}>\), and then the distribution of the attention is produced using softmax. With a weighted sum, the new feature of \(\sigma _{i}\) can be acquired as \(\hat{\sigma _{i}}^{I}\).For multiple attention spaces, the new feature from each space can be concatenated to get the final representation of \(\sigma _{i}\) as \(\hat{\sigma }_{i}\). To preserve the learned combinatorial features, including raw individual features, a standard connection is added to the network$$\begin{aligned} \sigma _{i}^{Res} = \text {ReLU}(\hat{\sigma }_{i} + W_{Res} \sigma _{i}), \end{aligned}$$
(8)
where \(W_{Res}\) is the projection matrix and \(\text {ReLU}(z) = \text {max}(0, z)\) is the standard non-linear activation function. Thus, the representation of each feature \(\sigma _{i}\) is updated into a new representation \(\sigma _{i}^{Res}\). By stacking multiple such layers, an arbitrary order of \(\hat{\sigma }_{i}\) can be modeled. The output of the interacting layer is a set of feature vectors \(\{\sigma _{i}^{Res}\}_{i=1}^{p}\). By concatenating all of the learned feature interactions, the aggregated representation will be passed through a final layer for the predictions of \(\hat{Y}\).LassoNetLassoNet is based on the sparsity idea of the Lasso and achieves feature sparsity by allowing a feature to participate in a hidden unit only if its input connection is active [35]. The features X and a residual feed-forward neural network \(\mathcal {F}\) with an arbitrary width and depth [51] can be described as$$\begin{aligned} \mathcal {F} = \{\hat{f} \equiv \hat{f}_{\theta , W}: X \mapsto \theta ^{T}X + g_{W} X\}, \end{aligned}$$
(9)
where \(g_{W}\) is a feed-forward network with wights W (fully connected). The object function of LassoNet for multi-task learning is$$\begin{aligned} \arg \min \limits _{\theta , W} \mathcal {L}(\theta , W) + \lambda \left\| \theta \right\| _{1} \quad s. t. \quad \quad \left\| W_{i}^{(1)} \right\| _{\infty }\le \nu \left| \theta _{i} \right| , i=1,\cdots ,p, \end{aligned}$$
(10)
where \(\mathcal {L}(\theta ,W)\) is the loss on the training data set, and \(W_{i}^{1}\) denotes the weights for feature i in the first hidden layer. The constraint means that the total amount of non-linearity involving feature i according to the relative effect importance of \(x_{i}\) as a main effect. The residual link and the first hidden layer jointly pass through a hierarchical soft-thresholding optimizer \(\mathcal {S}(x)=sign(x)\cdot max\left\{ \left| x \right| – \lambda , 0 \right\}\). For the multi-task learning, the layers of the neural network remain common across all tasks to capture shared representations. The sparsity of the  input layer weights gives complete control of the feature sparsity of the network. When \(\nu =0\), all the hidden units are inactive and only the skip connection remains which means that the formulation recovers exactly the Lasso. On the other hand, when \(\nu \rightarrow \infty\), one recovers a standard unregularized feed-forward neural network. The linear and nonlinear components are optimized jointly to capture arbitrary nonlinearity.GANDALFInspired by gated recurrent units (GRUs) [52] for representation learning, the gated adaptive network for deep automated learning of features (GANDALF) is designed for tabular data based on a gating mechanism and in-built feature selection called Gated Feature Learning Unit (GFLU) [41].A learnable mask \(M_n \in \mathcal {R}^{p}\) is used for the soft sparse selection of important features for each stage n of feature learning in GFLU. The mask is constructed by applying a sparse transformation on a learnable parameter vector \(\Im _{n}\in \mathcal {R}^{p}\) combined with t-softmax activation for encouraging sparsity selection [53]. Here, let \(X_{n}\) be the input features and \(M_{n} = t\text {-softmax}(\Im _{n})\) the mask, the feature selection can be defined by$$\begin{aligned} X_{n}&= M_n \odot X \end{aligned}$$
(11)
$$\begin{aligned} M_n&= t\text {-softmax}(\Im _{n}, t), \end{aligned}$$
(12)
where \(\Im _{n}\) and t are learnable parameters and \(\odot\) denotes an element-wise multiplication operation. The weight matrix W depends on the value of t. The gating mechanism has a reset gate \(r_{n}\) and an update gate \(z_n\). The update gate decides how much information to update in its internal feature representation, which can be defined as$$\begin{aligned} z_n = \sigma (W_{n}^{z} \cdot [\varphi _{n-1}; x_{n}]), \end{aligned}$$
(13)
where \(\varphi _{n-1}\) is the \((n-1)\)-th stage of the GFLU and \(W_{n}^{z}\) is a learnable parameter for the weight at stage n. Then the candidate feature representation \(\hat{\varphi }_{n}\) is computed as$$\begin{aligned} \hat{\varphi }_{n} = tanh(W_{n}^{O} \cdot [r_n \odot \varphi _{n-1}; X]), \end{aligned}$$
(14)
where \(r_{n}\) decides how much information to forget from the previous feature representation, [] represents a concatenation operation, and \(W_{n}^{O}\) represents a learnable parameter. The reset gate can be computed in a similar way as the update gate: \(r-{n} = \sigma (W_{n}^{r}\cdot [\varphi _{n-1};X_{n}])\).GANDALF can be viewed as a stack of GFLUs arranged in a sequence mannerthat at each stage n selects a subset of features and learns a representations of features and therefore multiple stages act in a hierarchical way to built up the optimal representation for the prediction task. Then this representation is fed to a multi-layer perceptron for the final prediction.SAINTSAINT (self-attention and intersample attention transformer) [44] is inspired by the transformer encoder, where the model takes in a sequence of feature embeddings and outputs contextual representations of the same dimension. Its main idea is to leverage several mechanisms to overcome the difficulties of training on tabular data. For the embedding layer, each feature in the input row is embedded into a e-dimensional space as$$\begin{aligned} \textbf{E} = \text {Embedding}(X), \end{aligned}$$
(15)
where \(\textbf{E}\in \mathbb {R}^{n\times p \times e}\), and e is the embedding dimension. In the stacking of L identical stages, each stage consists of one self-attention transformer block and one intersample attention transformer block. The contextual representations of the input of batch b can be given as \(\{\mathbf {r_{i}}\}_{i=1}^{b} = \textbf{S}(\{\textbf{E}(x_{i})\}_{i=1}^{b})\). When \(L=1\), \(\mathbf {r_{i}}\) can be obtained as the following procedure $$\begin{aligned} \begin{aligned} \textbf{z}_{i}^{(1)}&= \text {LN}(\text {MSA}(\textbf{E}(x_{i}))) + \textbf{E}(x_{i}) \\ \textbf{z}_{i}^{(2)}&= \text {LN}(\text {FF}_{1}(\textbf{z}_{i}^{(1)})) + \textbf{z}_{i}^{(1)}\\ \textbf{z}_{i}^{(3)}&= \text {LN}(\text {MISA}(\{\textbf{z}_{i}^{(2)}\}_{i=1}^{b})) + \textbf{z}_{i}^{(2)} \\ \textbf{r}_{i}&= \text {LN}(\text {FF}_{2}(\textbf{z}_{i}^{(3)})) + \textbf{z}_{i}^{(3)} \end{aligned}, \end{aligned}$$
(16)
where MSA is a multi-head self-attention layer with h heads, FF is a fully-connected feed-forward layer with a GELU non-linearity, LN is a normalization layer with skip connection and MISA is an intersample attention transformer block. For the intersample attention, it is computed across the different data points (i.e. rows of the tabular data matrix) in the batch. This can be helpful to improve the representation of a given data point by inspecting other points. For the self-supervised pretraining method, CutMix is used to augment samples in the input space and mixup is used in the embedding space for the augmented representation. At the final prediction stage, the corresponding embedding is passed through a single layer MLP with ReLU activation to get the output \(\hat{Y}\).Implementation detailsTuningFor each dataset, we tuned the hyperparameters of each model using Bayesian optimization (BO) with 100 iterations. The hyperparameter search was conducted on the validation folds of the training set, ensuring that the test set remained untouched and independent. To optimize the hyperparameters, we used a 5-fold cross-validation (CV) approach on the training set. For each fold, the model was trained on 4 folds and validated on the remaining fold. This process was repeated 5 times, with each fold serving as the validation set once. The performance metrics from the 5 folds were averaged to obtain a single performance measure for the given set of hyperparameters. Various combinations of hyperparameters were evaluated, and the set that provided the best average performance across the 5 folds was selected as the optimal set. Subsequently, we executed models in parallel across each fold and independently calculating the test MSE or test accuracy. The performance metric was then collected and averaged from this parallelized execution to facilitate the Bayesian Optimization process using Tree Parzen Estimator (TPE) for parameter suggestions. The best hyperparameters were selected based on the loss criteria (i.e. MSE or accuracy) of the validation set. This iterative process continued until the predefined stopping criterion was reached. For the TPE method, we relied on the stochasticity inherent in draws from the models, ensuring diverse candidate suggestions from one iteration to the next while incorporating new recommendations from BO [54]. To obtain a balance between time consumption and precision of the performance metric results, we set the BO stopping criterion to 1e-5. The experiments were conducted using 5 NVIDIA Tesla V100 GPUs. Each GPU is equipped with 32 GB of HBM2 memory. The initial parameter ranges of the hyperparameters of the models are public available online along with our code.There are two important hyperparamters in LassoNet: the \(l_1\)-penalty coefficient \(\lambda\) and the hierarchy coefficient M, which control the complexity of the fitted model and the relative strength of the linear and nonlinear components, respectively. First, we performed some initial test runs to determine a suitable range of M and \(\lambda\). For the \(\lambda\), we made sure that the initial dense model with \(\lambda =0\) trained well before starting the regularization path. Then the stepsize over \(\lambda\) was implemented following the same strategy as the original paper.EvaluationFor each tuned configuration, ensemble predictions were generated by conducting 10 experiments with different random seeds, and the average results are reported on the test set. For the multi-trait classification task, evaluation metrics include average classification accuracy with standard deviation (stddev), Bries scores and the area under the curve (AUC) with standard deviation. For the regression task, the metrics reported are the test mean squared error (MSE) with standard deviation and the Pearson correlation coefficient r, averaged across traits for each dataset.

Hot Topics

Related Articles