---
title: "XGBoost models"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{XGBoost models}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
if (requireNamespace("xgboost", quietly = TRUE)) {
  library(tidypredict)
  library(xgboost)
  library(dplyr)
  eval_code <- TRUE
} else {
  eval_code <- FALSE
}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  eval = eval_code
)
```

| Function                                                      |Works|
|---------------------------------------------------------------|-----|
|`tidypredict_fit()`, `tidypredict_sql()`, `parse_model()`      |  ✔  |
|`tidypredict_to_column()`                                      |  ✔  |
|`tidypredict_test()`                                           |  ✔  |
|`tidypredict_interval()`, `tidypredict_sql_interval()`         |  ✗  |
|`parsnip`                                                      |  ✔  |

## `tidypredict_` functions

```{r}
library(xgboost)

logregobj <- function(preds, dtrain) {
  labels <- xgboost::getinfo(dtrain, "label")
  preds <- 1 / (1 + exp(-preds))
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}

xgb_bin_data <- xgboost::xgb.DMatrix(
  as.matrix(mtcars[, -9]),
  label = mtcars$am
)

model <- xgboost::xgb.train(
  params = list(max_depth = 2, objective = "binary:logistic", base_score = 0.5),
  data = xgb_bin_data, nrounds = 50
)
```

- Create the R formula
    ```{r}
tidypredict_fit(model)
    ```

- Add the prediction to the original table
    ```{r}
library(dplyr)

mtcars %>%
  tidypredict_to_column(model) %>%
  glimpse()
    ```

- Confirm that `tidypredict` results match to the model's `predict()` results. The `xg_df` argument expects the `xgb.DMatrix` data set.
    ```{r}
tidypredict_test(model, mtcars, xg_df = xgb_bin_data)
    ```

## parsnip

`parsnip` fitted models are also supported by `tidypredict`:
```{r}
library(parsnip)

p_model <- boost_tree(mode = "regression") %>%
  set_engine("xgboost") %>%
  fit(am ~ ., data = mtcars)
```


```{r}
tidypredict_test(p_model, mtcars, xg_df = xgb_bin_data)
```

## Parse model spec

Here is an example of the model spec:
```{r}
pm <- parse_model(model)
str(pm, 2)
```

```{r}
str(pm$trees[1])
```