## ----include = FALSE----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
knitr::opts_chunk$set(
  size = "huge",
  collapse = TRUE,
  comment = "#>",
  eval = torch::torch_is_installed(),
  fig.align = "center",
  out.width = "95%"
)

## ----echo = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Sys.setenv(LANG = "en_US.UTF-8")
set.seed(1111)

## ----pressure, echo=FALSE, fig.cap = "**Figure 1:** Feature attribution methods"----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
knitr::include_graphics("images/feature_attribution.png")

## ----echo=FALSE, fig.cap = "**Figure 2:** innsight package"-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
knitr::include_graphics("images/innsight_torch.png")

## ----eval = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Step 0: Model creation
# model <- ... # this step is left to the user
# 
# # Step 1: Convert the model
# converter <- convert(model)
# converter <- Converter$new(model) # the same but without helper function
# 
# # Step 2: Apply selected method to your data
# result <- run_method(converter, data)
# result <- Method$new(converter, data) # the same but without helper function
# 
# # Step 3: Show and plot the results
# get_result(result) # get the result as an `array`, `data.frame` or `torch_tensor`
# plot(result) # for individual results (local)
# plot_global(result) # for summarized results (global)
# boxplot(result) # alias for `plot_global` for tabular and signal data

## ----eval = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Using the helper function `convert`
# converter <- convert(model, ...)
# # It simply passes all arguments to the initialization function of
# # the corresponding R6 class, i.e., it is equivalent to
# converter <- Converter$new(model, ...)

## -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
library(torch)
library(innsight)
torch_manual_seed(123)

# Create model
model <- nn_sequential(
  nn_linear(3, 10),
  nn_relu(),
  nn_linear(10, 2, bias = FALSE),
  nn_softmax(2)
)
# Convert the model
conv_dense <- convert(model, input_dim = c(3))
# Convert model with input and output names
conv_dense_with_names <-
  convert(model,
    input_dim = c(3),
    input_names = list(c("Price", "Weight", "Height")),
    output_names = list(c("Buy it!", "Don't buy it!"))
  )

## ----eval = keras::is_keras_available() & torch::torch_is_installed()---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
library(keras)

# Create model
model <- keras_model_sequential()
model <- model %>%
  layer_conv_2d(4, c(5, 4), input_shape = c(10, 10, 3), activation = "softplus") %>%
  layer_max_pooling_2d(c(2, 2), strides = c(1, 1)) %>%
  layer_conv_2d(6, c(3, 3), activation = "relu", padding = "same") %>%
  layer_max_pooling_2d(c(2, 2)) %>%
  layer_conv_2d(4, c(2, 2), strides = c(2, 1), activation = "relu") %>%
  layer_flatten() %>%
  layer_dense(5, activation = "softmax")

# Convert the model
conv_cnn <- convert(model)

## -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
library(neuralnet)
data(iris)

# Create model
model <- neuralnet(Species ~ Petal.Length + Petal.Width, iris,
  linear.output = FALSE
)

# Convert model
conv_dense <- convert(model)

## -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
model <- list(
  input_dim = 2,
  input_names = list(c("X1", "Feat2")),
  input_nodes = 1,
  output_nodes = 2,
  layers = list(
    list(
      type = "Dense", weight = matrix(rnorm(10), 5, 2), bias = rnorm(5),
      activation_name = "relu", input_layers = 0, output_layers = 2
    ),
    list(
      type = "Dense", weight = matrix(rnorm(5), 1, 5), bias = rnorm(1),
      activation_name = "sigmoid", input_layers = 1, output_layers = -1
    )
  )
)

converter <- convert(model)

## -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
converter

## ----eval = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# method <- Method$new(converter, data, # required arguments
#   channels_first = TRUE, # optional settings
#   output_idx = NULL, # .
#   ignore_last_act = TRUE, # .
#   ... # other args and method-specific args
# )

## ----eval = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# method <- run_method(converter, data, # required arguments
#   channels_first = TRUE, # optional settings
#   output_idx = NULL, # .
#   ignore_last_act = TRUE, # .
#   ... # other args and method-specific args
# )

## ----results='hide', message=FALSE, eval = keras::is_keras_available() & torch::torch_is_installed()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Apply method 'Gradient' for the dense network
grad_dense <- Gradient$new(conv_dense, iris[-c(1, 2, 5)])

# You can also use the helper function `run_grad`
grad_dense <- run_grad(conv_dense, iris[-c(1, 2, 5)])

# Apply method 'Gradient x Input' for CNN
x <- torch_randn(c(10, 3, 10, 10))
grad_cnn <- run_grad(conv_cnn, x, times_input = TRUE)

## ----results='hide', message=FALSE, eval = keras::is_keras_available() & torch::torch_is_installed()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Apply method 'SmoothGrad' for the dense network
smooth_dense <- run_smoothgrad(conv_dense, iris[-c(1, 2, 5)])

# Apply method 'SmoothGrad x Input' for CNN
x <- torch_randn(c(10, 3, 10, 10))
smooth_cnn <- run_smoothgrad(conv_cnn, x, times_input = TRUE)

## ----results='hide', message=FALSE, eval = keras::is_keras_available() & torch::torch_is_installed()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Apply method 'IntegratedGradient' for the dense network
intgrad_dense <- run_intgrad(conv_dense, iris[-c(1, 2, 5)])

# Apply method 'IntegratedGradient' for CNN with the average baseline
x <- torch_randn(c(10, 3, 10, 10))
x_ref <- x$mean(1, keepdim = TRUE)
intgrad_cnn <- run_intgrad(conv_cnn, x, x_ref = x_ref)

## ----results='hide', message=FALSE, eval = keras::is_keras_available() & torch::torch_is_installed()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Apply method 'ExpectedGradient' for the dense network
expgrad_dense <- run_expgrad(conv_dense, iris[-c(1, 2, 5)],
                             data_ref = iris[-c(1, 2, 5)])

# Apply method 'ExpectedGradient' for CNN
x <- torch_randn(c(10, 3, 10, 10))
data_ref <- torch_randn(c(20, 3, 10, 10))
expgrad_cnn <- run_expgrad(conv_cnn, x, data_ref = data_ref)

## ----results='hide', message=FALSE, eval = keras::is_keras_available() & torch::torch_is_installed()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Apply method 'LRP' for the dense network
lrp_dense <- run_lrp(conv_dense, iris[-c(1, 2, 5)])

# Apply method 'LRP' for CNN with alpha-beta-rule
x <- torch_randn(c(10, 10, 10, 3))
lrp_cnn <- run_lrp(conv_cnn, x,
  rule_name = "alpha_beta", rule_param = 1,
  channels_first = FALSE
)

## ----results='hide', message=FALSE, eval = keras::is_keras_available() & torch::torch_is_installed()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Define reference value
x_ref <- array(colMeans(iris[-c(1, 2, 5)]), dim = c(1, 2))
# Apply method 'DeepLift' for the dense network
deeplift_dense <- run_deeplift(conv_dense, iris[-c(1, 2, 5)], x_ref = x_ref)

# Apply method 'DeepLift' for CNN (default is a zero baseline)
x <- torch_randn(c(10, 3, 10, 10))
deeplift_cnn <- run_deeplift(conv_cnn, x)

## ----results='hide', message=FALSE, eval = keras::is_keras_available() & torch::torch_is_installed()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Apply global method 'ConnectionWeights' for a dense network
connectweights_dense <- run_cw(conv_dense)

# Apply local method 'ConnectionWeights' for a CNN
# Note: This variant requires input data
x <- torch_randn(c(10, 3, 10, 10))
connectweights_cnn <- run_cw(conv_cnn, x, times_input = TRUE)

## ----eval = keras::is_keras_available() & torch::torch_is_installed()---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
smooth_cnn

## ----eval = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Get the result with the class method
# method$get_result(type = "array")
# 
# # or use the S3 function
# get_result(method, type = "array")

## ----eval = keras::is_keras_available() & torch::torch_is_installed()---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Get result (make sure 'grad_dense' is defined!)
result_array <- grad_dense$get_result()

# or with the S3 method
result_array <- get_result(grad_dense)

# Show the result for data point 1 and 71
result_array[c(1, 71), , ]

## ----eval = keras::is_keras_available() & torch::torch_is_installed()---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Get result as data.frame (make sure 'lrp_cnn' is defined!)
result_data.frame <- lrp_cnn$get_result("data.frame")

# or with the S3 method
result_data.frame <- get_result(lrp_cnn, "data.frame")

# Show the first 5 rows
head(result_data.frame, 5)

## ----eval = keras::is_keras_available() & torch::torch_is_installed()---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Get result (make sure 'deeplift_dense' is defined!)
result_torch <- deeplift_dense$get_result("torch_tensor")

# or with the S3 method
result_torch <- get_result(deeplift_dense, "torch_tensor")

# Show for datapoint 1 and 71 the result
result_torch[c(1, 71), , ]

## ----eval = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Create a plot for single data points
# plot(method,
#   data_idx = 1, # the data point to be plotted
#   output_idx = NULL, # the indices of the output nodes/classes to be plotted
#   output_label = NULL, # the class labels to be plotted
#   aggr_channels = "sum",
#   as_plotly = FALSE, # create an interactive plot
#   ... # other arguments
# )
# 
# # Create a plot with summarized results
# plot_global(method,
#   output_idx = NULL, # the indices of the output nodes/classes to be plotted
#   output_label = NULL, # the class labels to be plotted
#   ref_data_idx = NULL, # the index of an reference data point to be plotted
#   aggr_channels = "sum",
#   as_plotly = FALSE, # create an interactive plot
#   ... # other arguments
# )
# 
# # Alias for `plot_global` for tabular and signal data
# boxplot(...)

## ----eval = keras::is_keras_available() & torch::torch_is_installed(), fig.height=6, fig.width=9------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Plot the result of the first data point (default) for the output classes '1', '2' and '3'
plot(smooth_dense, output_idx = 1:3)
# You can plot several data points at once
plot(smooth_dense, data_idx = c(1, 144), output_idx = 1:3)
# Plot result for the first data point and first and fourth output classes
# and aggregate the channels by taking the Euclidean norm
plot(lrp_cnn, aggr_channels = "norm", output_idx = c(1, 4))

## ----eval = FALSE-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # Create a plotly plot for the first output
# plot(lrp_cnn, aggr_channels = "norm", output_idx = c(1), as_plotly = TRUE)

## ----fig.width = 8, fig.height=4, echo = FALSE, message=FALSE, eval=Sys.getenv("RENDER_PLOTLY", unset = 0) == 1---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # You can do the same with the plotly-based plots
# p <- plot(lrp_cnn, aggr_channels = "norm", output_idx = c(1), as_plotly = TRUE)
# plotly::config(print(p))

## ----eval = keras::is_keras_available() & torch::torch_is_installed(), fig.height=6, fig.width=9------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Create boxplot for the first two output classes
plot_global(smooth_dense, output_idx = 1:2)
# Use no preprocess function (default: abs) and plot a reference data point
plot_global(smooth_dense,
  output_idx = 1:3, preprocess_FUN = identity,
  ref_data_idx = c(55)
)

## ----fig.height=6, fig.width=9, eval = FALSE----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # You can do the same with the plotly-based plots
# plot_global(smooth_dense,
#   output_idx = 1:3, preprocess_FUN = identity,
#   ref_data_idx = c(55), as_plotly = TRUE
# )

## ----fig.width = 8, fig.height=4, echo = FALSE, message=FALSE, eval=Sys.getenv("RENDER_PLOTLY", unset = 0) == 1 & torch::torch_is_installed()---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# # You can do the same with the plotly-based plots
# p <- plot_global(smooth_dense,
#   output_idx = 1:3, preprocess_FUN = identity,
#   ref_data_idx = c(55), as_plotly = TRUE
# )
# plotly::config(print(p))