---
title: "Estimating and Evaluating the Optimal Subgroup"
output:
  rmarkdown::html_vignette:
    fig_caption: true
    toc: true    
    toc_depth: 2
    mathjax: "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS_CHTML.js"
vignette: >
  %\VignetteIndexEntry{optimal_subgroup}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
bibliography: ref.bib  
---

```{r lib, message = FALSE}
library(polle)
library(data.table)
```

This vignette showcase how `policy_learn()` and `policy_eval()` can be
combined to estimate and evaluate the optimal subgroup in the single-stage
case. We refer to [@nordland2023policy] for the syntax and methodological context.

# Setup

From here on we consider the single-stage case with a binary action set
$\{0,1\}$. For a given threshold $\eta > 0$ we can formulate the optimal subgroup
function via the conditional average treatment effect (CATE/blip) as

\begin{align*}
d^\eta_0(v)_ = I\{B_0(v) > \eta\},
\end{align*}
where $B_0$ is the CATE defined as
\begin{align*}
E\left[U^{(1)} - U^{(0)} \big | V = v \right].
\end{align*}

The average treatment effect in the optimal subgroup is now defined as
\begin{align*}
E\left[U^{(1)} - U^{(0)} \big | d^\eta_0(V) = 1 \right],
\end{align*}

which under consistency, positivity and randomization is identified as

\begin{align*}
E\left[Z(1,g_0,Q_0)(O) - Z(0,g_0,Q_0)(O) \big | d^\eta_0(V) = 1 \right],
\end{align*}

where $Z(a,g,Q)(O)$ is the doubly robust score for treatment $a$ and

\begin{align*}
d^\eta_0(v) &= I\{B_0(v) > \eta\}\\
B_0(v) &= E\left[Z(1,g_0,Q_0)(O) - Z(0,g_0,Q_0)(O) \big | V = v \right]
\end{align*}


# Threshold policy learning

In `polle` the threshold policy $d_\eta$ can be estimated using `policy_learn()`
via the `threshold` argument, and the average treatment effect in the subgroup
can be estimated using `policy_eval()` setting `target = subgroup`.

Here we consider an example using simulated data:

```{r}
par0 <- list(a = 1, b = 0, c = 3)
sim_d <- function(n, par=par0, potential_outcomes = FALSE) {
  W <- runif(n = n, min = -1, max = 1)
  L <- runif(n = n, min = -1, max = 1)
  A <- rbinom(n = n, size = 1, prob = 0.5)
  U1 <- W + L + (par$c*W + par$a*L + par$b) # U^1
  U0 <- W + L # U^0
  U <- A * U1 + (1 - A) * U0 + rnorm(n = n)
  out <- data.table(W = W, L = L, A = A, U = U)
  if (potential_outcomes == TRUE) {
    out$U0 <- U0
    out$U1 <- U1
  }
  return(out)
}
```

Note that in this simple case $U^{(1)} - U^{(0)} = cW + aL + b$.

```{r single stage data}
set.seed(1)
d <- sim_d(n = 200)
pd <- policy_data(
    d,
    action = "A",
    covariates = list("W", "L"),
    utility = "U"
)
```

We set a correctly specified policy learner using `policy_learn()` with `type = "blip"`  and set
a threshold of $\eta = 1$:

```{r}
pl1 <- policy_learn(
  type = "blip",
  control = control_blip(blip_models = q_glm(~ W + L)),
  threshold = 1
)
```

When then apply the policy learner based on the correctly
specified nuisance models. Furthermore, we extract
the corresponding policy actions, where $d_N(Z,L) = 1$
identifies the optimal subgroup for $\eta = 1$:

```{r}
po1 <- pl1(
  policy_data = pd,
  g_models = g_glm(~ 1),
  q_models = q_glm(~ A * (W + L))
)
pf1 <- get_policy(po1)
pa <- pf1(pd)
```

In the following plot, the black line indicates the boundary
for the true optimal subgroup. The dots represent the
estimated threshold policy:

```{r pa_plot, echo = FALSE}
library("ggplot2")
plot_data <- data.table(d_N = factor(pa$d), W = d$W, L = d$L)
ggplot(plot_data) +
  geom_point(aes(x = W, y = L, color = d_N)) +
  geom_abline(slope = -3, intercept = 1) +
  theme_bw()
```

Similarly, we can also use `type = "ptl"` to fit a policy tree with
a given threshold for not choosing the reference action
(first action in action set in alphabetical order)

```{r}
get_action_set(pd)[1] # reference action
pl1_ptl <- policy_learn(
    type = "ptl",
    control = control_ptl(policy_var = c("W", "L")),
    threshold = 1
)
po1_ptl <- pl1_ptl(
  policy_data = pd,
  g_models = g_glm(~ 1),
  q_models = q_glm(~ A * (W + L))
)
po1_ptl$ptl_objects
```

```{r pa_plot_ptl, echo = FALSE}
pf1_ptl <- get_policy(po1_ptl)
pa_ptl <- pf1_ptl(pd)
library("ggplot2")
plot_data <- data.table(d_N = factor(pa_ptl$d), W = d$W, L = d$L)
ggplot(plot_data) +
  geom_point(aes(x = W, y = L, color = d_N)) +
  geom_abline(slope = -3, intercept = 1) +
  theme_bw()
```

# Subgroup average treatment effect

The true subgroup average treatment effect is given by:

\begin{align*}
E[cW + aL + b | cW + aL + b \geq \eta ],
\end{align*}

which we can easily approximate:

```{r sate, cache = TRUE}
set.seed(1)
approx <- sim_d(n = 1e7, potential_outcomes = TRUE)
(sate <- with(approx, mean((U1 - U0)[(U1 - U0 >= 1)])))
rm(approx)
```

The subgroup average treatment effect associated with the learned optimal threshold policy
can be directly estimated using `policy_eval()` via the `target` argument:

```{r pe}
(pe <- policy_eval(
  policy_data = pd,
  policy_learn = pl1,
  target = "subgroup"
 ))
```

We can also estimate the subgroup average treatment effect for a set of thresholds at once:

```{r}
pl_set <- policy_learn(
  type = "blip",
  control = control_blip(blip_models = q_glm(~ W + L)),
  threshold = c(0, 1)
)

policy_eval(
  policy_data = pd,
  g_models = g_glm(~ 1),
  q_models = q_glm(~ A * (W + L)),
  policy_learn = pl_set,
  target = "subgroup"
)
```

## Asymptotics

The data adaptive target parameter

\begin{align*}
E[U^{(1)} - U^{(0)}| d_N(V) = 1] = E[Z_0(1,g,Q)(O) - Z_0(0,g,Q)(O)| d_N(V) = 1]
\end{align*}

is asymptotically normal with influence function

\begin{align*}
\frac{1}{P(d'(\cdot) = 1)} I\{d'(\cdot) = 1\}\left\{Z(1,g,Q)(O) - Z(0,g,Q)(O) - E[Z(1,g,Q)(O) - Z(0,g,Q)(O) | d'(\cdot) = 1]\right\},
\end{align*}

where $d'$ is the limiting policy of $d_N$. The fitted influence curve can be extracted using `IC()`:

```{r}
IC(pe) |> head()
```

# References