## ----echo=FALSE, message = FALSE----------------------------------------------
### get knitr just the way we like it

knitr::opts_chunk$set(
  message = FALSE,
  warning = FALSE,
  error = FALSE,
  tidy = FALSE,
  cache = FALSE
  )
library(homomorpheR)
if (!require("magrittr", quietly = TRUE)) stop("This vignette requires the magrittr package!")
if (!require("dplyr", quietly = TRUE)) stop("This vignette requires the dplyr package!")
                                                       

## -----------------------------------------------------------------------------
set.seed(130)
sample_size  <- c(60, 15, 25)
query_data  <- local({
    tmp  <- c(0, cumsum(sample_size))
    start  <- tmp[1:3] + 1
    end  <- tmp[-1]
    id_list  <- Map(seq, from = start, to = end)
    lapply(seq_along(sample_size),
           function(i) {
               id  <- sprintf("P%4d", id_list[[i]])
               sex <- sample(c("F", "M"), size = sample_size[i], replace = TRUE)
               age <- sample(40:70, size = sample_size[i], replace = TRUE)
               bm <- rnorm(sample_size[i])
               data.frame(id = id, sex = sex, age = age, bm = bm, stringsAsFactors = FALSE)
           })
})


## -----------------------------------------------------------------------------
str(query_data[[1]])

## -----------------------------------------------------------------------------
str(query_data[[2]])

## -----------------------------------------------------------------------------
str(query_data[[3]])

## -----------------------------------------------------------------------------
do.call(rbind, query_data) %>%
    filter(age < 50 & sex == 'F' & bm < 0.2) %>%
    nrow()

## -----------------------------------------------------------------------------
Site <-
    R6::R6Class(
            "Site",
            private = list(
                ## name of the site
                name = NA,
                ## local data
                data = NA,
                result_cache = NULL,
                filterCondition = NA,
                local_query_count = function() {
                    ## Check if value is cached
                    result  <- private$result_cache
                    if (is.null(result)) {
                        ## We need to run the query
                        pubkey <- self$pubkey
                        ## Generate random offset for int and frac parts
                        offset.int <- random.bigz(nBits = 256)
                        ## 2. Add to count
                        data  <- private$data
                        filter_expr  <- eval(parse(text = paste("rlang::expr(", private$filterCondition, ")")))
                        data %>%
                            dplyr::filter(!! filter_expr) %>%
                            nrow() ->
                            result.int
                        result  <- list(
                            int1 = pubkey$encrypt(result.int - offset.int),
                            int2 = pubkey$encrypt(result.int + offset.int)
                        )
                        private$result_cache  <- result
                    }
                    result
                }
            ),
            public = list(
                ## Common denominator for approximate real arithmetic
                den = NA,
                ## The master's public key; everyone has this
                pubkey = NA,
                initialize = function(name, data) {
                    private$name <- name
                    private$data <- data
                },
                setPublicKey = function(pubkey) {
                    self$pubkey <- pubkey
                },
                setDenominator = function(den) {
                    self$den = den
                },
                setFilterCondition = function (filterCondition) {
                    private$filterCondition  <- filterCondition
                },
                ## query count,
                query_count = function(party) {
                    result  <- private$local_query_count()
                    if (party == 1) result$int1 else result$int2
                }
            )
        )


## -----------------------------------------------------------------------------

NCParty <-
    R6::R6Class(
            "NCParty",
            private = list(
                ## name of the site
                name = NA,
                ## NC party number
                number = NA,
                ## filter condition
                filterCondition = NA,
                ## The master
                master = NA,
                ## The sites
                sites = list()
            ),
            public = list(
                ## The master's public key; everyone has this
                pubkey = NA,
                ## The denoinator for rational arithmetic
                den = NA,
                initialize = function(name, number) {
                    private$name <- name
                    private$number  <- number
                },
                setPublicKey = function(pubkey) {
                    self$pubkey <- pubkey
                    ## Propagate to sites
                    for (site in sites) {
                        site$setPublicKey(pubkey)
                    }
                },
                setDenominator = function(den) {
                    self$den <- den
                    ## Propagate to sites
                    for (site in sites) {
                        site$setDenominator(den)
                    }
                },
                setFilterCondition = function(filterCondition) {
                    private$filterCondition  <- filterCondition
                    ## Propagate to sites
                    for (site in sites) {
                        site$setFilterCondition(filterCondition)
                    }
                },
                addSite = function(site) {
                    private$sites  <- c(private$sites, list(site))
                },
                ## sum of all counts
                query_count = function() {
                    pubkey  <- self$pubkey
                    results  <- lapply(sites, function(x) x$query_count(private$number))
                    ## Accumulate the integer and fractional parts
                    n  <- length(results)
                    enc_sum <- pubkey$encrypt(0)
                    for (result in results) {
                        enc_sum  <- pubkey$add(enc_sum, result)
                    }
                    enc_sum
                }
            )
        )


## -----------------------------------------------------------------------------
Master  <-
    R6::R6Class(
            "Master",
            private = list(
                ## name of the site
                name = NA,
                ## Private and public keys
                keys = NA,
                ## Non cooperating party 1
                nc_party_1 = NA,
                ## Non cooperating party 2
                nc_party_2 = NA,
                ## filter condition
                filterCondition = NA
            ),
            public = list(
                ## Denominator for rational arithmetic
                den  = NA,
                initialize = function(name, filterCondition) {
                    private$name <- name
                    private$keys <- PaillierKeyPair$new(1024) ## Generate new public and private key.
                    self$den <- gmp::as.bigq(2)^256  #Our denominator for rational approximations
                    private$filterCondition  <- filterCondition
                },
                setNCParty1  = function(site) {
                    private$nc_party_1 <- site
                    private$nc_party_1$setPublicKey(private$keys$pubkey)
                    private$nc_party_1$setDenominator(self$den)
                    private$nc_party_1$setFilterCondition(private$filterCondition)
                },
                setNCParty2  = function(site) {
                    private$nc_party_2 <- site
                    private$nc_party_2$setPublicKey(private$keys$pubkey)
                    private$nc_party_2$setDenominator(self$den)
                    private$nc_party_2$setFilterCondition(private$filterCondition)
                },
                ## Query count
                query_count = function() {
                    pubkey  <- private$keys$pubkey
                    privkey  <- private$keys$getPrivateKey()
                    result1  <- private$nc_party_1$query_count()
                    result2  <- private$nc_party_2$query_count()
                    ## Sum it
                    enc_sum <- pubkey$add(result1, result2)
                    final_result <- as.integer(privkey$decrypt(enc_sum))
                    ## Since we 2c, we divide by 2.
                    final_result / 2
                }
            )
        )


## -----------------------------------------------------------------------------
site1 <- Site$new(name = "Site 1", data = query_data[[1]])
site2 <- Site$new(name = "Site 2", data = query_data[[2]])
site3 <- Site$new(name = "Site 3", data = query_data[[3]])

sites  <- list(site1 = site1, site2 = site2, site3 = site3)

## -----------------------------------------------------------------------------
ncp1  <- NCParty$new("NCP1", 1)
ncp2  <- NCParty$new("NCP1", 2)


## -----------------------------------------------------------------------------
for (s in sites) {
    ncp1$addSite(s)
    ncp2$addSite(s)
}

## -----------------------------------------------------------------------------
master  <- Master$new(name = "Master",
                      filterCondition = "age < 50 & sex == 'F' & bm < 0.2")

## -----------------------------------------------------------------------------
master$setNCParty1(ncp1)
master$setNCParty2(ncp2)


## -----------------------------------------------------------------------------
cat(sprintf("Query Count is %d\n", master$query_count()))