---
title: "Multinomial Classification"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Multinomial Classification}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{css, echo = F}
body {
  text-align: justify;
}

p {
  font-weight: 'normal';
  font-size: '16px;
  text-align: justify;
}

p.note {
  font-weight: 'normal';
  font-size: '12px;
  font-family: 'Serif';
  text-align: justify;
}
```

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
```

```{r, eval = TRUE, echo = F}
execute <- T
if(!('ranger' %in% installed.packages())){
  execute <- F
  print("This vignetted requires the `ranger` package to be installed for execution.")
}
```

<p>This example presents the application of the PieGlyph package to show the predicted probabilities of different classes in a multinomial classification problem. This example could be useful for clustering problems which gives probabilities, e.g. `mclust` using model based clustering or `ranger` using random forests.</p>

### Load libraries
```{r setup, warning = F, message = FALSE, eval = execute}
# install.packages('ranger')
library(ranger)
library(ggplot2)
library(PieGlyph)
library(dplyr)
```

### Load data

<p>We are using `iris` dataset which gives the measurements of the sepal length, sepal width, petal length and petal width for 50 flowers from each of <i>Iris setosa, Iris versicolor, </i> and <i>Iris virginica </i></p>
```{r data,warning = F, message = FALSE, eval = execute}
head(iris)
```

### Classification model

<p> We use the random forest algorithm for classifying the samples into the three species according to the four measurements described above.</p>

```{r model, warning = F, eval = execute}
rf <- ranger(Species ~ Petal.Length + Petal.Width + 
                        Sepal.Length + Sepal.Width, 
             data=iris, probability=TRUE)
```

<p>We get the predicted probabilities of each sample belonging to a particular species.</p>

```{r predictions, warning = F, eval = execute}
preds <- as.data.frame(predict(rf, iris)$predictions)
head(preds)
```

<p>Combine the predicted probabilities with the original data for plotting</p>
```{r plot_data, warning = F, eval = execute}
plot_data <- cbind(iris, preds)
head(plot_data)
```

<p>Add a column indicating whether the sample was classified correctly or not</p>

```{r, eval = execute}
plot_data <- plot_data %>% 
    # Do operations on a row basis              
    rowwise() %>% 
    # Select the species with the highest predicted probability as the classified species
    mutate(Predicted = colnames(.)[5 + which.max(c(setosa, versicolor, virginica))]) %>% 
    # Compare whether the selected species is same as the original
    mutate('Classification' = ifelse(Species == Predicted, 'Correct', 'Incorrect')) %>% 
    ungroup()
```

### Create plot

<p>The plot shows a scatterplot of the sepal width and sepal length for the samples in the `iris` dataset. The predicted probabilities of belonging to a particular species for each sample are shown by the pie-chart glyphs. The borders of the pie charts show whether or not the sample was classified correctly.</p>

```{r prob-plot, warning = F, fig.align='center', fig.height = 5, fig.width=7, eval = execute}
ggplot(data=plot_data,
   aes(x=Sepal.Length, y=Sepal.Width))+
   # Pies-charts showing predicted probabilities of the different species
   # Using the pie-border to highlight if the same was classified correctly
   geom_pie_glyph(aes(linetype = Classification,  colour = Classification), 
                  slices = names(preds)) +
   # Colours for sectors of the pie-chart
   scale_fill_manual(values = c('#56B4E9', '#D55E00','#CC79A7'))+
   # Labels for axes and legend
   labs(y = 'Sepal Width', x = 'Sepal Length', fill = 'Prob (Species)')+
   # Adjusting the borders colours and linetypes 
   scale_linetype_manual(values = c(1, 3))+
   scale_colour_manual(values = c('black', 'white'))+
   # Theme of the plot
   theme_minimal()+
   theme(panel.grid = element_line(colour = 'darkgrey'),
         plot.background = element_rect(fill = 'grey', colour = NA))
```