## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(echo = TRUE)
knitr::opts_chunk$set(eval = TRUE)

## ----eval = FALSE-------------------------------------------------------------
# remotes::install_github("DoubleML/doubleml-for-r")

## ----message=FALSE, warning=FALSE---------------------------------------------
library(DoubleML)

## -----------------------------------------------------------------------------
library(DoubleML)

# Load bonus data
df_bonus = fetch_bonus(return_type = "data.table")
head(df_bonus)

# Simulate data
set.seed(3141)
n_obs = 500
n_vars = 100
theta = 3
X = matrix(rnorm(n_obs * n_vars), nrow = n_obs, ncol = n_vars)
d = X[, 1:3] %*% c(5, 5, 5) + rnorm(n_obs)
y = theta * d + X[, 1:3] %*% c(5, 5, 5) + rnorm(n_obs)

## -----------------------------------------------------------------------------
# Specify the data and variables for the causal model
dml_data_bonus = DoubleMLData$new(df_bonus,
  y_col = "inuidur1",
  d_cols = "tg",
  x_cols = c("female", "black", "othrace", "dep1", "dep2",
    "q2", "q3", "q4", "q5", "q6", "agelt35", "agegt54",
    "durable", "lusd", "husd"))
print(dml_data_bonus)

# matrix interface to DoubleMLData
dml_data_sim = double_ml_data_from_matrix(X = X, y = y, d = d)
dml_data_sim

## -----------------------------------------------------------------------------
library(mlr3)
library(mlr3learners)
# surpress messages from mlr3 package during fitting
lgr::get_logger("mlr3")$set_threshold("warn")

learner = lrn("regr.ranger", num.trees = 500, max.depth = 5, min.node.size = 2)
ml_l_bonus = learner$clone()
ml_m_bonus = learner$clone()

learner = lrn("regr.glmnet", lambda = sqrt(log(n_vars) / (n_obs)))
ml_l_sim = learner$clone()
ml_m_sim = learner$clone()

## -----------------------------------------------------------------------------
set.seed(3141)
obj_dml_plr_bonus = DoubleMLPLR$new(dml_data_bonus, ml_l = ml_l_bonus, ml_m = ml_m_bonus)
obj_dml_plr_bonus$fit()
print(obj_dml_plr_bonus)

obj_dml_plr_sim = DoubleMLPLR$new(dml_data_sim, ml_l = ml_l_sim, ml_m = ml_m_sim)
obj_dml_plr_sim$fit()
print(obj_dml_plr_sim)