---
title: "Multinomial Classification"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Multinomial Classification}
  %\VignetteEngine{knitr::knitr}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
use_saved_results <- TRUE

knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  echo = TRUE,
  eval = !use_saved_results,
  message = FALSE,
  warning = FALSE
)

if (use_saved_results) {
  results <- readRDS("vignette_mc.rds")
  pred <- results$pred
}
```

```{r, eval=TRUE}
library(dplyr); library(tidyr); library(purrr) # Data wrangling
library(ggplot2); library(stringr) # Plotting
library(tidyfit)   # Auto-ML modeling
```

Multinomial classification is possible in `tidyfit` using the methods powered by `glmnet`, `e1071` and `randomForest` (LASSO, Ridge, ElasticNet,  AdaLASSO, SVM and Random Forest). Currently, none of the other methods support multinomial classification.^[Feature selection methods such as `relief` or `chisq` can be used with multinomial response variables. I may also add support for multinomial classification with `mboost` in future.] When the response variable contains more than 2 classes, `classify` automatically uses a multinomial response for the above-mentioned methods.

Here's an example using the built-in `iris` dataset:

```{r, eval=TRUE}
data("iris")

# For reproducibility
set.seed(42)
ix_tst <- sample(1:nrow(iris), round(nrow(iris)*0.2))

data_trn <- iris[-ix_tst,]
data_tst <- iris[ix_tst,]

as_tibble(iris)
```

## Penalized classification algorithms to predict `Species`

The code chunk below fits the above mentioned algorithms on the training split, using a 10-fold cross validation to select optimal penalties. We then obtain out-of-sample predictions using `predict`. Unlike binomial classification, the `fit` and `pred` objects contain a `class` column with separate coefficients and predictions for each class. The predictions sum to one across classes:

```{r}
fit <- data_trn %>% 
  classify(Species ~ ., 
           LASSO = m("lasso"), 
           Ridge = m("ridge"), 
           ElasticNet = m("enet"), 
           AdaLASSO = m("adalasso"),
           SVM = m("svm"),
           `Random Forest` = m("rf"),
           `Least Squares` = m("ridge", lambda = 1e-5), 
           .cv = "vfold_cv")

pred <- fit %>% 
  predict(data_tst)
```

Note that we can add unregularized least squares estimates by setting `lambda = 0` (or very close to zero).

Next, we can use `yardstick` to calculate the log loss accuracy metric and compare the performance of the different models:

```{r, fig.width=7, fig.height=3, fig.align="center", eval=TRUE}
metrics <- pred %>% 
  group_by(model, class) %>% 
  mutate(row_n = row_number()) %>% 
  spread(class, prediction) %>% 
  group_by(model) %>% 
  yardstick::mn_log_loss(truth, setosa:virginica)

metrics %>% 
  mutate(model = str_wrap(model, 11)) %>% 
  ggplot(aes(model, .estimate)) +
  geom_col(fill = "darkblue") +
  theme_bw() +
  theme(axis.title.x = element_blank())
```

The least squares estimate performs poorest, while the random forest (nonlinear) and the support vector machine (SVM) achieve the best results. The SVM is estimated with a linear kernel by default (use `kernel = <chosen_kernel>` to use a different kernel).