---
title: "Tidymodels"
bibliography: "biblio.bib"
link-citations: true
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Tidymodels and SHAP}
  %\VignetteEncoding{UTF-8}
  %\VignetteEngine{knitr::rmarkdown}
---

This vignette explains how to use {shapviz} with {Tidymodels}. 

XGBoost and LightGBM are shipped with super-fast TreeSHAP algorithms. Thus, doing a SHAP analysis is quite different from the normal case.

## Normal case

A model fitted with Tidymodels has a `predict()` method that produces a data.frame with predictions. Therefore, working with model-agnostic SHAP (permutation SHAP or Kernel SHAP) is as easy as it can get. But is takes a little bit of time.

```r
library(tidymodels)
library(kernelshap)
library(shapviz)

set.seed(10)

splits <- diamonds |> 
  transform(
    log_price = log(price),
    log_carat = log(carat)
) |> 
  initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
  recipe(log_price ~ log_carat + color + clarity + cut)

rf <- rand_forest(mode = "regression") |>
  set_engine("ranger")

rf_wf <- workflow() |>
  add_recipe(dia_recipe) |>
  add_model(rf)

fit <- rf_wf |>
  fit(df_train)
  
# SHAP analysis
xvars <- c("log_carat", "color", "clarity", "cut")
X_explain <- df_train[1:1000, xvars]  # Use only feature columns

# 1.5 minutes on laptop
# Note: If you have more than p=8 features, use kernelshap() instead of permshap()
system.time(
  shap_values <- fit |> 
    permshap(X = X_explain) |> 
    shapviz()
)
# saveRDS(shap_values, file = "shap_values.rds")
# shap_values <- readRDS("shap_values.rds")

shap_values |> 
  sv_importance("bee")

shap_values |> 
  sv_dependence(xvars)
``` 
![](../man/figures/VIGNETTE-tidy-rf-imp.png)
![](../man/figures/VIGNETTE-tidy-rf-dep.png)

## XGBoost

When your Tidymodel is an XGBoost or LightGBM model, you will almost always want to use their native TreeSHAP implementation. In this case, you need to pass to `shapviz()` the fully prepared explanation matrix `X_pred` and the underlying fit engine.

We will show how to prepare the inputs for `shapviz()`, namely

- the underlying fit engine,
- `X_pred`, the matrix passed to XGBoost's `predict()`,
- and optionally `X`, the dataframe used for visualizations (to see original factor levels etc).

Since XGBoost offers SHAP interactions, we additionally show how to integrate these into the analysis.
of course, you don't *have* to work with SHAP interactions, especially if your model has many predictors.

**Remark:** Don't use 1:m transforms such as One-Hot-Encodings. They are usually not necessary and make the workflow more complicated. If you can't avoid this, check the `collapse` argument in `shapviz()`.

```r
library(tidymodels)
library(shapviz)
library(patchwork)

set.seed(10)

splits <- diamonds |> 
  transform(
    log_price = log(price),
    log_carat = log(carat)
) |> 
  initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
  recipe(log_price ~ log_carat + color + clarity + cut) |> 
  step_integer(all_ordered())

# Should be tuned in practice
xgb_model <- boost_tree(mode = "regression", learn_rate = 0.1, trees = 100) |>
  set_engine("xgboost")

xgb_wf <- workflow() |>
  add_recipe(dia_recipe) |>
  add_model(xgb_model)

fit <- xgb_wf |>
  fit(df_train)

# SHAP Analysis
df_explain <- df_train[1:1000, ]

X_pred <- bake(  # Goes to xgboost:::predict.xgb.Booster()
  prep(dia_recipe), 
  has_role("predictor"),
  new_data = df_explain,
  composition = "matrix"
)

stopifnot(colnames(X_pred) %in% colnames(df_explain))

shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain, interactions = TRUE)

# SHAP importance
shap_values |> 
  sv_importance(show_numbers = TRUE) +
  ggtitle("SHAP importance")

# Absolute average SHAP interactions (off-diagonals already multiplied by 2)
shap_values |> 
  sv_interaction(kind = "no")
#            log_carat     clarity       color         cut
# log_carat 0.87400688 0.067567245 0.032599394 0.024273852
# clarity   0.06756720 0.143393109 0.028236784 0.004910905
# color     0.03259941 0.028236796 0.095656042 0.004804729
# cut       0.02427382 0.004910904 0.004804732 0.031114735

# Usual dependence plot
xvars <- c("log_carat", "color", "clarity", "cut")

shap_values |> 
  sv_dependence(xvars) &
  plot_annotation("SHAP dependence plots")  # patchwork magic

# SHAP interactions for carat
shap_values |> 
  sv_dependence("log_carat", color_var = xvars, interactions = TRUE) &
  plot_annotation("SHAP interactions for carat")
```
![](../man/figures/VIGNETTE-tidy-xgb-imp.png)
![](../man/figures/VIGNETTE-tidy-xgb-dep.png)

![](../man/figures/VIGNETTE-tidy-xgb-inter.png)

## LightGBM

Regarding SHAP analysis and Tidymodels, LightGBM is slightly different from XGBoost:

- It requires {bonsai}.
- It turns factors internally to integers and treats them as LightGBM categoricals. You should avoid this for factors with logical order, so don't forget to manually integer encode such factors in a recipe. For illustration only, we treat "cut" as unordered and let LightGBM use internal encodings.
- LightGBM does not offer SHAP interactions.

```r
library(tidymodels)
library(bonsai)
library(shapviz)

set.seed(10)

splits <- diamonds |> 
  transform(
    log_price = log(price),
    log_carat = log(carat)
) |> 
  initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
  recipe(log_price ~ log_carat + color + clarity + cut) |> 
  step_integer(color, clarity)  # we keep cut a factor (for illustration only)

# Should be tuned in practice
lgb_model <- boost_tree(mode = "regression", learn_rate = 0.1, trees = 100) |>
  set_engine("lightgbm")

lgb_wf <- workflow() |>
  add_recipe(dia_recipe) |>
  add_model(lgb_model)

fit <- lgb_wf |>
  fit(df_train)

# SHAP analysis
df_explain <- df_train[1:1000, ]

X_pred <- bake(   # Goes to lightgbm:::predict.lgb.Booster()
  prep(dia_recipe), 
  has_role("predictor"),
  new_data = df_explain
) |> 
  bonsai:::prepare_df_lgbm()
  
head(X_pred, 2)
#       log_carat color clarity cut
# [1,]  0.3148107     5       5   3
# [2,] -0.5978370     2       3   4

stopifnot(colnames(X_pred) %in% colnames(df_explain))

shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain)

shap_values |> 
  sv_importance(show_numbers = TRUE)

shap_values |> 
  sv_dependence(c("log_carat", "color", "clarity", "cut"))
```
![](../man/figures/VIGNETTE-tidy-lgb-imp.png)
![](../man/figures/VIGNETTE-tidy-lgb-dep.png)

## Probabilistic classification

For probabilistic classification, the code is very similar to above regression examples. 

`shapviz()` returns a list of "shapviz" objects (one per class). Sometimes, you might want to analyze them together, or select an individual class via `$name_of_interesting_class` or `[[`.

### Normal case

Simply pass `type = "prob"` to `kernelshap::kernelshap()` or `kernelshap::permshap()`:

```r
library(tidymodels)
library(kernelshap)
library(shapviz)
library(patchwork)

set.seed(1)

iris_recipe <- iris |> 
  recipe(Species ~ .)

fit <- rand_forest(trees = 100) |>
  set_engine("ranger") |> 
  set_mode("classification")
  
iris_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(fit)

fit <- iris_wf |>
  fit(iris)

# SHAP analysis
X_explain <- iris[-5]  # Feature columns of <=2000 rows from the training data

system.time(  # 2s
  shap_values <- permshap(fit, X_explain, type = "prob") |> 
    shapviz()
)
sv_importance(shap_values)

shap_values |> 
  sv_dependence("Sepal.Length") +
  plot_layout(ncol = 1) +
  plot_annotation("SHAP dependence of one variable for all classes")

# Use $ to extract SHAP values for one class
shap_setosa <- shap_values$.pred_setosa

shap_setosa |> 
  sv_dependence(colnames(X_explain)) +
  plot_annotation("SHAP dependence of all variables for one class")
```

![](../man/figures/VIGNETTE-tidy-class-normal-imp.png)

![](../man/figures/VIGNETTE-tidy-class-normal-dep1.png)

![](../man/figures/VIGNETTE-tidy-class-normal-dep2.png)

### XGBoost

For XGBoost and LightGBM, we again want to use its native TreeSHAP implementation.

We can slightly adapt the code from the regression example:

```r
library(tidymodels)
library(shapviz)
library(patchwork)

set.seed(1)

iris_recipe <- iris |>
  recipe(Species ~ .)

xgb_model <- boost_tree(learn_rate = 0.1, trees = 100) |>
  set_mode("classification") |> 
  set_engine("xgboost", verbose = -1)

xgb_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(xgb_model)

fit <- xgb_wf |>
  fit(iris)
  
# SHAP analysis
df_explain <- iris  # Typically 1000 - 2000 rows from the training data

X_pred <- bake(  # goes to xgboost:::predict.xgb.Booster()
  prep(iris_recipe), 
  has_role("predictor"),
  new_data = df_explain,
  composition = "matrix"
)

stopifnot(colnames(X_pred) %in% colnames(df_explain))

shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain) |> 
  setNames(levels(iris$Species))

shap_values |> 
  sv_importance()

shap_values |> 
  sv_dependence(v = "Sepal.Length", color_var = "Sepal.Width") +
  plot_layout(ncol = 1, guides = "collect")
```

![](../man/figures/VIGNETTE-tidy-class-xgb-imp.png)

![](../man/figures/VIGNETTE-tidy-class-xgb-dep.png)

### LightGBM (binary probabilistic)

Let's complete this vignette by running a binary LightGBM model.

```r
library(tidymodels)
library(bonsai)
library(shapviz)
library(patchwork)

set.seed(1)

# Make factor with two levels
iris$sl_large <- factor(
  iris$Sepal.Length > median(iris$Sepal.Length), labels = c("no", "yes")
)

iris_recipe <- iris |>
  recipe(sl_large ~ Sepal.Width + Petal.Length + Petal.Width + Species) # |> 
  # step_integer(some ordinal factors)
  
lgb_model <- boost_tree(learn_rate = 0.1, trees = 100) |>
  set_mode("classification") |> 
  set_engine("lightgbm", verbose = -1)

lgb_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(lgb_model)

fit <- lgb_wf |>
  fit(iris)
  
# SHAP analysis
df_explain <- iris  # Typically 1000 - 2000 rows from the training data

X_pred <- bake(
  prep(iris_recipe), 
  has_role("predictor"),
  new_data = df_explain
) |> 
  bonsai:::prepare_df_lgbm()

stopifnot(colnames(X_pred) %in% colnames(df_explain))

shap_values <- extract_fit_engine(fit) |> 
  shapviz(X_pred = X_pred, X = df_explain)

shap_values |> 
  sv_importance()

shap_values |> 
  sv_dependence("Species")
```

![](../man/figures/VIGNETTE-tidy-class-lgb-imp.png)

![](../man/figures/VIGNETTE-tidy-class-lgb-dep.png)