Explainable machine learning by SEE-Net: closing the gap between interpretable models and DNNs

Synced explanation-enhanced neural network (SEE-Net)We develop a model called Relational-Class Logistic Regression (Rec-LR) which achieves high accuracy by exploiting the power of DNN in training an explainable model. In the case of classification, Rec-LR contains two layers of linear logistic regression. In the case of regression, the first layer is a logistic regression model while the second layer contains linear models. For brevity of terminology, we do not explicitly distinguish the two types of linear models used at the second layer. LR is easy to explain in a global sense because the class decision boundary is given by a linear function of the input variables and the role of each variable can be directly read off from their corresponding coefficients.In a Rec-LR model, the first layer LR predicts a latent class (LC) label, which is to be distinguished from what we ultimately predict—the class of instances/objects. The latent class is called relational class because it determines which linear model will be used to predict the output (instead of the direct value of the output), thus reflecting a coherent relationship between the input and output. Since the LR model can be implemented as one layer of a neural network, Rec-LR is implemented as a shallow neural network (SNN). Rec-LR model boosts the flexibility of fitting data with little reduction in explainability provided that the division into the relational classes is performed by an explainable model, which is true in our case since we employ an LR model in the first layer. Rec-LR may seem deceivingly easy to train as a usual SNN. However, because relational class labels are unknown, we cannot train Rec-LR in a usual stand-alone fashion. The main novelty of our work is to form relational classes based on a DNN so that the prediction power of the DNN can be leveraged.Next, we introduce the mathematical formulation of Rec-LR. Consider a predictor vector \(X\in \mathcal {X}\subset \mathbb {R}^p\), where \(\mathbb {R}^p\) is the p-dimensional Euclidean space, and a response \(Y\in \mathcal {Y}\). For regression, Y takes a continuum of values, \(\mathcal {Y}\subset \mathbb {R}\). For classification, Y is categorical. Denote the number of classes by \(M=|\mathcal {Y}|\). Denote a realization of X by \(\textbf{x}\) and that of Y by y. The input data \(\{\textbf{x}_1,…, \textbf{x}_n\}\) are i.i.d. samples of X. Let the input data matrix \(\textbf{X}=(\textbf{x}_1, \textbf{x}_2, …, \textbf{x}_n)^t\in \mathbb {R}^{n\times p}\), where each row corresponds to one instance. Let \(\textbf{x}^{\prime }\in \mathbb {R}^{p^\prime }\) be the intra-latent-class (intra-LC) explainable features that are input to the linear model within a particular relational class (also called a regional linear model). Definitions of intra-LC features are consistent within each relational class but differ across different relational classes. The degree of locality for these features is influenced by the number of relational classes, which in our experiments was no more than six for every dataset. Importantly, the locality is not defined at the level of individual data points, but rather at a broader class level.Let \(\textbf{x}^{\prime \prime }\in \mathbb {R}^{p^{\prime \prime }}\) be the inter-latent-class (inter-LC) explainable features used to predict the relational classes. Depending on the application, \(\textbf{x}’\) and \(\textbf{x}”\) can simply be the original input \(\textbf{x}\), or they can be processed data from \(\textbf{x}\) that are still interpretable. The Rec-LR model is given by:$$\begin{aligned} \widetilde{y}_j= & \textbf{b}_j^T\textbf{x}^\prime +a_j, \;\; j=1,\cdots ,J, \\ \textbf{v}= & {\text {softmax}}(\textbf{V}\textbf{x}^{\prime \prime }+\textbf{c}), \\ \hat{y}= & <\textbf{v}, \widetilde{\textbf{y}}>, \end{aligned}$$where \(<\cdot ,\cdot>\) is the inner product, and \(\hat{y}\) is the final predicted value, and \(\widetilde{\textbf{y}} =(\widetilde{y}_1,\cdots ,\widetilde{y}_{J})^T\) is the vector of predicted values by the J regional linear models, and \(\textbf{v}\in [0,1]^{J}\) is the weights estimated from the inter-LC explainable features. In the case of classification, the linear predictions \(\widetilde{y}_j\), \(j=1,…, J\), will go through softmax. Note that \(\textbf{v}\) contains the posterior probabilities of the relational classes. We need to estimate \(\textbf{b}_j \in \mathbb {R}^{p^\prime }\), \(a_j\in \mathbb {R}\), for \(j=1,\cdots ,J\), and \(\textbf{V} \in \mathbb {R}^{J\times p^{\prime \prime }}\), \(\textbf{c} \in \mathbb {R}^{J}\). In particular, to estimate \(\textbf{V}\), we need the relational class labels.To generate the latent relational class labels, we utilize a DNN. Subsequently, the classification of these latent classes is performed using linear logistic regression to preserve explainability. To facilitate end-to-end training and optimize prediction performance, we propose a new structure called Synced Explanation-Enhanced neural network (SEE-Net), shown in Fig. 1b-c, to train Rec-LR. SEE-Net contains two connected neural networks called interpretable NN and guidance NN. The interpretable NN, functionally a Rec-LR model, is the final prediction model to be used on test data, while the guidance NN, which has no structural restriction, is trained for the ultimate purpose of training the interpretable NN. Specifically, the guidance NN is used to generate latent classes, while the interpretable NN attempts to form relational classes that match those latent classes using linear models. If these linear models in the interpretable NN are bypassed and the latent classes produced by the guidance NN are used directly, we call the resulting network pre-interpretable NN, which does not offer the same high level of explainability as the interpretable NN. However, classification performance of the pre-interpretable NN will also be reported to demonstrate that the interpretable NN can achieve similar accuracy without accessing the guidance NN.When a DNN uses the ReLU function to threshold the output of neurons, the final outcome is a linear function within any of the numerous regions created by the outputs of the intermediate layers of the DNN. We are thus motivated to create relational classes in Rec-LR by clustering intermediate DNN outputs across layers, enabling the DNN to aid in training Rec-LR. However, clustering is not readily realized by neural network modules, a technical challenge that we tackle using stochastic gates so that the interpretable NN in SEE-Net can be trained in an end-to-end fashion. Next, we describe how relational classes are formed in SEE-Net.Generate relational classes by Guidance NNThe effectiveness of the latent relational classes is crucial for the accuracy of the final Rec-LR model. We utilize the guidance NN which is a DNN to form these latent classes. The DNN is embedded within the pre-interpretable NN to facilitate end-to-end training. We first present the basic notations for the guidance NN. Consider a network with L hidden layers and one output layer. Let \(p_l\) be the dimension of output at the l-th hidden layer, \(l=0, …, L\), and \(p_0 = p\). Let \(\textbf{z}_l \in \mathbb {R}^{p_l}\) be the output of the l-th hidden layer. Set \(\textbf{z}_0 = \textbf{x}\). The mapping at the l-th hidden layer is denoted by \(h_l(\cdot )\), which can be any type of hidden layer, e.g., dense, convolutional, or recurrent. We have \(\displaystyle \textbf{z}_l =h_l(\textbf{z}_{l-1})\), \(l=1,\cdots ,L\). The output layer, \(o(\textbf{z}_L)\), is defined as either a linear or softmax function depending on whether the purpose is regression or classification.We cluster the outputs of the hidden layers of the guidance NN to generate relational classes for the input \(\textbf{x}\). This approach is motivated by the following consideration. In a DNN, since each of the hidden layers is the composite function of an affine transformation and an activation function, when the output at each layer is clustered such that the non-linear effect of the activation function can be neglected within a cluster, the composite function within the cluster is linear. In particular, if the activation function is ReLU, the hidden layers become piecewise linear and the entire network is also piecewise linear. Consequently, within each cluster of \(\textbf{x}\), it becomes logical to predict y using linear models. The clustering process includes the following two major steps.

1.

Clustering at one hidden layer of a DNN:
Instead of using conventional clustering to generate the relational classes, which does not integrate well with modules of a neural network architecture and thus prevents end-to-end training, we opt for clustering through linear operations followed by softmax (essentially, logistic regression), as shown in Fig. 1c. Let \(k_l\) be a user-specified hyperparameter specifying the number of clusters to be generated from the l-th hidden layer outputs, \(\textbf{z}_l\), in the guidance NN. Let \({\pmb {\rho }}_l\) be the posterior probabilities of the clusters at the l-th hidden layer (that is, to perform soft-clustering). Specifically, \({\pmb {\rho }}_l = ({\rho }_1^{(l)},\cdots ,{\rho }_{k_l}^{(l)})^T\), \({\rho }_j^{(l)}\in [0,1]\) for \(j=1,\cdots ,k_l\), and \(\sum \limits _{j=1}^{k_l} {\rho }_j^{(l)}=1\). In SEE-net, \({\mathbf {\rho }}_l\)’s are computed by an affine transformation of \(\textbf{z}_l\) followed by a softmax function: $$\begin{aligned} (\widetilde{z}_1^{(l)},\cdots \widetilde{z}_{k_l}^{(l)})^T = \textbf{U}_l \textbf{z}_l+\textbf{d}_l \;, \quad {\pmb {\rho }}_l = {\text {softmax}}\left( (\widetilde{z}_1^{(l)},\cdots , \widetilde{z}_{k_l}^{(l)})^T\right) \;, \end{aligned}$$ where \(\textbf{U}_l\in \mathbb {R}^{k_l\times p_l}\), \(\textbf{d}_l\in \mathbb {R}^{k_l}\), and \(\widetilde{z}_j^{(l)}\in \mathbb {R}\) for \(j=1\cdots ,k_l\). The softmax function \({\text {softmax}}(\cdot )\) with a temperature parameter \(\lambda\) is given by the following equation. As \(\lambda\) decreases, the posterior probability converges to one for one class and to zero for others, achieving hard clustering. $$\begin{aligned} {\text {softmax}}\left( (\widetilde{z}_1^{(l)},\cdots , \widetilde{z}_{k_l}^{(l)})^T\right) = \left( \frac{e^{\widetilde{z}_1^{(l)}/\lambda }}{\sum _{j=1}^{k_l} e^{\widetilde{z}_j^{(l)}/\lambda }},\cdots ,\frac{e^{\widetilde{z}_{k_l}^{(l)}/\lambda }}{\sum _{j=1}^{k_l} e^{\widetilde{z}_j^{(l)}/\lambda }}\right) ^T \;. \end{aligned}$$

2.

Clustering across the layers of a DNN:
Since there are \(k_l\) clusters at the lth layer, the sequence of cluster labels across all the L layers has \(K=\prod _{l=1}^{L}k_l\) configurations. We can take each configuration as one cluster, which we call a guided cluster since these clusters are generated with the help of the guidance NN. The guided clusters are the Cartesian product clusters of the clusters generated at each layer. The soft-clustering probabilities of all the guided clusters are computed by the Kronecker product of \(\pmb {\rho }_l\), \(l=1,\cdots ,L\). Denote by \(\textbf{u}\) the output of the Kronecker product, \(\displaystyle \textbf{u} = \bigotimes _{l=1}^{L}{\pmb {\rho }}_l\). Then, \(\textbf{u}=(u_1,\cdots ,u_K)\in [0,1]^K\), and \(\sum \limits _{j=1}^K u_j=1\). Each entry of \(\textbf{u}\) is the probability of a corresponding guided cluster. It is impractical to use the guided clusters directly as the relational classes to build a Rec-LR because the number of guided clusters grows exponentially with the number of layers. A Rec-LR with too many relational classes can become difficult to interpret. As shown in Fig. 1c, the guided clusters are merged into the relational classes using stochastic gates. Stochastic gates within SEE-Net are seamlessly optimized through end-to-end learning.
Denote the number of merged clusters, that is, the final relational classes, by J, usually, \(J\ll K\). Also refer to the soft-clustering probabilities of the relational classes as component weights \(\textbf{w}=(w_1,…, w_{J})^T\). \(\textbf{w}\in [0,1]^{J}\), and \(\sum \limits _{j=1}^{J} w_j=1\). The merging of the guided clusters can be represented by a group-mask matrix \(\textbf{M}\) of size \(J \times K\), \(\textbf{M} \in \{0,1\}^{J\times K}\) and \(\sum _i \textbf{M}_{ij} = 1\). The element of \(\textbf{M}\), \(M_{ij}=1\) if the j-th guided cluster belongs to the i-th merged cluster and 0 otherwise. Given \(\textbf{M}\), we obtain the component weights by \(\textbf{w} = \textbf{M}\textbf{u}\). However, given the discrete nature of \(\textbf{M}\), it cannot be trained in the framework of neural network. To overcome this hurdle, we are inspired to use the CONCRETE (CONtinuous relaxations of disCRETE) distribution31 to generate a soft group-mask matrix \(\textbf{M}^\prime\) that mimics the discrete \(\textbf{M}\) and is trainable in a unified framework of SEE-Net. The idea of using the CONCRETE distribution for regularization has been explored by 32. Here, we extend the CONCRETE distribution to handle more than two classes and use the extension to emulate the process of merging clusters.
To train the CONCRETE group-mask matrix \(\textbf{M}^\prime\), we re-parametrize it using real-valued CONCRETE random variables \(\pmb {\Psi }=(\psi _{ij})_{i=1,\cdots J; j=1\cdots K}\) of the same dimensionality as \(\textbf{M}^\prime\). Every element of \(\textbf{M}^\prime\) is ensured to be non-negative and no greater than 1, and can be controlled to approach either 0 or 1. Specifically, let \(q(\pmb {\Psi }|\pmb {\phi })\) be a hard CONCRETE multinoulli distiribution with parameters \(\pmb {\phi }\). The function of \(q(\pmb {\Psi }|\pmb {\phi })\) is called the stochastic gate. The parameter \(\pmb {\phi } = (\pmb {\alpha }, \beta , \gamma , \zeta )\) contains a trainable location parameter \(\pmb {\alpha }=(\alpha _{ij})_{i=1,\cdots J; j=1\cdots K}\in \mathbb {R}^{J\times K}\), pre-given real-valued hyperparameters \(\beta\), \(\gamma\), \(\zeta\), and an element-wise truncation function \(\text {trunc}(x)\) such that \(\text {trunc}(x) = \min (1,\max (0,x))\). The hyperparameter \(\beta >0\), called temperature, controls the extent to which \(\textbf{M}^\prime\) resembles a discrete group-mask matrix (more so at a smaller \(\beta\)), and \(\gamma <0\) and \(\zeta >1\) are stretch parameters that stretch a probability value to the interval \((\gamma , \zeta )\).
To estimate \(\pmb {\phi }\) and construct \(\pmb {\Psi }\), we first sample \(\epsilon _{i j}\), \(i=1, \cdots , J-1\), \(j=1, \cdots , K\), from \((J-1)\times K\) independent uniform distributions. Let \(\epsilon _{(\cdot , j)}\) be the ordered statistics of \(\epsilon _{\cdot , j}\). Set \(\epsilon _{(0,j)}=0\) and \(\epsilon _{(J,j)}=1\). Let \(\delta _{(i,j)}= \epsilon _{(i,j)}-\epsilon _{(i-1,j)}\). The random variables \({\psi }_{ij}\) for \(i=1,\cdots ,J,\, j=1,\cdots ,K\), are computed by \(\psi _{ij} = \text {trunc}\left( \widetilde{\psi }_{ij}(\zeta -\gamma )+\gamma \right)\) where $$\begin{aligned} (\widetilde{\psi }_{1j},\cdots ,\widetilde{\psi }_{Jj})^T = {\text {softmax}}\left( \left( \frac{\log (\delta _{(1 j)})-\log (1-\delta _{(1 j)})+\log \alpha _{1 j}}{\beta },\cdots ,\frac{\log (\delta _{(J j)})-\log (1-\delta _{(J j)})+\log \alpha _{J j}}{\beta }\right) ^T\right) \,. \;\;\; \end{aligned}$$The location parameters \(\alpha _{ij}\)’s are key for re-parameterizing \(\textbf{M}^\prime\) and will be trained in an end-to-end fashion as part of the pre-interpretable NN. The soft group-mask matrix \(\textbf{M}^\prime\) is obtained by the column-wise normalization of \(\pmb {\Psi }\) so that \(\sum _{i} M^\prime _{ij}=1\) for \(j=1,\cdots ,K\). Finally, we compute the relational class posterior probabilities (mixture component weights) by \(\textbf{w} = \textbf{M}^\prime \textbf{u}\). As the number of non-empty relational classes is determined during the training of stochastic gates, it can become smaller than J, the pre-selected target number.

Although the guidance NN is used to produce the weights \(\textbf{w}\) for the relational classes, we do not use \(\textbf{w}\) when applying the final prediction model, because \(\textbf{w}\) is generated by a DNN, not explainable by the standard of the interpretable NN. In other words, \(\textbf{w}\) is computed only for training, and subsequently, an interpretable module in SEE-net is trained to infer \(\textbf{w}\). The interpretable NN is also trained in the end-to-end fashion with the capacity of selecting variables from \(\textbf{x}^{\prime }\) and \(\textbf{x}^{\prime \prime }\) for the global and regional linear models respectively. For end-to-end training, it is difficult to apply directly the \(L_1\) penalty on the linear coefficients, as is done in Lasso. Instead, we employ the technique of hard CONCRETE Bernoulli distribution as introduced by31, which generates a mask of variable selection on the intra- and inter-LC features \(\textbf{x}^{\prime }\) and \(\textbf{x}^{\prime \prime }\).The possible candidates for \(\textbf{x}^{\prime }\) and \(\textbf{x}^{\prime \prime }\) are different depending on the specific task and data types. For example, in image analysis, instead of raw pixel values, extracted features by low-level filtering, e.g., edge filters, wavelets, first convolution layer outputs, are often used. We use the first convolutional layer output or other relatively simple transformation of input features \(\textbf{x}\) to find reasonable \(\textbf{x}^{\prime }\) and \(\textbf{x}^{\prime \prime }\) for each task.Training and optimization criteria for SEE-NetSeveral types of losses are considered in SEE-net, a weighted combination of which is defined as the overall loss. Depending on different purposes, a certain type of loss can be more prominent than the others.Let \(\hat{y}_i^{I}\) be the prediction of the ith response, \(i=1, …, n\), by the interpretable NN, and \(\hat{y}_i^{G}\) be the prediction by the pre-interpretable NN. Note that because the guidance NN does not include an output layer that predicts the response Y, the pre-interpretable NN is a combination of the guidance NN, the stochastic gates, and the linear models within each regional class. The difference between the processes that predict respectively \(\hat{y}_i^{G}\) and \(\hat{y}_i^{I}\) lies in whether the “original” or “predicted” relational classes are used. If we replace the module of the interpretable NN that contains linear models for predicting regional classes by the concatenation of the guidance NN and stochastic gates, we get \(\hat{y}_i^{G}\) rather than \(\hat{y}_i^{I}\). While both the interpretable and pre-interpretable NNs incorporate a module of regional linear models, as we will shortly elaborate, it is important to note that these models are trained at different stages. This distinction results in variations in parameters between the regional models in the two networks.We can measure the total loss between the prediction of the interpretable NN and the true response by \(\sum _{i=1}^{n}\mathcal {L}(\hat{y}_i^{I}, y_i)\). Similarly, the total prediction loss for the pre-interpretable NN is \(\sum _{i=1}^{n}\mathcal {L}(\hat{y}_i^{G}, y_i)\), which is the objective function used to train this network. The connection between the interpretable NN and the guidance NN is captured in the relational classes—the posterior \(\textbf{v}\) estimated by the interpretable NN aims at being aligned with the posterior \(\textbf{w}\) estimated by the guidance NN. We measure the distance between the posterior vectors by the relative entropy \(\mathcal {K}(\textbf{w}_i||\textbf{v}_i)\) and define \(\sum _{i=1}^{n}\mathcal {K}(\textbf{w}_i||\textbf{v}_i)\) as a loss.Consider the objective of achieving high prediction accuracy through the interpretable NN. Specifically, predictions are provided by the Rec-LR, which is derived from the interpretable NN, while the guidance NN identifies the relational classes. To maximize prediction accuracy by effectively leveraging the relational classes, we incorporate two loss functions. The first is the empirical prediction loss, \(\sum _{i=1}^{n}\mathcal {L}(\hat{y}_i^{I}, y_i)\), and the second is \(\sum _{i=1}^{n}\mathcal {K}(\textbf{w}_i||\textbf{v}_i)\). Let \(\mathcal {R}(\theta )\) be the overall loss of SEE-Net where \(\theta\) is the collection of parameters. We define \(\mathcal {R}(\theta )\) by$$\begin{aligned} \mathcal {R}(\theta ) = \frac{1}{n}\sum \limits _{i=1}^n\left[ \mathcal {L}(\hat{y}_i^{I},y_i) +\lambda \mathcal {K}(\textbf{w}_i||\textbf{v}_i)\right] \,, \end{aligned}$$
(1)
where \(\lambda\) is a hyperparameter.We conduct training by two stages to optimize SEE-Net. In the first stage, we train the pre-interpretable NN to produce \(\textbf{w}_i\)’s by minimizing the average loss \(\mathcal {L}(\hat{y}_i^{G}, y_i)\). After establishing the guidance NN and setting the stochastic gates, we proceed to the second stage, where the interpretable NN is trained. Here, we focus on minimizing the objective function in Eq. (1). During this stage, we train the linear models that predict the relational classes, and concurrently, we refit the regional models within each relational class for the optimal performance of the interpretable NN. Theoretically, it would be possible to train the entire SEE-Net in a single stage. However, our experimental study indicates that SEE-Net may not train effectively if we attempt to optimize the linear models that predict the relational classes without first establishing a useful set of relational classes. By adopting a two-stage training strategy, we effectively circumvent this challenge.

Hot Topics

Related Articles