---
title: "Distributed Stratified Cox Regression using Homomorphic Computation"
author: "Balasubramanian Narasimhan"
date: '`r Sys.Date()`'
bibliography: homomorphing.bib
output:
  html_document:
  theme: cerulean
  toc: yes
  toc_depth: 2
vignette: >
  %\VignetteIndexEntry{Distributed Stratified Cox Regression using Homomorphic Computation}
  %\VignetteEngine{knitr::rmarkdown}
  \usepackage[utf8]{inputenc}
---

```{r echo=FALSE}
### get knitr just the way we like it

knitr::opts_chunk$set(
  message = FALSE,
  warning = FALSE,
  error = FALSE,
  tidy = FALSE,
  cache = FALSE
)
```

## Introduction

It is only a short way from the toy MLE example to a more useful
example using Cox regression.

But first, we need the `survival` package and the `homomopheR` package.

```{r}
if (!require("survival")) {
    stop("this vignette requires the survival package")
}
library(homomorpheR)
```

We generate some simulated data for the purpose of this example. We
will have three sites each with patient data (sizes 1000, 500 and
1500) respectively, containing

- `sex` (0, 1) for male/female
- `age` between 40 and 70
- a biomarker `bm`
- a `time` to some event of interest
- an indicator `event` which is 1 if an event was observed and 0
otherwise.

It is common to fit stratified models using sites as strata since the
patient characteristics usually differ from site to site. So the
baseline hazards (`lambdaT`) are different for each site but they
share common coefficients (`beta.1`, `beta.2` and `beta.3` for `age`,
`sex` and `bm` respy.) for the model. See [@survival-book] by Therneau
and Grambsch for details. So our model for each site $i$ is

$$
S(t, age, sex, bm) =
[S_0^i(t)]^{\exp(\beta_1 age + \beta_2 sex + \beta_3 bm)}
$$


```{r}
sampleSize <- c(n1 = 1000, n2 = 500, n3 = 1500)

set.seed(12345)

beta.1 <- -.015; beta.2 <- .2; beta.3 <- .001;

lambdaT <- c(5, 4, 3)
lambdaC <- 2

coxData <- lapply(seq_along(sampleSize),
                  function(i) {
                      sex <- sample(c(0, 1), size = sampleSize[i], replace = TRUE)
                      age <- sample(40:70, size = sampleSize[i], replace = TRUE)
                      bm <- rnorm(sampleSize[i])
                      trueTime <- rweibull(sampleSize[i],
                                           shape = 1,
                                           scale = lambdaT[i] * exp(beta.1 * age + beta.2 * sex + beta.3 * bm ))
                      censoringTime <- rweibull(sampleSize[i],
                                                shape = 1,
                                                scale = lambdaC)
                      time <- pmin(trueTime, censoringTime)
                      event <- (time == trueTime)
                      data.frame(stratum = i,
                                 sex = sex,
                                 age = age,
                                 bm = bm,
                                 time = time,
                                 event = event)
                  })
```

So here is a summary of the data for the three sites.

###  Site 1
```{r}
str(coxData[[1]])
```

###  Site 2
```{r}
str(coxData[[2]])
```

###  Site 3
```{r}
str(coxData[[3]])
```

#
# Aggregated fit

If the data were all aggregated in one place, it would very simple to
fit the model. Below, we row-bind the data from the three sites.

```{r}
aggModel <- coxph(formula = Surv(time, event) ~ sex +
                                age + bm + strata(stratum),
                            data = do.call(rbind, coxData))
aggModel
```

Here `age` and `sex` are significant, but `bm` is not. The estimates
$\hat{\beta}$ are `(-0.180, .020, .007)`.

We can also print out the value of the (partial) log-likelihood at the
MLE.

```{r}
aggModel$loglik
```

The first is the value at the parameter value `(0, 0, 0)` and the last
is the value at the MLE.


## Distributed Computation

Assume now that the data `coxData` is distributed between three sites
none of whom want to share actual data among each other or even with a
master computation process. They wish to keep their data secret but
are willing, together, to provide the sum of their local negative
log-likelihoods. They need to do this in a way so that the master
process will not be able to associate the contribution to the
likelihood from each site.

The overall likelihood function $l(\lambda)$ for the entire data is
therefore the sum of the likelihoods at each site: $l(\lambda) =
l_1(\lambda)+l_2(\lambda)+l_3(\lambda).$ How can this likelihood be
computed while maintaining privacy?

Assuming that every site including the master has access to a
homomorphic computation library such as `homomorpheR`, the likelihood
can be computed in a privacy-preserving manner using the following
scheme. We use $E(x)$ and $D(x)$ to denote the encrypted and decrypted
values of $x$ respectively.

0. Master generates a public/private key pair. Master distributes the
   public key to all sites. (The private key is not distributed and
   kept only by the master.)
1. Master generates a random offset $r$ to obfuscate the intial
   likelihood.
2. Master sends $E(r)$ and a guess $\lambda_0$ to site 1. Note that
   $\lambda$ is not encrypted.
3. Site 1 computes $l_1 = l(\lambda_0, y_1)$, the local likelihood for
   local data $y_1$ using parameter $\lambda_0$. It then sends on
   $\lambda_0$ and $E(r) + E(l_1)$ to site 2.
4. Site 2 computes $l_2 = l(\lambda_0, y_2)$, the local likelihood for
   local data $y_2$ using parameter $\lambda_0$. It then sends on
   $\lambda_0$ and $E(r) + E(l_1) + E(l_2)$ to site 3.
5. Site 3 computes $l_3 = l(\lambda_0, y_3)$, the local likelihood for
   local data $y_3$ using parameter $\lambda_0$. It then sends on
   $E(r) + E(l_1) + E(l_2) + E(l_3)$ back to master.
6. Master retrieves $E(r) + E(l_1) + E(l_2) + E(l_3)$ which, due to
   the homomorphism, is exactly $E(r+l_1+l_2+l_3) = E(r+l).$ So the
   master computes $D(E(r+l)) - r$ to obtain the value of the overall
   likelihood at $\lambda_0$.
7. Master updates $\lambda_0$ with a new guess $\lambda_1$ and repeats
   steps 1-5. This process is iterated to convergence. For added
   security, even steps 0-5 can be repeated, at additional
   computational cost.

This is pictorially shown below.

![](assets/round_robin.png)

## Implementation

The above implementation assumes that the encryption and decryption
can happen with real numbers which is not the actual
situation. Instead, we use rational approximations using a large
denominator, $2^{256}$, say. In the future, of course, we need to
build an actual library is built with rigorous algorithms guaranteeing
precision and overflow/undeflow detection. For now, this is just an ad
hoc implementation.

Also, since we are only using homomorphic additive properties, a
partial homomorphic scheme such as the Paillier Encryption system will
be sufficient for our computations.

We define a class to encapsulate our sites that will compute the
Poisson likelihood on site data given a parameter $\lambda$. Note how
the `addNLLAndForward` method takes care to split the result into an
integer and fractional part while performing the arithmetic
operations. (The latter is approximated by a rational number.)

We define a class to encapsulate our sites that will compute the
partial log likelihood on site data given a parameter $\beta$.

In the code below, we exploit, for expository purposes, a feature of
`coxph`: a control parameter can be passed to evaluate the
partial likelihood at a given $\beta$ value.

```{r}
Site <- R6::R6Class("Site",
                    private = list(
                        ## name of the site
                        name = NA,
                        ## only master has this, NA for workers
                        privkey = NA,
                        ## local data
                        data = NA,
                        ## The next site in the communication: NA for master
                        nextSite = NA,
                        ## is this the master site?
                        iAmMaster = FALSE,
                        ## intermediate result variable
                        intermediateResult = NA,
                        ## Control variable for cox regression
                        cph.control = NA
                    ),
                    public = list(
                        count = NA,
                        ## Common denominator for approximate real arithmetic
                        den = NA,
                        ## The public key; everyone has this
                        pubkey = NA,
                        initialize = function(name, data, den) {
                            private$name <- name
                            private$data <- data
                            self$den <- den
                            private$cph.control <- replace(coxph.control(), "iter.max", 0)
                        },
                        setPublicKey = function(pubkey) {
                            self$pubkey <- pubkey
                        },
                        setPrivateKey = function(privkey) {
                            private$privkey <- privkey
                        },
                        ## Make me master
                        makeMeMaster = function() {
                            private$iAmMaster <- TRUE
                        },
                        ## add neg log lik and forward to next site
                        addNLLAndForward = function(beta, enc.offset) {
                            if (private$iAmMaster) {
                                ## We are master, so don't forward
                                ## Just store intermediate result and return
                                private$intermediateResult <- enc.offset
                            } else {
                                ## We are workers, so add and forward
                                ## add negative log likelihood and forward result to next site
                                ## Note that offset is encrypted
                                nllValue <- self$nLL(beta)
                                result.int <- floor(nllValue)
                                result.frac <- nllValue - result.int
                                result.fracnum <- gmp::as.bigq(gmp::numerator(gmp::as.bigq(result.frac) * self$den))
                                pubkey <- self$pubkey
                                enc.result.int <- pubkey$encrypt(result.int)
                                enc.result.fracnum <- pubkey$encrypt(result.fracnum)
                                result <- list(int = pubkey$add(enc.result.int, enc.offset$int),
                                               frac = pubkey$add(enc.result.fracnum, enc.offset$frac))
                                private$nextSite$addNLLAndForward(beta, enc.offset = result)
                            }
                            ## Return a TRUE result for now.
                            TRUE
                        },
                        ## Set the next site in the communication graph
                        setNextSite = function(nextSite) {
                            private$nextSite <- nextSite
                        },
                        ## The negative log likelihood
                            nLL = function(beta) {
                            if (private$iAmMaster) {
                                ## We're master, so need to get result from sites
                                ## 1. Generate a random offset and encrypt it
                                pubkey <- self$pubkey
                                offset <- list(int = random.bigz(nBits = 256),
                                               frac = random.bigz(nBits = 256))
                                enc.offset <- list(int = pubkey$encrypt(offset$int),
                                                   frac = pubkey$encrypt(offset$frac))
                                ## 2. Send off to next site
                                throwaway <- private$nextSite$addNLLAndForward(beta, enc.offset)
                                ## 3. When the call returns, the result will be in
                                ##    the field intermediateResult, so decrypt that.
                                sum <- private$intermediateResult
                                privkey <- private$privkey
                                intResult <- as.double(privkey$decrypt(sum$int) - offset$int)
                                fracResult <- as.double(gmp::as.bigq(privkey$decrypt(sum$frac) - offset$frac) / den)
                                intResult + fracResult
                            } else {
                                ## We're worker, so compute local negative log likelihood
                                tryCatch({
                                    m <- coxph(formula = Surv(time, event) ~ sex + age + bm,
                                                         data = private$data,
                                                         init = beta,
                                                         control = private$cph.control)
                                    -(m$loglik[1])
                                },
                                error = function(e) NA)
                            }
                        })
                    )
```

We are now ready to use our sites in the computation.

### 1. Generate public and private key pair

We also choose a denominator for all our rational approximations.

```{r}
keys <- PaillierKeyPair$new(1024) ## Generate new public and private key.
den <- gmp::as.bigq(2)^256  #Our denominator for rational approximations
```

### 2. Create sites

```{r}
site1 <- Site$new(name = "Site 1", data = coxData[[1]], den = den)
site2 <- Site$new(name = "Site 2", data = coxData[[2]], den = den)
site3 <- Site$new(name = "Site 3", data = coxData[[3]], den = den)
```
The master process is also a site but has no data. So has to be thus
designated.

```{r}
## Master has no data!
master <- Site$new(name = "Master", data = c(), den = den)
master$makeMeMaster()
```

### 2. Distribute public key to sites


```{r}
site1$setPublicKey(keys$pubkey)
site2$setPublicKey(keys$pubkey)
site3$setPublicKey(keys$pubkey)
master$setPublicKey(keys$pubkey)
```

Only master has private key for decryption.

```{r}
master$setPrivateKey(keys$getPrivateKey())
```


### 3. Define the communication graph

Master will always send to the first site, and then the others have to
forward results in turn with the last site returning to the master.

```{r}
master$setNextSite(site1)
site1$setNextSite(site2)
site2$setNextSite(site3)
site3$setNextSite(master)
```

### 4. Perform the likelihood estimation



```{r}
library(stats4)
nll <- function(age, sex, bm) master$nLL(c(age, sex, bm))
fit <- mle(nll, start = list(age = 0, sex = 0, bm = 0))
```

### 5. Compare the results

The summary will show the results.

```{r}
summary(fit)
```

Note how the estimated coefficients and standard errors closely match
the full model summary below.

```{r}
summary(aggModel)
```

And the log likelihood of the distributed homomorphic fit also
matches that of the model on aggregated data:

```{r}
cat(sprintf("logLik(MLE fit): %f, logLik(Agg. fit): %f.\n", logLik(fit), aggModel$loglik[2]))
```

## References