---
title: "Using the tidytreatment package with bartCause"
author: "Joshua J Bon"
date: "`r Sys.Date()`"
bibliography: vignette.bib
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Using the tidytreatment package with bartCause}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.dim = c(6, 4)
)

suppressPackageStartupMessages({
    library(bartCause)
    library(stan4bart)
    library(tidytreatment)
    library(dplyr)
    library(tidybayes)
    library(ggplot2)
  })
  
  # load pre-computed data and model
  sim <- suhillsim2_ranef

  
```

This vignette demonstrates an example workflow for heterogeneous treatment effect models using the `BART` package for fitting Bayesian Additive Regression Trees and `tidytreatment` for investigating the output of such models. The `tidytreatment` package can also be used with `bartMachine` models, support for `bcf` is coming soon (see branch `bcf-hold` on github). 

## Simulate data

Below we load packages and simulate data using the scheme described by @Hill2013 with the additional of 1 categorical variable. It it implemented in the function `simulate_hill_su_data()`:

```{r load-data-print, echo = TRUE, eval = FALSE}

# load packages
library(bartCause)
library(stan4bart)
library(tidytreatment)
library(dplyr)
library(tidybayes)
library(ggplot2)

# set seed so vignette is reproducible
set.seed(101)

# simulate data
sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE,  omega = 0, add_categorical = TRUE,
                             n_subjects = 10, sd_subjects = 0.1,
                             coef_categorical_treatment = c(0,0,1),
                             coef_categorical_nontreatment = c(-1,0,-1)
)
  
```

Now we can take a look at some data summaries.

```{r data-summary, echo = TRUE, eval = TRUE}

# non-treated vs treated counts:
table(sim$data$z)

dat <- sim$data
# a selection of data
dat %>% select(y, z, c1, x1:x3) %>% head()

# repeated observation counts for subjects:
table(sim$data$subject_id)

```

## Fit the regression model

Run the model to be used to assess treatment effects. Here we will use the `bartCause`, for causal inference with Bayesian additive regression trees (BART) [@hill2011]. For more on BART see @Chipman2010 and @sparapani2016. The package can be found on [CRAN](https://cran.r-project.org/package=bartCause).

We are following the procedure in @Hahn2020 (albeit without their causal forest model) where we estimate a propensity score for being assigned to the treatment regime, which improves estimation properties. This is done automatically in `bartCause`. The procedure is roughly as follows:

1. Fit 'variable selection' model (VS): Regress response variable against all potential confounders (i.e. no treatment variable)
2. Select a subset of confounders from the VS model which are most associated with the response variable
3. Fit a 'propensity score' model (PS): A binary response model estimating the treatment assignment propensity score for using only the variables selected in step 2
4. Fit the treatment effect model (TE): Using the all potential confounders and propensity score from step 3

Note: Using the automatic treatment assignment model in `bartCause` ignores step 1-2 uses all potential confounders to estimate the PS model. 

```{r run-bart, echo = TRUE, eval = TRUE}
  
# STEP 1 VS Model: Regress y ~ covariates
vs_bart <- stan4bart(y ~ bart(. - subject_id - z) + (1|subject_id), 
                             data = dat, iter = 5000, verbose = -1)

# STEP 2: Variable selection
  # select most important vars from y ~ covariates model
  # very simple selection mechanism. Should use cross-validation in practice
covar_ranking <- covariate_importance(vs_bart)
var_select <- covar_ranking %>% 
  filter(avg_inclusion > mean(avg_inclusion) - sd(avg_inclusion)) %>% # at minimum: within 1 sd of mean inclusion
  pull(variable)

# change categorical variables to just one variable
var_select <- unique(gsub("c1.[1-3]$","c1", var_select))

var_select
# includes all covariates

# STEP 3 PS Model: Regress z ~ selected covariates
ps_bart <- stan4bart(z ~ bart(. - subject_id - y) + (1|subject_id), 
                             data = dat, iter = 5000, verbose = -1)

# store propensity score in data
prop_score <- fitted(ps_bart)

# Step 4 TE Model: Regress y ~ z + covariates + propensity score
te_bart <- bartc(response = y, treatment = z, confounders = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10,  
                 parametric = (1|subject_id), data = dat, method.trt = prop_score, 
                 iter = 5000, bart_args = list(keepTrees = TRUE))

#* The posterior samples are kept small to manage size on CRAN

```

## Extract the posterior (tidy style)

Methods for extracting the posterior in a tidy format is included in the `tidytreatment`.

```{r tidy-bart-fit, echo=TRUE, cache=FALSE}

# get model parameters (excluding BART paramaters)
posterior_params <- tidy_draws(te_bart)

posterior_fitted <- epred_draws(te_bart, value = "fitted")

```

```{r tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE}

# Function to tidy predicted draws (adds predicted noise to fitted values)
posterior_pred <- predicted_draws(te_bart, value = "predicted")

```

## Use some plotting functions from the `tidybayes` package

Since `tidytreatment` follows the `tidybayes` output specifications, functions from `tidybayes` should work.

```{r plot-tidy-bart, echo=TRUE, cache=FALSE}

treatment_var_and_c1 <- 
  dat %>% 
  select(z,c1) %>%
  mutate(.row = 1:n(), z = as.factor(z))

posterior_fitted %>%
  left_join(treatment_var_and_c1, by = ".row") %>%
  ggplot() + 
  stat_halfeye(aes(x = z, y = fitted)) + 
  facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) +
  xlab("Treatment (z)") + ylab("Posterior predicted value") +
  theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values")

```

## Calculate Treatment Effects

Posterior conditional (average) treatment effects can be calculated using the `treatment_effects` function. This function finds the posterior values of
$$
  \tau(x) = \text{E}(y ~ \vert~ T = 1, X = x) - \text{E}(y ~ \vert~ T = 0, X = x) 
$$
  for each unit of measurement, $i$, (e.g. subject) in the data sample.

Some histogram summaries are presented below.


```{r post-treatment, eval = T}

# sample based (using data from fit) conditional treatment effects, posterior draws
posterior_treat_eff <- treatment_effects(te_bart)

# check lines up with summary results...

```
```{r cates-hist, echo=TRUE, cache=FALSE, eval = T}

# Histogram of treatment effect (all draws)
posterior_treat_eff %>% 
  ggplot() +
  geom_histogram(aes(x = icate), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (all draws)")

# Histogram of treatment effect (median for each subject)
posterior_treat_eff %>% summarise(cte_hat = median(icate)) %>%
  ggplot() +
  geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)")

```


## References