---
title: "Using R6causal"
author: "Juha Karvanen"
date: "2024-03-14"
output: rmarkdown::pdf_document
vignette: >
  %\VignetteIndexEntry{using_R6causal}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
```

# Overview

The R package `R6causal` implements an R6 class called `SCM`. The class aims to simplify working with structural causal models. The missing data mechanism can be defined as a part of the structural model.  

The class contains methods for

* defining a structural causal model via functions, text or conditional probability tables
* printing basic information on the model
* plotting the graph for the model using packages `igraph` or `qgraph`
* simulating data from the model
* applying an intervention
* checking the identifiability of a query using the R packages `causaleffect` and `dosearch`
* defining the missing data mechanism 
* simulating incomplete data from the model according to the specified missing data mechanism 
* checking the identifiability in a missing data problem using the R package `dosearch`  
* checking the identifiability of a counterfactual query using the R package `cfid` 

In addition, there are functions for

* running experiments 
* counterfactual inference using simulation
* evaluating fairness of a prediction model

The class `ParallelWorld` inherits `SCM` and defines a structural causal model that describes parallel worlds for counterfactual inference.

The class `LinearGaussianSCM` inherits `SCM` and defines a structural causal model where all functions are linear and all background variables follow Gaussian distribution.

# Setup
```{r setup}
library(R6causal)
library(data.table)
library(stats)
data.table::setDTthreads(2)
```

# Defining the model

Structural causal model (SCM) for a backdoor situation can be defined as follows

```{r definebackdoor}
backdoor <- SCM$new("backdoor",
  uflist = list(
    uz = function(n) {return(runif(n))},
    ux = function(n) {return(runif(n))},
    uy = function(n) {return(runif(n))}
  ),
  vflist = list(
    z = function(uz) {
      return(as.numeric(uz < 0.4))},
    x = function(ux, z) {
      return(as.numeric(ux < 0.2 + 0.5*z))},
    y = function(uy, z, x) {
      return(as.numeric(uy < 0.1 + 0.4*z + 0.4*x))}
  )
)

```

A shortcut notation for this is

```{r definebackdoor_text}
backdoor_text <- SCM$new("backdoor",
  uflist = list(
    uz = "n : runif(n)", 
    ux = "n : runif(n)", 
    uy = "n : runif(n)" 
  ),
  vflist = list(
    z = "uz : as.numeric(uz < 0.4)", 
    x = "ux, z : as.numeric(ux < 0.2 + 0.5*z)",
    y = "uy, z, x : as.numeric(uy < 0.1 + 0.4*z + 0.4*x)" 
  )
)

```


Alternatively the functions of SCM can be specified via conditional probability tables

```{r definebackdooralt}
backdoor_condprob <- SCM$new("backdoor",
  uflist = list(
    uz = function(n) {return(runif(n))},
    ux = function(n) {return(runif(n))},
    uy = function(n) {return(runif(n))}
  ),
  vflist = list(
    z = function(uz) {
      return( generate_condprob( ycondx = data.table(z = c(0,1), 
                                                     prob = c(0.6,0.4)), 
                               x = data.table(uz = uz), 
                               Umerge_expr = "uz"))},
    x = function(ux, z) {
      return( generate_condprob( ycondx = data.table(x = c(0,1,0,1), 
                                                     z = c(0,0,1,1), 
                                                     prob = c(0.8,0.2,0.3,0.7)),
                                             x = data.table(z = z, ux = ux), 
                                             Umerge_expr = "ux"))},
    y = function(uy, z, x) {
      return( generate_condprob( ycondx = data.table(y= rep(c(0,1), 4), 
                                                     z = c(0,0,1,1,0,0,1,1), 
                                                     x = c(0,0,0,0,1,1,1,1), 
                                                     prob = c(0.9,0.1,0.5,0.5,
                                                              0.5,0.5,0.1,0.9)),
                                             x = data.table(z = z, x = x, uy = uy), 
                                             Umerge_expr = "uy"))}
  )
)
```
It is possible to mix the styles and define some elements of a function list as functions, some as text and
some as conditional probability tables.

# Defining a linear Gaussian SCM

A linear Gaussian SCM can be defined giving the coefficients for the structural equations:
```{r lineargaussianbackdoor}
lgbackdoor <- LinearGaussianSCM$new("Linear Gaussian Backdoor",
                                    linear_gaussian = list(
                                      uflist = list(ux = function(n) {rnorm(n)},
                                                    uy = function(n) {rnorm(n)},
                                                    uz = function(n) {rnorm(n)}),
                                      vnames = c("x","y","z"),
                                      vcoefmatrix = matrix(c(0,0.4,0,0,0,0,0.6,0.8,0),3,3),
                                      ucoefvector = c(1,1,1),
                                      ccoefvector = c(0,0,0)))
print(lgbackdoor)
```

It is also possible to generate the underlying DAG and the coefficients randomly:
```{r randomlineargaussian}
randomlg <- LinearGaussianSCM$new("Random Linear Gaussian",
                                  random_linear_gaussian = list(
                                  nv = 6, 
                                  edgeprob=0.5, 
                                  vcoefdistr = function(n) {rnorm(n)}, 
                                  ccoefdistr = function(n) {rnorm(n)}, 
                                  ucoefdistr = function(n) {rnorm(n)}))
print(randomlg)
```






\newpage
# Printing the model

The print method presents the basic information on the model
```{r printbackdoor}
backdoor
```


\newpage
# Plotting the graph

The plotting method of the package `igraph` is used by default. If `qgraph` is available, its plotting method can be used as well. The argument `subset` controls which variables are plotted. Plotting parameters are passed to the plotting method.

```{r plotbackdoor}
backdoor$plot(vertex.size = 25) # with package 'igraph'
backdoor$plot(subset = "v") # only observed variables
if (requireNamespace("qgraph", quietly = TRUE)) backdoor$plot(method = "qgraph") 
# alternative look with package 'qgraph'
```

# Simulating data

Calling method `simulate()` creates or updates data table `simdata`.
```{r simulatebackdoor}
backdoor$simulate(10)
backdoor$simdata
backdoor$simulate(8)
backdoor$simdata
backdoor_text$simulate(20)
backdoor_condprob$simulate(30)
```

# Applying an intervention
In an intervention, the structural equation of the target variable is changed.
```{r interventionbackdoor}
backdoor_x1 <- backdoor$clone()  # making a copy
backdoor_x1$intervene("x",1) # applying the intervention
backdoor_x1$plot(method = "qgraph") # to see that arrows incoming to x are cut
backdoor_x1$simulate(10) # simulating from the intervened model
backdoor_x1$simdata
```

## An intervention can redefine a structural equation

```{r interventionbackdoor2}
backdoor_yz <- backdoor$clone()  # making a copy
backdoor_yz$intervene("y", 
  function(uy, z) {return(as.numeric(uy < 0.1 + 0.8*z ))}) # making y a function of z only
backdoor_yz$plot(method = "qgraph") # to see that arrow x -> y is cut
```

# Running an experiment (set of interventions)
The function `run_experiment` applies a set of interventions, simulates data and collects the results.
```{r experimentbackdoor}
backdoor_experiment <- run_experiment(backdoor, 
                                      intervene = list(x = c(0,1)), 
                                      response = "y", 
                                      n = 10000)
str(backdoor_experiment)
colMeans(backdoor_experiment$response_list$y)
```

# Applying the ID algorithm, Do-search and cfid
There are direct plugins to R packages `causaleffect`, `dosearch` and `cfid` that can be used to solve identifiability problems.
```{r IDbackdoor}
backdoor$causal.effect(y = "y", x = "x")
backdoor$dosearch(data = "p(x,y,z)", query = "p(y|do(x))")
backdoor$cfid(gamma = cfid::conj(cfid::cf("Y",0), cfid::cf("X",0, c(Z=1))) ) 
```

# Counterfactual inference (a simple case)

Let us assume that intervention do(X=0) was applied and the response Y = 0 was recorded. 
What is the probability that in this situation the intervention do(X=1) would have led to 
the response Y = 1? We estimate this probability by means of simulation. 

```{r counterfactualbackdoor}
cfdata <- counterfactual(backdoor, situation = list(do = list(target = "x", ifunction = 0), 
                                                    condition = data.table( x = 0, y = 0)), 
                         target = "x", ifunction = 1, n = 100000, 
                         method = "rejection")
mean(cfdata$y)
```
The result differs from P(Y = 1 | do(X = 1))
```{r counterfactualcomparisonbackdoor}
backdoor_x1$simulate(100000)
mean(backdoor_x1$simdata$y)
```

# Counterfactual inference (parallel worlds)

Parallel world graphs (a generalization of a twin graph)  are used for counterfactual inference with several counterfactual interventions  . The package implements class `ParallelWorld` which heritates class `SCM`. A `ParallelWorld` object is created from an `SCM` object by specifying the interventions for each world. By default the variables of the parallel worlds are named with suffixes "_1", "_2", ...

In the example below, we have the original world (variables `x`, `z`, `y`) and its two variants. In the variant 1 (variables `x_1`, `z_1`, `y_1`), the value of `x` (variable `x_1` in the object) is set to be 0. In the variant 2 (variables `x_2`, `z_2`, `y_2`), the value of `x` (variable `x_2` in the object) is set to be 0 and the value of `z` (variable `z_2` in the object) is set to be 1.
```{r parallelworldbackdoor}
backdoor_parallel <- ParallelWorld$new(
                         backdoor,
                        dolist=list(
                             list(target = "x", 
                                 ifunction = 0),
                             list(target = list("z","x"), 
                                  ifunction = list(1,0))
                         )
 )
backdoor_parallel
if (requireNamespace("qgraph", quietly = TRUE)) backdoor_parallel$plot(method = "qgraph") 
```

Counterfactual data can be simulated with function `counterfactual`. In the example below, we know that variable `y` obtained value 0 in the original world as well as variants 1 and 2. We are interested in the counterfactual distribution of `y` if `x` had been set to 1.
```{r parallelworldcounterfactual}
cfdata <- counterfactual(backdoor_parallel,
                         situation = list(
                            do = NULL,
                            condition = data.table::data.table( y = 0, y_1 = 0, y_2 = 0)),
                         target = "x",
                         ifunction = 1,
                         n = 100000,
                         method = "rejection")
mean(cfdata$y)
```
The printed value is a simulation based estimate for the counterfactual probability $P(Y=1)$.

An alternative way for answering the same question defines the case of interest as one of the parallel worlds (here variant 3). 
```{r parallelworldbackdoor2}
backdoor_parallel2 <- ParallelWorld$new(
                         backdoor,
                        dolist=list(
                             list(target = "x", 
                                 ifunction = 0),
                             list(target = list("z","x"), 
                                  ifunction = list(1,0)),
                              list(target = "x", 
                                 ifunction = 1)
                         )
 )
cfdata <- counterfactual(backdoor_parallel2,
                         situation = list(
                            do = NULL,
                            condition = data.table::data.table( y = 0, y_1 = 0, y_2 = 0)),
                         n = 100000, 
                         method = "rejection")
mean(cfdata$y_3)
```
The printed value is a simulation based estimate for the counterfactual probability $P(Y=1)$.

# A model with a missing data mechanism 
The missing data mechanism is defined in similar manner as the other variables.
```{r definebackdoor_md}
backdoor_md <- SCM$new("backdoor_md",
                       uflist = list(
                         uz = "n : runif(n)",
                         ux = "n : runif(n)",
                         uy = "n : runif(n)",
                         urz = "n : runif(n)",
                         urx = "n : runif(n)",
                         ury = "n : runif(n)"
                       ),
                       vflist = list(
                         z = "uz : as.numeric(uz < 0.4)",
                         x = "ux, z : as.numeric(ux < 0.2 + 0.5*z)",
                         y = "uy, z, x : as.numeric(uy < 0.1 + 0.4*z + 0.4*x)"
                       ),
                       rflist = list(
                         z = "urz : as.numeric( urz < 0.9)",
                         x = "urx, z : as.numeric( (urx + z)/2 < 0.9)",
                         y = "ury, z : as.numeric( (ury + z)/2 < 0.9)"
                       ),
                       rprefix = "r_"
)
```

# Plotting the graph for a model with missing data mechanism

```{r plotbackdoor_md}
backdoor_md$plot(vertex.size = 25, edge.arrow.size=0.5) # with package 'igraph'
backdoor_md$plot(subset = "v") # only observed variables a
if (!requireNamespace("qgraph", quietly = TRUE)) backdoor_md$plot(method = "qgraph") 
# alternative look with package 'qgraph'
```

# Simulating incomplete data

By default both complete data and incomplete data are simulated. The incomplete dataset is named as `$simdata_obs`.
```{r backdoor_md_simulate}
backdoor_md$simulate(100)
summary(backdoor_md$simdata)
summary(backdoor_md$simdata_obs)
```

By using the argument `fixedvars` one can keep the complete data unchanged and re-simulate the missing data mechanism. 
```{r backdoor_md_simulate2}
backdoor_md$simulate(100, fixedvars = c("x","y","z","ux","uy","uz"))
summary(backdoor_md$simdata)
summary(backdoor_md$simdata_obs)
```

## Applying Do-search to a missing data problem

```{r backdoor_md_dosearch}
backdoor_md$dosearch(data = "p(x*,y*,z*,r_x,r_y,r_z)", query = "p(y|do(x))")
```
It is automatically recognized that the problem is a missing data problem when `rflist != NULL`.