## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7,
  fig.height = (4/6)*7
)

## ----setup--------------------------------------------------------------------
library(QAEnsemble)

## ----eval=TRUE----------------------------------------------------------------
library(svMisc)
library(coda)

## ----eval=TRUE----------------------------------------------------------------
set.seed(10)

## ----eval=TRUE----------------------------------------------------------------
a_shape = 8
sigma_scale = 1100

time_use = 12*c(1:110)

plot(time_use, exp(-1*((time_use/sigma_scale)^a_shape)), xlab="Months (t)", ylab="Survival probability (S)", type = "l", col = "red")

## ----eval=TRUE----------------------------------------------------------------

#Input for the 'create_life_table_info' function:

#interval_time_input: the start and end times for each time interval (Note: it 
#is assumed in the 'create_life_table_info' function that the first start time 
#is at month 0)

#num_ran_samples_input: the number of random samples from the Weibull 
#distribution

#data_weibull_input: the Weibull distribution mortality time samples

#Output for the 'create_life_table_info' is a list containing two data frames. 
#The first data frame contains the life table information: time intervals, the 
#number of people at risk in the interval, number of deaths that occur in the 
#interval, risk of death in the interval, probability of surviving in the 
#interval, and the survival probabilities. The second data frame contains the 
#interval-censored information in a format easy to use in the likelihood 
#function: interval-censored mortality time observation number, interval left 
#bound, and interval right bound.

create_life_table_info <- function(interval_time_input, num_ran_samples_input, data_weibull_input)
{
  num_at_risk_during_interval_output = matrix(NA, nrow = 1, ncol = length(interval_time_input))
  num_deaths_during_interval_output = matrix(NA, nrow = 1, ncol = length(interval_time_input))
  risk_death_during_interval_output = matrix(NA, nrow = 1, ncol = length(interval_time_input))
  prob_surv_during_interval_output = matrix(NA, nrow = 1, ncol = length(interval_time_input))
  s_values_output = matrix(NA, nrow = 1, ncol = length(interval_time_input))
  
  obs_num_output = matrix(1:num_ran_samples_input, nrow = 1, ncol = num_ran_samples_input)
  left_endpoint_output = matrix(NA, nrow = 1, ncol = num_ran_samples_input)
  right_endpoint_output = matrix(NA, nrow = 1, ncol = num_ran_samples_input)
  
  for(t in 1:length(interval_time_input))
  {
  if(t == 1)
  {
    num_at_risk_during_interval_output[t] = num_ran_samples_input
    
    num_deaths_during_interval_output[t] = length(data_weibull_input[0 <= data_weibull_input & data_weibull_input <= interval_time_input[1]])
    
    risk_death_during_interval_output[t] = num_deaths_during_interval_output[t]/num_at_risk_during_interval_output[t]
    
    prob_surv_during_interval_output[t] = 1 - risk_death_during_interval_output[t]
    
    s_values_output[t] = 1*prob_surv_during_interval_output[t]
  }
  
  if(t != 1)
  {
    if(num_at_risk_during_interval_output[t-1] - num_deaths_during_interval_output[t-1] > 0)
    {
    num_at_risk_during_interval_output[t] = num_at_risk_during_interval_output[t-1] - num_deaths_during_interval_output[t-1]
    
    num_deaths_during_interval_output[t] = length(data_weibull_input[interval_time_input[t-1] < data_weibull_input & data_weibull_input <= interval_time_input[t]])
    
    risk_death_during_interval_output[t] = num_deaths_during_interval_output[t]/num_at_risk_during_interval_output[t]
    
    prob_surv_during_interval_output[t] = 1 - risk_death_during_interval_output[t]
    
    s_values_output[t] = s_values_output[t-1]*prob_surv_during_interval_output[t]
    }

    if(num_at_risk_during_interval_output[t-1] - num_deaths_during_interval_output[t-1] <= 0)
    {
    num_at_risk_during_interval_output[t] = 0
    
    num_deaths_during_interval_output[t] = 0
    
    risk_death_during_interval_output[t] = 1
    
    prob_surv_during_interval_output[t] = 1 - risk_death_during_interval_output[t]
    
    s_values_output[t] = s_values_output[t-1]*prob_surv_during_interval_output[t]
    }
  }
  
  }
  
  string_time_interval_names = matrix(NA, nrow = 1, ncol = length(interval_time_input))
  
  i = 1
  
  for(t in 1:length(interval_time_input))
  {
    if(t == 1)
    {
      string_time_interval_names[1] = paste(0, "-", 12, " mo", " (", 0, "-", 1, " yr)", sep = "")
      
      if(num_deaths_during_interval_output[t] > 0)
      {
        for(j in 1:num_deaths_during_interval_output[t])
        {
          left_endpoint_output[i] = 0
          right_endpoint_output[i] = 12
        
          i = i + 1
        }
      }
    }
    
    if(t != 1)
    {
      string_time_interval_names[t] = paste(interval_time_input[t-1], "-", interval_time_input[t], " mo", " (", interval_time_input[t-1]/12, "-", interval_time_input[t]/12, " yr)", sep = "")
      
      if(num_deaths_during_interval_output[t] > 0)
      {
        for(j in 1:num_deaths_during_interval_output[t])
        {
          left_endpoint_output[i] = interval_time_input[t-1]
          right_endpoint_output[i] = interval_time_input[t]
        
          i = i + 1
        }
      }
    }
  }
  
  life_table_df = data.frame(t(string_time_interval_names), t(num_at_risk_during_interval_output), t(num_deaths_during_interval_output), t(risk_death_during_interval_output), t(prob_surv_during_interval_output), t(s_values_output))
  
  names(life_table_df)[1] = "Time interval"
  names(life_table_df)[2] = "Number of people at risk in the interval"
  names(life_table_df)[3] = "Number of deaths that occur in the interval"
  names(life_table_df)[4] = "Risk of death in the interval"
  names(life_table_df)[5] = "Probability of surviving in the interval"
  names(life_table_df)[6] = "Survival probability"
  
  interval_censored_df = data.frame(t(obs_num_output), t(left_endpoint_output), t(right_endpoint_output))
  
  names(interval_censored_df)[1] = "Interval-censored mortality time observation number"
  names(interval_censored_df)[2] = "Interval left bound"
  names(interval_censored_df)[3] = "Interval right bound"

  return(list("life_table_df" = life_table_df, "interval_censored_df" = interval_censored_df))
}


## ----eval=TRUE----------------------------------------------------------------

num_ran_samples = 50

data_weibull = rweibull(num_ran_samples, shape = a_shape, scale = sigma_scale)

info_use = create_life_table_info(time_use, num_ran_samples, data_weibull)

knitr::kable(info_use$life_table_df, caption = "Life table data", format = "pipe", padding = 0)

## ----eval=TRUE----------------------------------------------------------------
knitr::kable(info_use$interval_censored_df, caption = "Interval-censored data", format = "pipe", padding = 0)

## ----eval=TRUE----------------------------------------------------------------
#Plot the survival probability from the life table data
plot(time_use, info_use$life_table_df[[6]], xlab="Months (t)", ylab="Survival probability (S)", type = "s")


## ----eval=TRUE----------------------------------------------------------------

logp <- function(param)
{
  a_shape_use = param[1]
  sigma_scale_use = param[2]

  logp_val = dunif(a_shape_use, min = 1e-2, max = 1e2, log = TRUE) + dunif(sigma_scale_use, min = 1, max = 1e4, log = TRUE)

  return(logp_val)
}

logl <- function(param)
{
  a_shape_use = param[1]
  sigma_scale_use = param[2]
  
  logl_val = 0
  
  interval_l_bound = info_use$interval_censored_df[[2]]
  interval_r_bound = info_use$interval_censored_df[[3]]
  
  for(i in 1:num_ran_samples)
  {

  logl_val = logl_val + log(exp(-1*((interval_l_bound[i]/sigma_scale_use)^a_shape_use)) - exp(-1*((interval_r_bound[i]/sigma_scale_use)^a_shape_use)))
    
  }
  
  if(is.na(logl_val) || is.infinite(logl_val))
  {
    logl_val = -Inf
  }

  return(logl_val)
}

logfuns = list(logp = logp, logl = logl)

## ----eval=TRUE----------------------------------------------------------------
num_par = 2
num_chains = 2*num_par

theta0 = matrix(0, nrow = num_par, ncol = num_chains)

temp_val = 0
j = 0

while(j < num_chains)
{
  initial = c(runif(1, 1e-2, 1e2), runif(1, 1, 1e4))
  temp_val = logl(initial) + logp(initial)

  while(is.na(temp_val) || is.infinite(temp_val))
  {
    initial = c(runif(1, 1e-2, 1e2), runif(1, 1, 1e4))
    temp_val = logl(initial) + logp(initial)
  }

  j = j + 1

  message(paste('j:', j))
  
  theta0[1,j] = initial[1]
  theta0[2,j] = initial[2]

}

## ----eval=TRUE----------------------------------------------------------------
num_chain_iterations = 1e5
thin_val_par = 10

floor(num_chain_iterations/thin_val_par)*num_chains

## ----eval=TRUE----------------------------------------------------------------

Weibull_Surv_Quad_result = ensemblealg(theta0, logfuns, T_iter = num_chain_iterations, Thin_val = thin_val_par, ShowProgress = TRUE, ReturnCODA = TRUE)


## ----eval=TRUE----------------------------------------------------------------
my_samples = Weibull_Surv_Quad_result$theta_sample
my_log_prior = Weibull_Surv_Quad_result$log_prior_sample
my_log_like = Weibull_Surv_Quad_result$log_like_sample
my_mcmc_list_coda = Weibull_Surv_Quad_result$mcmc_list_coda

## ----eval=TRUE----------------------------------------------------------------
varnames(my_mcmc_list_coda) = c("a_shape", "sigma_scale")

plot(my_mcmc_list_coda)

## ----eval=TRUE----------------------------------------------------------------
my_samples_burn_in = my_samples[,,-c(1:floor((num_chain_iterations/thin_val_par)*0.25))]

my_log_prior_burn_in = my_log_prior[,-c(1:floor((num_chain_iterations/thin_val_par)*0.25))]

my_log_like_burn_in = my_log_like[,-c(1:floor((num_chain_iterations/thin_val_par)*0.25))]

## ----eval=TRUE----------------------------------------------------------------
diagnostic_result = psrfdiagnostic(my_samples_burn_in, 0.05)

diagnostic_result$p_s_r_f_vec

## ----eval=TRUE----------------------------------------------------------------
log_un_post_vec = as.vector(my_log_prior_burn_in + my_log_like_burn_in)
k1 = as.vector(my_samples_burn_in[1,,])
k2 = as.vector(my_samples_burn_in[2,,])

## ----eval=TRUE----------------------------------------------------------------
#a_shape posterior median and 95% credible interval
median(k1)
hpdparameter(k1, 0.05)
#(The 95% CI is narrower than the prior distribution and it contains the true 
#value of a_shape, 8)

#sigma_scale posterior median and 95% credible interval
median(k2)
hpdparameter(k2, 0.05)
#(The 95% CI is narrower than the prior distribution and it contains the true 
#value of sigma_scale, 1100)

#a_shape, sigma_scale, and stand_dev maximum posterior
k1[which.max(log_un_post_vec)]
k2[which.max(log_un_post_vec)]

## ----eval=TRUE----------------------------------------------------------------
plot(k1, exp(log_un_post_vec), xlab="a_shape", ylab="unnormalized posterior density")

## ----eval=TRUE----------------------------------------------------------------
plot(k2, exp(log_un_post_vec), xlab="sigma_scale", ylab="unnormalized posterior density")

## ----eval=TRUE----------------------------------------------------------------
num_samples = 10000

w_predict = matrix(NA, nrow = num_samples, ncol = num_ran_samples)
s_predict = matrix(NA, nrow = num_samples, ncol = length(time_use))

discrepancy_s = matrix(NA, nrow = num_samples, ncol = 1)
discrepancy_s_pred = matrix(NA, nrow = num_samples, ncol = 1)
ind_pred_exceeds_s = matrix(NA, nrow = num_samples, ncol = 1)

for (i in (length(log_un_post_vec) - num_samples + 1):length(log_un_post_vec))
{
  l = i - (length(log_un_post_vec) - num_samples)
  
  w_predict[l,] = rweibull(num_ran_samples, shape = k1[i], scale = k2[i])
  
  data_weibull_temp = w_predict[l,]
  
  info_use_predict = create_life_table_info(time_use, num_ran_samples, data_weibull_temp)
  
  s_predict[l,] = t(info_use_predict$life_table_df[[6]])
  
  s_data = t(info_use$life_table_df[[6]])
  
  my_s_mean = exp(-1*((time_use/k2[i])^k1[i]))
  
  discrepancy_s[l,1] = sum(((s_data - my_s_mean)^2))

    discrepancy_s_pred[l,1] = sum(((s_predict[l,] - my_s_mean)^2))

    if(discrepancy_s_pred[l,1] > discrepancy_s[l,1])
    {
      ind_pred_exceeds_s[l,1] = 1
    }
    else
    {
      ind_pred_exceeds_s[l,1] = 0
    }

  svMisc::progress(l,num_samples,progress.bar = TRUE,init = (l == 1))
}

## ----eval=TRUE----------------------------------------------------------------
Bayes_p_value = sum(ind_pred_exceeds_s)/length(ind_pred_exceeds_s)

Bayes_p_value

## ----eval=TRUE----------------------------------------------------------------
s_predict_lower = matrix(NA, nrow = 1, ncol = length(time_use))
s_predict_upper = matrix(NA, nrow = 1, ncol = length(time_use))

for(j in 1:length(time_use))
{
  s_predict_lower[j] = quantile(s_predict[,j], probs = 0.025)
  s_predict_upper[j] = quantile(s_predict[,j], probs = 0.975)
}

## ----eval=TRUE----------------------------------------------------------------
plot(time_use, info_use$life_table_df[[6]], main="Weibull survival curve (a = 8, sigma = 1100) Example", cex.main = 1, xlab="Months (t)", ylab="Survival probability (S)", type = "s", ylim = c(0, 1))
lines(time_use, exp(-1*((time_use/sigma_scale)^a_shape)), type = "l", lty = 1, col = "red")
lines(time_use, colMeans(s_predict), type = "l", lty = 1, col = "blue")
lines(time_use, s_predict_lower, type = "l", lty = 2, col = "blue")
lines(time_use, s_predict_upper, type = "l", lty = 2, col = "blue")

legend("bottomleft", legend = c("Data", "True model", "Posterior predictive mean", "95% prediction interval"), lty=c(1,1,1,2), col = c("black", "red", "blue", "blue"), pch=c(NA,NA,NA,NA))