## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)


## ----load libraries-----------------------------------------------------------
library(ale)
library(dplyr)

## ----diamonds_print-----------------------------------------------------------
# Clean up some invalid entries
diamonds <- ggplot2::diamonds |> 
  filter(!(x == 0 | y == 0 | z == 0)) |> 
  # https://lorentzen.ch/index.php/2021/04/16/a-curious-fact-on-the-diamonds-dataset/
  distinct(
    price, carat, cut, color, clarity,
    .keep_all = TRUE
  ) |> 
  rename(
    x_length = x,
    y_width = y,
    z_depth = z,
    depth_pct = depth
  )

# Optional: sample 1000 rows so that the code executes faster.
set.seed(0)
diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]

summary(diamonds)

## ----diamonds_str-------------------------------------------------------------
str(diamonds)

## ----diamonds_price-----------------------------------------------------------
summary(diamonds$price)

## ----train_gam----------------------------------------------------------------
# Create a GAM model with flexible curves to predict diamond prices.
# Smooth all numeric variables and include all other variables.
gam_diamonds <- mgcv::gam(
  price ~ s(carat) + s(depth_pct) + s(table) + s(x_length) + s(y_width) + s(z_depth) +
    cut + color + clarity,
  data = diamonds
  )
summary(gam_diamonds)

## ----ale_simple---------------------------------------------------------------
# Simple ALE without bootstrapping
ale_gam_diamonds <- ALE(gam_diamonds)

## ----create-plots-------------------------------------------------------------
# Print a plot by entering its reference
diamonds_plots <- plot(ale_gam_diamonds)

## ----print-carat, fig.width=3.5, fig.width=4----------------------------------
# Print a plot by entering its reference
get(diamonds_plots, 'carat')

## ----print-ale_simple, fig.width=7, fig.height=11-----------------------------
# Print all plots
plot(diamonds_plots, ncol = 2)

## ----diamonds_new-------------------------------------------------------------
# Bootstraping is rather slow, so create a smaller subset of new data for demonstration
set.seed(0)
new_rows <- sample(nrow(diamonds), 200, replace = FALSE)
diamonds_small_test <- diamonds[new_rows, ]

## ----ale_boot, fig.width=7, fig.height=11-------------------------------------

ale_gam_diamonds_boot <- ALE(
  model = gam_diamonds, 
  data = diamonds_small_test, 
  # Normally boot_it should be set to at least 100, but just 10 here for a faster demonstration
  boot_it = 10
)

# Bootstrapping produces confidence intervals
plot(ale_gam_diamonds_boot) |> 
  print(ncol = 2)

## ----ale_2D-------------------------------------------------------------------
# ALE two-way interactions
ale_2D_gam_diamonds <- ALE(
  gam_diamonds,
  x_cols = list(d2 = TRUE)
)

## ----print-all-2D, fig.width=7, fig.height=7----------------------------------
diamonds_2D_plots <- plot(ale_2D_gam_diamonds)

diamonds_2D_plots |>
  # Select all 2D interactions that involve 'carat'
  subset(list(d2_all = 'carat')) |> 
  print(ncol = 2)

## ----print-specific-ixn, fig.width=5, fig.height=3----------------------------
get(diamonds_2D_plots, ~ carat:clarity)