Quantcast
Channel: Machine Learning | Towards AI
Viewing all articles
Browse latest Browse all 786

How Do Inherently Interpretable AI Models Work? — GAMINET

$
0
0
Author(s): Indraneel Dutta Baruah Originally published on Towards AI. Source: interpretable-ml-book The field of deep learning has grown exponentially and the recent craze about ChatGPT is proof of the same. The models are becoming more and more complex with deeper layers leading to greater accuracy. One issue with this current trend is the focus on interpretability is lost at times. It is very risky to apply these black-box AI systems in real-life applications, especially in sectors like banking and healthcare. For example, a deep neural net used for a loan application scorecard might deny a customer, and we will not be able to explain why. This is where an explainable neural network based on generalized additive models with structured interactions (GAMI-Net) comes into the picture. It is one of the few deep learning models that are inherently interpretable. If you are not aware of how neural networks work, it is highly advised that you brush on that first by going through this blog series. So, without further ado, let’s dive in! What is GAMINET? GAMINET is a neural network that consists of multiple additive subnetworks. Each subnetwork consists of multiple hidden layers and is designed for capturing one main effect (direct impact of an input variable on a target variable) or one pairwise interaction (combined impact of multiple input variables on a target variable). These subnetworks are then additively combined to form the final output. GAMI-Net is formulated as follows: Figure 1 (Source: Yang, Zhang and Sudjianto (2021, Pattern Recognition): GAMI-Net. arXiv: 2003.07132) where hj (xj ) are main effects and fjk(xj , xk) are interaction effects. Source: Yang, Zhang and Sudjianto (2021, Pattern Recognition): GAMI-Net. arXiv: 2003.07132 As shown in the figure above, Each main effect hj (xj ) is calculated by a subnetwork consisting of one input node using xj, numerous hidden layers, and one output node. Similarly, the pairwise interaction fjk(xj , xk) is calculated using a subnetwork with two input nodes (based on xj, xk). Finally, all the subnetworks are combined along with a bias term to predict the target variable. Now that we know the overall structure of GAMINET, let’s discuss how certain constraints help make this model interpretable. GAMINET Constraints To improve the interpretability and identifiability of the impact of an input feature, GAMI-Net is developed with 3 constraints: Sparsity constraint: If we have a lot of input features, GAMI-net can end up having too many main and interaction effects. For example, if we have 100 input features, we will have 100 main effects and 4950 interaction effects! Thus, we need to keep only the critical effects for efficient computation and better interpretability. The importance of a main effect or pairwise interaction can be quantified based on the variance in the data explained by it. For example, the main effect of input feature xj can be measured as: Source: Yang, Zhang and Sudjianto (2021, Pattern Recognition): GAMI-Net. arXiv: 2003.07132 where n is the sample size. GAMI-Net picks the top s1 main effects ranked by D(hj ) values. s1 can be any user-defined number between 1 to the number of input features). Similarly, the top s2 pairwise interactions are picked using D(fjk): Source: Yang, Zhang and Sudjianto (2021, Pattern Recognition): GAMI-Net. arXiv: 2003.07132 Hereditary Constraint: This constraint requires that a pairwise interaction can be included if at least one of its parent main effects is included by s1. This helps prune the number of interaction effects in the model. Marginal Clarity It can become difficult to quantify the impact of an input feature if the main effects are absorbed by their child interactions or vice versa. For example, if there are 3 variables x1, x2 and x3, f12(x1 , x2) and f12(x1 , x2) might absorb the impact of h(x1) itself. The marginal clarity constraint ensures such situations don’t occur by ensuring the impact of a feature can be uniquely decomposed into orthogonal components. This is similar to the functional ANOVA decomposition. Thus, marginal clarity constraints refer to the orthogonality condition for the j-th main effect and the corresponding pairwise interaction (j, k) as follows: Source: Yang, Zhang and Sudjianto (2021, Pattern Recognition): GAMI-Net. arXiv: 2003.07132 To learn more about orthogonal functions, you can go through this video. Now we are ready to understand the model training process! How does GAMI-net get trained? The training process has the following three stages. Training main effects: In this stage, the model trains all the main effect subnetworks for some epochs. It then removes the trivial main effects according to their variance contributions and validation performance Training interaction effects: Train pairwise interaction effects that satisfy the hereditary constraint. Similar to the main effects, prune the weak pairwise interactions according to their variance contributions and validation performance Fine-tune all the network parameters for some epochs Let’s dive deep into each stage next. Steps for training main effects: Here are the main steps involved in training and identifying the main effects: Source: Image by author Main effect subnetworks are simultaneously estimated. The pairwise interaction subnetworks are set at zero in this stage. Mini-batch gradient descent is used for training along with Adam optimizer and adaptive learning rates. The training continues till the maximum number of epochs or the validation accuracy doesn’t increase for some epochs The main effect is centered on having a zero mean such that the bias in the output layer represents the mean of the target variable The top main effects based on sparsity constraint are selected The next step is to estimate the validation performance. Starting with a null model with just the intercept term, its validation set performance is recorded (say l0). It is followed by adding the most important main effect based on the variance explained and its validation performance (say l1) is recorded. The other important main effects will be added in a sequence based on the variance explained and their validation performance is added to the list {l0, l1, · · · , lp} The validation performance will start deteriorating when […]

Viewing all articles
Browse latest Browse all 786

Trending Articles