It is only a short way from the toy MLE example to a more useful example using Cox regression.
But first, we need the survival
package and the
homomopheR
package.
if (!require("survival")) {
stop("this vignette requires the survival package")
}
library(homomorpheR)
We generate some simulated data for the purpose of this example. We will have three sites each with patient data (sizes 1000, 500 and 1500) respectively, containing
sex
(0, 1) for male/femaleage
between 40 and 70bm
time
to some event of interestevent
which is 1 if an event was observed
and 0 otherwise.It is common to fit stratified models using sites as strata since the
patient characteristics usually differ from site to site. So the
baseline hazards (lambdaT
) are different for each site but
they share common coefficients (beta.1
, beta.2
and beta.3
for age
, sex
and
bm
respy.) for the model. See (Terry
M. Therneau and Patricia M. Grambsch 2000) by Therneau and
Grambsch for details. So our model for each site i is
S(t, age, sex, bm) = [S0i(t)]exp (β1age + β2sex + β3bm)
sampleSize <- c(n1 = 1000, n2 = 500, n3 = 1500)
set.seed(12345)
beta.1 <- -.015; beta.2 <- .2; beta.3 <- .001;
lambdaT <- c(5, 4, 3)
lambdaC <- 2
coxData <- lapply(seq_along(sampleSize),
function(i) {
sex <- sample(c(0, 1), size = sampleSize[i], replace = TRUE)
age <- sample(40:70, size = sampleSize[i], replace = TRUE)
bm <- rnorm(sampleSize[i])
trueTime <- rweibull(sampleSize[i],
shape = 1,
scale = lambdaT[i] * exp(beta.1 * age + beta.2 * sex + beta.3 * bm ))
censoringTime <- rweibull(sampleSize[i],
shape = 1,
scale = lambdaC)
time <- pmin(trueTime, censoringTime)
event <- (time == trueTime)
data.frame(stratum = i,
sex = sex,
age = age,
bm = bm,
time = time,
event = event)
})
So here is a summary of the data for the three sites.
## 'data.frame': 1000 obs. of 6 variables:
## $ stratum: int 1 1 1 1 1 1 1 1 1 1 ...
## $ sex : num 1 0 1 1 1 1 1 0 0 1 ...
## $ age : int 47 69 70 47 41 51 59 45 43 69 ...
## $ bm : num -0.516 -1.375 1.01 0.454 0.275 ...
## $ time : num 1.37 0.95 2.35 2.48 1.93 ...
## $ event : logi FALSE TRUE TRUE TRUE FALSE FALSE ...
## 'data.frame': 500 obs. of 6 variables:
## $ stratum: int 2 2 2 2 2 2 2 2 2 2 ...
## $ sex : num 0 1 0 1 1 1 0 1 1 1 ...
## $ age : int 54 63 53 70 40 57 48 54 63 47 ...
## $ bm : num -0.3243 0.2531 0.0464 0.8149 -0.1921 ...
## $ time : num 1.10483 0.34804 0.01602 0.68249 0.00157 ...
## $ event : logi FALSE FALSE TRUE TRUE FALSE TRUE ...
## 'data.frame': 1500 obs. of 6 variables:
## $ stratum: int 3 3 3 3 3 3 3 3 3 3 ...
## $ sex : num 1 0 0 1 1 1 0 1 0 1 ...
## $ age : int 55 70 49 60 44 42 58 62 61 68 ...
## $ bm : num -0.9554 0.8138 0.0425 -1.2272 0.3244 ...
## $ time : num 0.0733 1.9869 2.2946 0.1231 1.0602 ...
## $ event : logi TRUE FALSE FALSE TRUE FALSE FALSE ...
If the data were all aggregated in one place, it would very simple to fit the model. Below, we row-bind the data from the three sites.
aggModel <- coxph(formula = Surv(time, event) ~ sex +
age + bm + strata(stratum),
data = do.call(rbind, coxData))
aggModel
## Call:
## coxph(formula = Surv(time, event) ~ sex + age + bm + strata(stratum),
## data = do.call(rbind, coxData))
##
## coef exp(coef) se(coef) z p
## sex -0.160493 0.851723 0.050627 -3.170 0.00152
## age 0.010057 1.010108 0.002835 3.547 0.00039
## bm -0.005989 0.994029 0.025208 -0.238 0.81222
##
## Likelihood ratio test=22.82 on 3 df, p=4.413e-05
## n= 3000, number of events= 1575
Here age
and sex
are significant, but
bm
is not. The estimates β̂ are
(-0.180, .020, .007)
.
We can also print out the value of the (partial) log-likelihood at the MLE.
## [1] -9534.495 -9523.087
The first is the value at the parameter value (0, 0, 0)
and the last is the value at the MLE.
Assume now that the data coxData
is distributed between
three sites none of whom want to share actual data among each other or
even with a master computation process. They wish to keep their data
secret but are willing, together, to provide the sum of their local
negative log-likelihoods. They need to do this in a way so that the
master process will not be able to associate the contribution to the
likelihood from each site.
The overall likelihood function l(λ) for the entire data is therefore the sum of the likelihoods at each site: l(λ) = l1(λ) + l2(λ) + l3(λ). How can this likelihood be computed while maintaining privacy?
Assuming that every site including the master has access to a
homomorphic computation library such as homomorpheR
, the
likelihood can be computed in a privacy-preserving manner using the
following scheme. We use E(x) and D(x) to denote the
encrypted and decrypted values of x respectively.
This is pictorially shown below.
The above implementation assumes that the encryption and decryption can happen with real numbers which is not the actual situation. Instead, we use rational approximations using a large denominator, 2256, say. In the future, of course, we need to build an actual library is built with rigorous algorithms guaranteeing precision and overflow/undeflow detection. For now, this is just an ad hoc implementation.
Also, since we are only using homomorphic additive properties, a partial homomorphic scheme such as the Paillier Encryption system will be sufficient for our computations.
We define a class to encapsulate our sites that will compute the
Poisson likelihood on site data given a parameter λ. Note how the
addNLLAndForward
method takes care to split the result into
an integer and fractional part while performing the arithmetic
operations. (The latter is approximated by a rational number.)
We define a class to encapsulate our sites that will compute the partial log likelihood on site data given a parameter β.
In the code below, we exploit, for expository purposes, a feature of
coxph
: a control parameter can be passed to evaluate the
partial likelihood at a given β value.
Site <- R6::R6Class("Site",
private = list(
## name of the site
name = NA,
## only master has this, NA for workers
privkey = NA,
## local data
data = NA,
## The next site in the communication: NA for master
nextSite = NA,
## is this the master site?
iAmMaster = FALSE,
## intermediate result variable
intermediateResult = NA,
## Control variable for cox regression
cph.control = NA
),
public = list(
count = NA,
## Common denominator for approximate real arithmetic
den = NA,
## The public key; everyone has this
pubkey = NA,
initialize = function(name, data, den) {
private$name <- name
private$data <- data
self$den <- den
private$cph.control <- replace(coxph.control(), "iter.max", 0)
},
setPublicKey = function(pubkey) {
self$pubkey <- pubkey
},
setPrivateKey = function(privkey) {
private$privkey <- privkey
},
## Make me master
makeMeMaster = function() {
private$iAmMaster <- TRUE
},
## add neg log lik and forward to next site
addNLLAndForward = function(beta, enc.offset) {
if (private$iAmMaster) {
## We are master, so don't forward
## Just store intermediate result and return
private$intermediateResult <- enc.offset
} else {
## We are workers, so add and forward
## add negative log likelihood and forward result to next site
## Note that offset is encrypted
nllValue <- self$nLL(beta)
result.int <- floor(nllValue)
result.frac <- nllValue - result.int
result.fracnum <- gmp::as.bigq(gmp::numerator(gmp::as.bigq(result.frac) * self$den))
pubkey <- self$pubkey
enc.result.int <- pubkey$encrypt(result.int)
enc.result.fracnum <- pubkey$encrypt(result.fracnum)
result <- list(int = pubkey$add(enc.result.int, enc.offset$int),
frac = pubkey$add(enc.result.fracnum, enc.offset$frac))
private$nextSite$addNLLAndForward(beta, enc.offset = result)
}
## Return a TRUE result for now.
TRUE
},
## Set the next site in the communication graph
setNextSite = function(nextSite) {
private$nextSite <- nextSite
},
## The negative log likelihood
nLL = function(beta) {
if (private$iAmMaster) {
## We're master, so need to get result from sites
## 1. Generate a random offset and encrypt it
pubkey <- self$pubkey
offset <- list(int = random.bigz(nBits = 256),
frac = random.bigz(nBits = 256))
enc.offset <- list(int = pubkey$encrypt(offset$int),
frac = pubkey$encrypt(offset$frac))
## 2. Send off to next site
throwaway <- private$nextSite$addNLLAndForward(beta, enc.offset)
## 3. When the call returns, the result will be in
## the field intermediateResult, so decrypt that.
sum <- private$intermediateResult
privkey <- private$privkey
intResult <- as.double(privkey$decrypt(sum$int) - offset$int)
fracResult <- as.double(gmp::as.bigq(privkey$decrypt(sum$frac) - offset$frac) / den)
intResult + fracResult
} else {
## We're worker, so compute local negative log likelihood
tryCatch({
m <- coxph(formula = Surv(time, event) ~ sex + age + bm,
data = private$data,
init = beta,
control = private$cph.control)
-(m$loglik[1])
},
error = function(e) NA)
}
})
)
We are now ready to use our sites in the computation.
We also choose a denominator for all our rational approximations.
site1 <- Site$new(name = "Site 1", data = coxData[[1]], den = den)
site2 <- Site$new(name = "Site 2", data = coxData[[2]], den = den)
site3 <- Site$new(name = "Site 3", data = coxData[[3]], den = den)
The master process is also a site but has no data. So has to be thus designated.
site1$setPublicKey(keys$pubkey)
site2$setPublicKey(keys$pubkey)
site3$setPublicKey(keys$pubkey)
master$setPublicKey(keys$pubkey)
Only master has private key for decryption.
Master will always send to the first site, and then the others have to forward results in turn with the last site returning to the master.
The summary will show the results.
## Maximum likelihood estimation
##
## Call:
## mle(minuslogl = nll, start = list(age = 0, sex = 0, bm = 0))
##
## Coefficients:
## Estimate Std. Error
## age -0.160493329 0.050626611
## sex 0.010057265 0.002835374
## bm -0.005988214 0.025208370
##
## -2 log L: 19046.17
Note how the estimated coefficients and standard errors closely match the full model summary below.
## Call:
## coxph(formula = Surv(time, event) ~ sex + age + bm + strata(stratum),
## data = do.call(rbind, coxData))
##
## n= 3000, number of events= 1575
##
## coef exp(coef) se(coef) z Pr(>|z|)
## sex -0.160493 0.851723 0.050627 -3.170 0.00152 **
## age 0.010057 1.010108 0.002835 3.547 0.00039 ***
## bm -0.005989 0.994029 0.025208 -0.238 0.81222
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## exp(coef) exp(-coef) lower .95 upper .95
## sex 0.8517 1.174 0.7713 0.9406
## age 1.0101 0.990 1.0045 1.0157
## bm 0.9940 1.006 0.9461 1.0444
##
## Concordance= 0.536 (se = 0.009 )
## Likelihood ratio test= 22.82 on 3 df, p=4e-05
## Wald test = 22.81 on 3 df, p=4e-05
## Score (logrank) test = 22.85 on 3 df, p=4e-05
And the log likelihood of the distributed homomorphic fit also matches that of the model on aggregated data:
## logLik(MLE fit): -9523.087001, logLik(Agg. fit): -9523.087001.