--- title: "Distributed Stratified Cox Regression using Homomorphic Computation" author: "Balasubramanian Narasimhan" date: '`r Sys.Date()`' bibliography: homomorphing.bib output: html_document: theme: cerulean toc: yes toc_depth: 2 vignette: > %\VignetteIndexEntry{Distributed Stratified Cox Regression using Homomorphic Computation} %\VignetteEngine{knitr::rmarkdown} \usepackage[utf8]{inputenc} --- ```{r echo=FALSE} ### get knitr just the way we like it knitr::opts_chunk$set( message = FALSE, warning = FALSE, error = FALSE, tidy = FALSE, cache = FALSE ) ``` ## Introduction It is only a short way from the toy MLE example to a more useful example using Cox regression. But first, we need the `survival` package and the `homomopheR` package. ```{r} if (!require("survival")) { stop("this vignette requires the survival package") } library(homomorpheR) ``` We generate some simulated data for the purpose of this example. We will have three sites each with patient data (sizes 1000, 500 and 1500) respectively, containing - `sex` (0, 1) for male/female - `age` between 40 and 70 - a biomarker `bm` - a `time` to some event of interest - an indicator `event` which is 1 if an event was observed and 0 otherwise. It is common to fit stratified models using sites as strata since the patient characteristics usually differ from site to site. So the baseline hazards (`lambdaT`) are different for each site but they share common coefficients (`beta.1`, `beta.2` and `beta.3` for `age`, `sex` and `bm` respy.) for the model. See [@survival-book] by Therneau and Grambsch for details. So our model for each site $i$ is $$ S(t, age, sex, bm) = [S_0^i(t)]^{\exp(\beta_1 age + \beta_2 sex + \beta_3 bm)} $$ ```{r} sampleSize <- c(n1 = 1000, n2 = 500, n3 = 1500) set.seed(12345) beta.1 <- -.015; beta.2 <- .2; beta.3 <- .001; lambdaT <- c(5, 4, 3) lambdaC <- 2 coxData <- lapply(seq_along(sampleSize), function(i) { sex <- sample(c(0, 1), size = sampleSize[i], replace = TRUE) age <- sample(40:70, size = sampleSize[i], replace = TRUE) bm <- rnorm(sampleSize[i]) trueTime <- rweibull(sampleSize[i], shape = 1, scale = lambdaT[i] * exp(beta.1 * age + beta.2 * sex + beta.3 * bm )) censoringTime <- rweibull(sampleSize[i], shape = 1, scale = lambdaC) time <- pmin(trueTime, censoringTime) event <- (time == trueTime) data.frame(stratum = i, sex = sex, age = age, bm = bm, time = time, event = event) }) ``` So here is a summary of the data for the three sites. ### Site 1 ```{r} str(coxData[[1]]) ``` ### Site 2 ```{r} str(coxData[[2]]) ``` ### Site 3 ```{r} str(coxData[[3]]) ``` # # Aggregated fit If the data were all aggregated in one place, it would very simple to fit the model. Below, we row-bind the data from the three sites. ```{r} aggModel <- coxph(formula = Surv(time, event) ~ sex + age + bm + strata(stratum), data = do.call(rbind, coxData)) aggModel ``` Here `age` and `sex` are significant, but `bm` is not. The estimates $\hat{\beta}$ are `(-0.180, .020, .007)`. We can also print out the value of the (partial) log-likelihood at the MLE. ```{r} aggModel$loglik ``` The first is the value at the parameter value `(0, 0, 0)` and the last is the value at the MLE. ## Distributed Computation Assume now that the data `coxData` is distributed between three sites none of whom want to share actual data among each other or even with a master computation process. They wish to keep their data secret but are willing, together, to provide the sum of their local negative log-likelihoods. They need to do this in a way so that the master process will not be able to associate the contribution to the likelihood from each site. The overall likelihood function $l(\lambda)$ for the entire data is therefore the sum of the likelihoods at each site: $l(\lambda) = l_1(\lambda)+l_2(\lambda)+l_3(\lambda).$ How can this likelihood be computed while maintaining privacy? Assuming that every site including the master has access to a homomorphic computation library such as `homomorpheR`, the likelihood can be computed in a privacy-preserving manner using the following scheme. We use $E(x)$ and $D(x)$ to denote the encrypted and decrypted values of $x$ respectively. 0. Master generates a public/private key pair. Master distributes the public key to all sites. (The private key is not distributed and kept only by the master.) 1. Master generates a random offset $r$ to obfuscate the intial likelihood. 2. Master sends $E(r)$ and a guess $\lambda_0$ to site 1. Note that $\lambda$ is not encrypted. 3. Site 1 computes $l_1 = l(\lambda_0, y_1)$, the local likelihood for local data $y_1$ using parameter $\lambda_0$. It then sends on $\lambda_0$ and $E(r) + E(l_1)$ to site 2. 4. Site 2 computes $l_2 = l(\lambda_0, y_2)$, the local likelihood for local data $y_2$ using parameter $\lambda_0$. It then sends on $\lambda_0$ and $E(r) + E(l_1) + E(l_2)$ to site 3. 5. Site 3 computes $l_3 = l(\lambda_0, y_3)$, the local likelihood for local data $y_3$ using parameter $\lambda_0$. It then sends on $E(r) + E(l_1) + E(l_2) + E(l_3)$ back to master. 6. Master retrieves $E(r) + E(l_1) + E(l_2) + E(l_3)$ which, due to the homomorphism, is exactly $E(r+l_1+l_2+l_3) = E(r+l).$ So the master computes $D(E(r+l)) - r$ to obtain the value of the overall likelihood at $\lambda_0$. 7. Master updates $\lambda_0$ with a new guess $\lambda_1$ and repeats steps 1-5. This process is iterated to convergence. For added security, even steps 0-5 can be repeated, at additional computational cost. This is pictorially shown below. ![](assets/round_robin.png) ## Implementation The above implementation assumes that the encryption and decryption can happen with real numbers which is not the actual situation. Instead, we use rational approximations using a large denominator, $2^{256}$, say. In the future, of course, we need to build an actual library is built with rigorous algorithms guaranteeing precision and overflow/undeflow detection. For now, this is just an ad hoc implementation. Also, since we are only using homomorphic additive properties, a partial homomorphic scheme such as the Paillier Encryption system will be sufficient for our computations. We define a class to encapsulate our sites that will compute the Poisson likelihood on site data given a parameter $\lambda$. Note how the `addNLLAndForward` method takes care to split the result into an integer and fractional part while performing the arithmetic operations. (The latter is approximated by a rational number.) We define a class to encapsulate our sites that will compute the partial log likelihood on site data given a parameter $\beta$. In the code below, we exploit, for expository purposes, a feature of `coxph`: a control parameter can be passed to evaluate the partial likelihood at a given $\beta$ value. ```{r} Site <- R6::R6Class("Site", private = list( ## name of the site name = NA, ## only master has this, NA for workers privkey = NA, ## local data data = NA, ## The next site in the communication: NA for master nextSite = NA, ## is this the master site? iAmMaster = FALSE, ## intermediate result variable intermediateResult = NA, ## Control variable for cox regression cph.control = NA ), public = list( count = NA, ## Common denominator for approximate real arithmetic den = NA, ## The public key; everyone has this pubkey = NA, initialize = function(name, data, den) { private$name <- name private$data <- data self$den <- den private$cph.control <- replace(coxph.control(), "iter.max", 0) }, setPublicKey = function(pubkey) { self$pubkey <- pubkey }, setPrivateKey = function(privkey) { private$privkey <- privkey }, ## Make me master makeMeMaster = function() { private$iAmMaster <- TRUE }, ## add neg log lik and forward to next site addNLLAndForward = function(beta, enc.offset) { if (private$iAmMaster) { ## We are master, so don't forward ## Just store intermediate result and return private$intermediateResult <- enc.offset } else { ## We are workers, so add and forward ## add negative log likelihood and forward result to next site ## Note that offset is encrypted nllValue <- self$nLL(beta) result.int <- floor(nllValue) result.frac <- nllValue - result.int result.fracnum <- gmp::as.bigq(gmp::numerator(gmp::as.bigq(result.frac) * self$den)) pubkey <- self$pubkey enc.result.int <- pubkey$encrypt(result.int) enc.result.fracnum <- pubkey$encrypt(result.fracnum) result <- list(int = pubkey$add(enc.result.int, enc.offset$int), frac = pubkey$add(enc.result.fracnum, enc.offset$frac)) private$nextSite$addNLLAndForward(beta, enc.offset = result) } ## Return a TRUE result for now. TRUE }, ## Set the next site in the communication graph setNextSite = function(nextSite) { private$nextSite <- nextSite }, ## The negative log likelihood nLL = function(beta) { if (private$iAmMaster) { ## We're master, so need to get result from sites ## 1. Generate a random offset and encrypt it pubkey <- self$pubkey offset <- list(int = random.bigz(nBits = 256), frac = random.bigz(nBits = 256)) enc.offset <- list(int = pubkey$encrypt(offset$int), frac = pubkey$encrypt(offset$frac)) ## 2. Send off to next site throwaway <- private$nextSite$addNLLAndForward(beta, enc.offset) ## 3. When the call returns, the result will be in ## the field intermediateResult, so decrypt that. sum <- private$intermediateResult privkey <- private$privkey intResult <- as.double(privkey$decrypt(sum$int) - offset$int) fracResult <- as.double(gmp::as.bigq(privkey$decrypt(sum$frac) - offset$frac) / den) intResult + fracResult } else { ## We're worker, so compute local negative log likelihood tryCatch({ m <- coxph(formula = Surv(time, event) ~ sex + age + bm, data = private$data, init = beta, control = private$cph.control) -(m$loglik[1]) }, error = function(e) NA) } }) ) ``` We are now ready to use our sites in the computation. ### 1. Generate public and private key pair We also choose a denominator for all our rational approximations. ```{r} keys <- PaillierKeyPair$new(1024) ## Generate new public and private key. den <- gmp::as.bigq(2)^256 #Our denominator for rational approximations ``` ### 2. Create sites ```{r} site1 <- Site$new(name = "Site 1", data = coxData[[1]], den = den) site2 <- Site$new(name = "Site 2", data = coxData[[2]], den = den) site3 <- Site$new(name = "Site 3", data = coxData[[3]], den = den) ``` The master process is also a site but has no data. So has to be thus designated. ```{r} ## Master has no data! master <- Site$new(name = "Master", data = c(), den = den) master$makeMeMaster() ``` ### 2. Distribute public key to sites ```{r} site1$setPublicKey(keys$pubkey) site2$setPublicKey(keys$pubkey) site3$setPublicKey(keys$pubkey) master$setPublicKey(keys$pubkey) ``` Only master has private key for decryption. ```{r} master$setPrivateKey(keys$getPrivateKey()) ``` ### 3. Define the communication graph Master will always send to the first site, and then the others have to forward results in turn with the last site returning to the master. ```{r} master$setNextSite(site1) site1$setNextSite(site2) site2$setNextSite(site3) site3$setNextSite(master) ``` ### 4. Perform the likelihood estimation ```{r} library(stats4) nll <- function(age, sex, bm) master$nLL(c(age, sex, bm)) fit <- mle(nll, start = list(age = 0, sex = 0, bm = 0)) ``` ### 5. Compare the results The summary will show the results. ```{r} summary(fit) ``` Note how the estimated coefficients and standard errors closely match the full model summary below. ```{r} summary(aggModel) ``` And the log likelihood of the distributed homomorphic fit also matches that of the model on aggregated data: ```{r} cat(sprintf("logLik(MLE fit): %f, logLik(Agg. fit): %f.\n", logLik(fit), aggModel$loglik[2])) ``` ## References