---
title: "Binary Treatment"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Binary Treatment}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
knitr::opts_chunk$set(fig.width=7, fig.height=4) 
set.seed(123)
```

In this example, we will illustrate how to deal with binary treatment when using CausalEGM.

```{r setup}
library(RcausalEGM)
```

```{r check_pkg}
if (!(reticulate::py_module_available('CausalEGM'))){
  cat("## Please install the CausalEGM package using the function: install_causalegm()")
  knitr::knit_exit()
}
```

## Data Generation

Let's first generate a simulation dataset with binary treatment.

```{r simulation}
n <- 500
p <- 20
v <- matrix(rnorm(n * p), n, p)
x <- rbinom(n, 1, 0.3 + 0.2 * (v[, 1] > 0))
y <- pmax(v[, 1], 0) * x + v[, 2] + pmin(v[, 3], 0) + rnorm(n)
```

Let's take a look at the simulation data.
```{r look}
oldpar <- par(mfrow=c(1,3))
slices <- c(sum(x==1), sum(x==0))
lbls <- c(paste("T group:",round(sum(x==1)*100/length(x), 2), "%", sep=""), paste("C group:",round(sum(x==0)*100/length(x), 2), "%", sep=""))
pie(slices, labels = lbls, main="Treatment Variables")
hist(y, breaks=12, col="red",xlab="y values")
boxplot(v[,1:5],main="First five covariates", xlab="Covariate index", ylab="v values")
par(oldpar)
```

## Model training

Start training a CausalEGM model. Users can refer to the core API "causalegm" by help(causalegm) for detailed usage.

Note that the parameters for x, y, v  are required. Besides, users can also specify the z_dims as a integer list with four elements.

```{r train}
#help(causalegm)
model <- causalegm(x=x,y=y,v=v,n_iter=2000)
```

## Treatment Effect Estimation

After the above model training, users can find the .txt format of individual treatment effect (ITE) estimates in the "output_dir" directory (parameter in "causalegm"). 

Alternatively, several keys estimates, including average treatment effect and individual treatment effect can be directly obtained from the trained model.

```{r extract}
ATE <- mean(model$causal_pre)
paste("The average treatment effect (ATE):", round(ATE, 3))
ITE <- model$causal_pre
boxplot(ITE, main="ITE distribution", ylab="Values")
```

## Predicting Counterfactual Outcome

Besides ATE and ITE estimation, we also provide APIs for predicting the counterfactual outcome directly.

```{r cf}
x_cf <- 1-x
y_cf <- get_est(model, v, x_cf)
boxplot(y_cf, main="Counterfactual outcome", ylab="Values")
```

## Estimating Conditional Average Treatment Effect (CATE)

We demonstrate how to estimate CATE by an external data after training the model.

```{r cate}
n_test <- 100
v_test <- matrix(rnorm(n_test * p), n_test, p)
CATE <- get_est(model, v_test)
boxplot(CATE, main="CATE", ylab="Values")

```