---
title: "Introduction to parttree"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Introduction to parttree}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  warning = FALSE,
  out.width = "90%",
  # fig.width = 8,
  # dpi = 300,
  asp = 0.625
)
```

## Motivating example: Classifying penguin species

Start by loading the **parttree** package alongside **rpart**, which comes
bundled with the base R installation. For the basic examples that follow, 
I'll use the well-known
[Palmer Penguins](https://allisonhorst.github.io/palmerpenguins/) dataset to
demonstrate functionality. You can load this dataset via the parent package (as 
I have here), or import it directly as a CSV
[here](https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv).

```{r setup}
library(parttree)  # This package
library(rpart)     # For fitting decisions trees

# install.packages("palmerpenguins")
data("penguins", package = "palmerpenguins")
head(penguins)
```

```{r thread_control, echo=-1}
data.table::setDTthreads(2)
```

Dataset in hand, let's say that we are interested in predicting penguin
_species_ as a function of 1) flipper length and 2) bill length. We could model
this as a simple decision tree:

```{r tree}
tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins)
tree
```

Like most tree-based frameworks, **rpart** comes with a default `plot` method
for visualizing the resulting node splits.

```{r rpart_plot}
plot(tree, compress = TRUE)
text(tree, use.n = TRUE)
```

While this is okay, I don't feel that it provides much intuition about the
model's prediction on the _scale of the actual data_. In other words, what I'd
prefer to see is: How has our tree partitioned the original penguin data?

This is where **parttree** enters the fray. The package is named for its primary
workhorse function `parttree()`, which extracts all of the information needed
to produce a nice plot of our tree partitions alongside the original data.

```{r penguin_cl_plot}
ptree = parttree(tree)
plot(ptree)
```

_Et voila!_ Now we can clearly see how our model has divided up the Cartesian
space of the data. Gentoo penguins typically have longer flippers than Chinstrap
or Adelie penguins, while the latter have the shortest bills.

From the perspective of the end-user, the `ptree` parttree object is not all
that interesting in of itself. It is simply a data frame that contains the basic
information needed for our plot (partition coordinates, etc.). You can think of
it as a helpful intermediate object on our way to the visualization of interest.

```{r ptree}
# See also `attr(ptree, "parttree")`
ptree
```

Speaking of visualization, underneath the hood `plot.parttree` calls the
powerful
[**tinyplot**](https://grantmcdermott.com/tinyplot/)
package. All of the latter's various customization arguments can be passed on to
our `parttree` plot to make it look a bit nicer. For example:

```{r penguin_cl_plot_custom}
plot(ptree, pch = 16, palette = "classic", alpha = 0.75, grid = TRUE)
```


### Continuous predictions

In addition to discrete classification problems, **parttree** also supports
regression trees with continuous independent variables.

```{r penguin_reg_plot}
tree_cont = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data = penguins)

tree_cont |>
  parttree() |>
  plot(pch = 16, palette = "viridis")
```


## Supported model classes

Alongside the [**rpart**](https://CRAN.R-project.org/package=rpart) model
objects that we have been working with thus far, **parttree** also supports
decision trees created by the
[**partykit**](https://CRAN.R-project.org/package=partykit) package. Here we
see how the latter's `ctree` (conditional inference tree) algorithm yields a
slightly more sophisticated partitioning that the former's default.

```{r penguin_ctree_plot}
library(partykit)

ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins) |>
  parttree() |>
  plot(pch = 16, palette = "classic", alpha = 0.5)
```

**parttree** also supports a variety of "frontend" modes that call
`rpart::rpart()` as the underlying engine. This includes packages from both the
[**mlr3**](https://mlr3.mlr-org.com/) and
[**tidymodels**](https://www.tidymodels.org/)
([parsnip](https://parsnip.tidymodels.org/) 
or [workflows](https://workflows.tidymodels.org/))
ecosystems. Here is a quick demonstration using **parsnip**, where we'll also
pull in a different dataset just to change things up a little.

```{r titanic_plot}
set.seed(123) ## For consistent jitter

library(parsnip)
library(titanic) ## Just for a different data set

titanic_train$Survived = as.factor(titanic_train$Survived)

## Build our tree using parsnip (but with rpart as the model engine)
ti_tree =
  decision_tree() |>
  set_engine("rpart") |>
  set_mode("classification") |>
  fit(Survived ~ Pclass + Age, data = titanic_train)
## Now pass to parttree and plot
ti_tree |>
  parttree() |>
  plot(pch = 16, jitter = TRUE, palette = "dark", alpha = 0.7)
```

## ggplot2

The default `plot.parttree` method produces a base graphics plot. But we also
support **ggplot2** via with a dedicated `geom_parttree()` function. Here we
demonstrate with our initial classification tree from earlier.

```{r penguin_cl_ggplot2}
library(ggplot2)
theme_set(theme_linedraw())

## re-using the tree model object from above...
ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
  geom_point(aes(col = species)) +
  geom_parttree(data = tree, aes(fill=species), alpha = 0.1)
```

Compared to the "native" `plot.parttree` method, note that the **ggplot2**
workflow requires a few tweaks:

- We need to need to plot the original dataset as a separate layer (i.e., `geom_point()`).
- `geom_parttree()` accepts the tree object _itself_, not the result of `parttree()`.^[This is because `geom_parttree(data = <tree>)` calls `parttree(<tree>)` internally.] 

Continuous regression trees can also be drawn with `geom_parttree`. However, I 
recommend adjusting the plot fill aesthetic since your model will likely
partition the data into intervals that don't match up exactly with the raw data.
The easiest way to do this is by setting your colour and fill aesthetic together
as part of the same `scale_colour_*` call.

```{r penguin_reg_ggplot2}
## re-using the tree_cont model object from above...
ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
  geom_parttree(data = tree_cont, aes(fill=body_mass_g), alpha = 0.3) +
  geom_point(aes(col = body_mass_g)) + 
  scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together
```

### Gotcha: (gg)plot orientation

As we have already said, `geom_parttree()` calls the companion `parttree()` 
function internally, which coerces the **rpart** tree object into a data frame
that is  easily understood by **ggplot2**. For example, consider our initial
"ptree" object from earlier. 

```{r ptree_redux}
# ptree = parttree(tree)
ptree
```

Again, the resulting data frame is designed to be amenable to a **ggplot2** geom
layer, with columns like `xmin`, `xmax`, etc. specifying aesthetics that
**ggplot2** recognizes. (Fun fact: `geom_parttree()` is really just a thin
wrapper around `geom_rect()`.) The goal of **parttree** is to abstract away
these kinds of details from the user, so that they can just specify
`geom_parttree()`&mdash;with a valid tree object as the data input&mdash;and be
done with it. However, while this generally works well, it can sometimes lead to
unexpected behaviour in terms of plot orientation. That's because it's hard to
guess ahead of time what the user will specify as the x and y variables (i.e.
axes) in their other **ggplot2** layers.^[The default `plot.partree` method
doesn't have this problem because it assigns the x and y variables for both the
partitions and raw data points as part of the same function call.] To see what I
mean, let's redo our penguin plot from earlier, but this time switch the axes in
the main `ggplot()` call.

```{r penguin_cl_ggplot_mismatch}
## First, redo our first plot but this time switch the x and y variables
p3 = ggplot(
  data = penguins, 
  aes(x = bill_length_mm, y = flipper_length_mm) ## Switched!
  ) +
  geom_point(aes(col = species))

## Add on our tree (and some preemptive titling..)
p3 +
  geom_parttree(data = tree, aes(fill = species), alpha = 0.1) +
  labs(
    title = "Oops!",
    subtitle = "Looks like a mismatch between our x and y axes..."
  )
```

As was the case here, this kind of orientation mismatch is normally (hopefully) 
pretty easy to recognize. To fix, we can use the `flip = TRUE` argument to 
flip the orientation of the `geom_parttree` layer.

```{r penguin_cl_ggplot_mismatch_flip}
p3 +
  geom_parttree(
    data = tree, aes(fill = species), alpha = 0.1,
    flip = TRUE  ## Flip the orientation
  ) +
  labs(title = "That's better")
```