- Use simple interpretable models. This approach was covered in the previous posts where we looked at logistic regression and decision trees as examples of white box models.
- Conduct post-hoc interpretation on models. There are two are two types of post-hoc analysis which can be done, model specific and model agonistic.
Direction of post
In the next few posts, we will look at model specific post-hoc analysis which involves ranking the variables according to importance to the model. Though these interpretation can be applied on both white box and black box models, we will be explaining black box models in these posts as white box models can be explained by the models themselves. Specifically, we will explain random forest in this post and gradient boosting in future posts.
Similar to the previous posts, the Cleveland heart dataset will be used as well as principles of
#library library(tidyverse) library(tidymodels) #import heart<-read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data", col_names = F) # Renaming var colnames(heart)<- c("age", "sex", "rest_cp", "rest_bp", "chol", "fast_bloodsugar","rest_ecg","ex_maxHR","ex_cp", "ex_STdepression_dur", "ex_STpeak","coloured_vessels", "thalassemia","heart_disease") #elaborating cat var ##simple ifelse conversion heart<-heart %>% mutate(sex= ifelse(sex=="1", "male", "female"),fast_bloodsugar= ifelse(fast_bloodsugar=="1", ">120", "<120"), ex_cp=ifelse(ex_cp=="1", "yes", "no"), heart_disease=ifelse(heart_disease=="0", "no", "yes")) heart<-heart %>% mutate( rest_cp=case_when(rest_cp== "1" ~ "typical",rest_cp=="2" ~ "atypical", rest_cp== "3" ~ "non-CP pain",rest_cp== "4" ~ "asymptomatic"), rest_ecg=case_when(rest_ecg=="0" ~ "normal",rest_ecg=="1" ~ "ST-T abnorm",rest_ecg=="2" ~ "LV hyperthrophy"), ex_STpeak=case_when(ex_STpeak=="1" ~ "up/norm", ex_STpeak== "2" ~ "flat",ex_STpeak== "3" ~ "down"), thalassemia=case_when(thalassemia=="3.0" ~ "norm", thalassemia== "6.0" ~ "fixed", thalassemia== "7.0" ~ "reversable")) # convert missing value "?" into NA heart<-heart%>% mutate_if(is.character, funs(replace(., .=="?", NA))) # convert char into factors heart<-heart %>% mutate_if(is.character, as.factor) #train/test set set.seed(4595) data_split <- initial_split(heart, prop=0.75, strata = "heart_disease") heart_train <- training(data_split) heart_test <- testing(data_split) # create recipe object heart_recipe<-recipe(heart_disease ~., data= heart_train) %>% step_knnimpute(all_predictors()) %>% step_dummy(all_nominal(), -heart_disease)# need dummy variables to be created for some `randomForestexplainer` functions though random forest model itself doesnt need explicit one hot encoding # process the traing set/ prepare recipe(non-cv) heart_prep <-heart_recipe %>% prep(training = heart_train, retain = TRUE)
Ensemble Models (tree-based)
Random forests and gradient boosting are also examples of tree based ensemble models. In tree based ensemble models, multiple variants of a simple tree are combined to develop a more complex model with higher prediction power. The trade off of the ensemble is that the model becomes a black box and cannot be interpreted as simply as a decision tree. Thus, post-hoc analysis is required to explain the ensemble model.
Random forest mechanism
Random forest is a step up of bootstrap aggregating/ bagging and bootstrap aggregating is a step up of decision tree. In bagging, multiple bootstrap samples are of the training data are used to train multiple single regression tree. As bootstrap samples are used, the same observation from the training set can be used more than once to train a regression tree. For regression problems, the ensemble occurs when the predictions of each tree are averaged to provide the predicted value for the bagging model. For classification problems, the ensemble occurs when the class probabilities of each tree are averaged to provide the predicted class. Alternatively, the majority predicted class of the ensemble of trees will be the predicted class for the model. Random forest is similar to bootstrap aggregating except that a random subset of variables is used for each split. The number of variables in this random subset of variables is known as mtry. The default mtry in regression problems is a third of the total number of available variables in the dataset. The default mtry in classification problems is the square root of the total number of available variables.
Training random forest models
We will train two random forest where each model adopts a different ranking approach for feature importance. The two ranking measurements are:
Permutation based. Permuting values in a variable decouples any relationship between the predictor and the outcome which renders the variable pseudo present in the model. When the out of bag sample is passed down the model with the permutated variable, the increased in error/ decreased in accuracy of the model can be calculated. This accuracy/error is compared against the baseline accuracy/error (for regression problems, the measurement will be R2) to calculate the importance score. This computation is repeated for all variables. The variable with the largest decrease in accuracy/largest increase in error is considered the most important variable because not considering the variable has the largest penalty on prediction. In the case of the
randomForestpackage, the score measures the decrease in accuracy and this score is refer as the
Impurity based. Impurity refers to gini impurity/ gini index. The concept of impurity for random forest is the same as regression tree. Features which are more important have a lower impurity score/ higher purity score/ higher decrease in impurity score. The
randomForestpackage, adopts the latter score which known as
set.seed(69) rf_model<-rand_forest(trees = 2000, mtry = 4, mode = "classification") %>% set_engine("randomForest", # importance = T to have permutation score calculated importance=T, # localImp=T for randomForestExplainer(next post) localImp = T, ) %>% fit(heart_disease ~ ., data = juice(heart_prep))
We can extract permutation based importance as well as gini based importance directly from the model trained by the
tidyverse approach to the extract results, remember to convert
MeanDecreaseAccuracy from character to numeric form for
arrange to sort the variables correctly. Otherwise,
R will recognise the value based on the first digit while ignoring log/exp values. For instance, if
MeanDecreaseAccuracy was in character format,
rest_ecg_ST.T.abnorm will have the highest value as
R recognises the first digit in
4.1293934444743e-05 and fails to recognise that e-05 renders that value close to 0.
cbind(rownames(rf_model$fit$importance), rf_model$fit$importance) %>% as_tibble() %>% select(predictor= V1, MeanDecreaseAccuracy) %>% mutate(MeanDecreaseAccuracy=as.numeric(MeanDecreaseAccuracy)) %>% arrange(desc((MeanDecreaseAccuracy))) %>% head()
## # A tibble: 6 x 2 ## predictor MeanDecreaseAccuracy ## <chr> <dbl> ## 1 ex_STdepression_dur 0.0365 ## 2 thalassemia_norm 0.0332 ## 3 thalassemia_reversable 0.0323 ## 4 ex_maxHR 0.0226 ## 5 sex_male 0.0111 ## 6 coloured_vessels_X2.0 0.0109
As the model is trained using the
randomForest package, we can use
randomForest::importance to extract the results too. Use argument
type=1 to extract permutation-based importance and set
scale=F to prevent normalization. There are benefits for not scaling the permutation-based score.
cbind(rownames(randomForest::importance(rf_model$fit, type=1, scale=F)) %>% as_tibble(), randomForest::importance(rf_model$fit, type=1, scale=F) %>% as_tibble())%>% arrange(desc(MeanDecreaseAccuracy)) %>% head()
## value MeanDecreaseAccuracy ## 1 ex_STdepression_dur 0.03651391 ## 2 thalassemia_norm 0.03321383 ## 3 thalassemia_reversable 0.03231399 ## 4 ex_maxHR 0.02256749 ## 5 sex_male 0.01106189 ## 6 coloured_vessels_X2.0 0.01087665
Likewise, we can use both approaches to calculate impurity-based importance. Using the
cbind(rownames(rf_model$fit$importance), rf_model$fit$importance) %>% as_tibble() %>% select(predictor= V1, MeanDecreaseGini) %>% mutate(MeanDecreaseGini=as.numeric(MeanDecreaseGini)) %>% arrange(desc((MeanDecreaseGini))) %>% head(10)
## # A tibble: 10 x 2 ## predictor MeanDecreaseGini ## <chr> <dbl> ## 1 ex_maxHR 14.9 ## 2 ex_STdepression_dur 14.5 ## 3 thalassemia_norm 11.1 ## 4 thalassemia_reversable 9.90 ## 5 rest_bp 8.84 ## 6 age 8.69 ## 7 chol 8.00 ## 8 ex_cp_yes 5.56 ## 9 coloured_vessels_X2.0 3.77 ## 10 rest_cp_typical 3.34
randomForest::importance function, we will change the argument
scale argument does not work for impurity-based scores.
cbind(rownames(randomForest::importance(rf_model$fit, type=2)) %>% as_tibble(), randomForest::importance(rf_model$fit, type=2) %>% as_tibble())%>% arrange(desc(MeanDecreaseGini)) %>% head(10)
## value MeanDecreaseGini ## 1 ex_maxHR 14.850652 ## 2 ex_STdepression_dur 14.537241 ## 3 thalassemia_norm 11.128131 ## 4 thalassemia_reversable 9.902711 ## 5 rest_bp 8.838706 ## 6 age 8.686010 ## 7 chol 7.997923 ## 8 ex_cp_yes 5.557652 ## 9 coloured_vessels_X2.0 3.769974 ## 10 rest_cp_typical 3.342365
To sum up
In this post we did a post-hoc analysis on random forest to understand the model by using permutation and impurity variable importance ranking. In the next post, we will continue post-hoc analysis on random forest model using other techniques.