## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.dim = c(6, 4)
)

suppressPackageStartupMessages({
    library(bartCause)
    library(stan4bart)
    library(tidytreatment)
    library(dplyr)
    library(tidybayes)
    library(ggplot2)
  })
  
  # load pre-computed data and model
  sim <- suhillsim2_ranef

  

## ----load-data-print, echo = TRUE, eval = FALSE-------------------------------
# 
# # load packages
# library(bartCause)
# library(stan4bart)
# library(tidytreatment)
# library(dplyr)
# library(tidybayes)
# library(ggplot2)
# 
# # set seed so vignette is reproducible
# set.seed(101)
# 
# # simulate data
# sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE,  omega = 0, add_categorical = TRUE,
#                              n_subjects = 10, sd_subjects = 0.1,
#                              coef_categorical_treatment = c(0,0,1),
#                              coef_categorical_nontreatment = c(-1,0,-1)
# )
# 

## ----data-summary, echo = TRUE, eval = TRUE-----------------------------------

# non-treated vs treated counts:
table(sim$data$z)

dat <- sim$data
# a selection of data
dat %>% select(y, z, c1, x1:x3) %>% head()

# repeated observation counts for subjects:
table(sim$data$subject_id)


## ----run-bart, echo = TRUE, eval = TRUE---------------------------------------
  
# STEP 1 VS Model: Regress y ~ covariates
vs_bart <- stan4bart(y ~ bart(. - subject_id - z) + (1|subject_id), 
                             data = dat, iter = 5000, verbose = -1)

# STEP 2: Variable selection
  # select most important vars from y ~ covariates model
  # very simple selection mechanism. Should use cross-validation in practice
covar_ranking <- covariate_importance(vs_bart)
var_select <- covar_ranking %>% 
  filter(avg_inclusion > mean(avg_inclusion) - sd(avg_inclusion)) %>% # at minimum: within 1 sd of mean inclusion
  pull(variable)

# change categorical variables to just one variable
var_select <- unique(gsub("c1.[1-3]$","c1", var_select))

var_select
# includes all covariates

# STEP 3 PS Model: Regress z ~ selected covariates
ps_bart <- stan4bart(z ~ bart(. - subject_id - y) + (1|subject_id), 
                             data = dat, iter = 5000, verbose = -1)

# store propensity score in data
prop_score <- fitted(ps_bart)

# Step 4 TE Model: Regress y ~ z + covariates + propensity score
te_bart <- bartc(response = y, treatment = z, confounders = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10,  
                 parametric = (1|subject_id), data = dat, method.trt = prop_score, 
                 iter = 5000, bart_args = list(keepTrees = TRUE))

#* The posterior samples are kept small to manage size on CRAN


## ----tidy-bart-fit, echo=TRUE, cache=FALSE------------------------------------

# get model parameters (excluding BART paramaters)
posterior_params <- tidy_draws(te_bart)

posterior_fitted <- epred_draws(te_bart, value = "fitted")


## ----tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE-----------------------
# 
# # Function to tidy predicted draws (adds predicted noise to fitted values)
# posterior_pred <- predicted_draws(te_bart, value = "predicted")
# 

## ----plot-tidy-bart, echo=TRUE, cache=FALSE-----------------------------------

treatment_var_and_c1 <- 
  dat %>% 
  select(z,c1) %>%
  mutate(.row = 1:n(), z = as.factor(z))

posterior_fitted %>%
  left_join(treatment_var_and_c1, by = ".row") %>%
  ggplot() + 
  stat_halfeye(aes(x = z, y = fitted)) + 
  facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) +
  xlab("Treatment (z)") + ylab("Posterior predicted value") +
  theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values")


## ----post-treatment, eval = T-------------------------------------------------

# sample based (using data from fit) conditional treatment effects, posterior draws
posterior_treat_eff <- treatment_effects(te_bart)

# check lines up with summary results...


## ----cates-hist, echo=TRUE, cache=FALSE, eval = T-----------------------------

# Histogram of treatment effect (all draws)
posterior_treat_eff %>% 
  ggplot() +
  geom_histogram(aes(x = icate), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (all draws)")

# Histogram of treatment effect (median for each subject)
posterior_treat_eff %>% summarise(cte_hat = median(icate)) %>%
  ggplot() +
  geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + 
  theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)")