## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
if (utils::packageVersion("scoringutils") < "2.0.0") {
    stop("The 'scoringutils' package version 2.0.0 or higher is required. Please update it.")
}
if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("The vignette requires the 'ggplot2' package. Please install it to build the vignette.")
}

## ----setup, warning=FALSE, message=FALSE--------------------------------------
library(NobBS)
library(dplyr)
library(ggplot2)
library(scoringutils)
data(mpoxdat)

win_size <- 28
now <- as.Date("2022-08-15")
test_dates <- seq(now, now, by = 1)

# Filter data for current "now" date
current_data <- mpoxdat[mpoxdat$dx_date <= now, ]

# Run nowcasts and store results
nowcast_without_dow <- NobBS(
  current_data, now, units = "1 day",
  onset_date = "dx_date", report_date = "dx_report_date",
  moving_window = win_size,
  specs=list(nAdapt=2000))

nowcast_with_dow <- NobBS(
  current_data, now, units = "1 day",
  onset_date = "dx_date", report_date = "dx_report_date",
  moving_window = win_size, 
  specs=list(nAdapt=2000),
  add_dow_cov = TRUE,
  )

nowcasts_without_dow <- list()
nowcasts_with_dow <- list()
nowcasts_without_dow[[1]] <- nowcast_without_dow
nowcasts_with_dow[[1]] <- nowcast_with_dow


## ----plot_estimates, fig.width=8, fig.height=4--------------------------------

plot_estimates <- function(nowcast1, nowcast2, cases_per_date, now) {
  # Ensure input data is valid
  if (is.null(nowcast1$estimates) || is.null(nowcast2$estimates)) {
    stop("Nowcast estimates are missing.")
  }
  
  # Extract estimates and credible intervals for nowcast without DoW effect
  onset_dates1 <- nowcast1$estimates$onset_date
  estimates1 <- nowcast1$estimates$estimate
  lower1 <- nowcast1$estimates$q_0.025  # 2.5% quantile (lower bound of 95% PI)
  upper1 <- nowcast1$estimates$q_0.975  # 97.5% quantile (upper bound of 95% PI)
  
  # Extract estimates and prediction intervals for nowcast with DoW effect
  onset_dates2 <- nowcast2$estimates$onset_date
  estimates2 <- nowcast2$estimates$estimate
  lower2 <- nowcast2$estimates$q_0.025
  upper2 <- nowcast2$estimates$q_0.975

  # Extract eventual case counts
  case_dates <- cases_per_date$dx_date
  case_counts <- cases_per_date$count
  
  # Calculate plot range
  min_val <- min(c(lower1, lower2, case_counts), na.rm = TRUE)
  max_val <- max(c(upper1, upper2, case_counts), na.rm = TRUE)

  # Create the plot
  plot(
    onset_dates1, estimates1, col = 'blue', type = 'l',
    xlab = 'Onset Date', ylab = 'Cases',
    ylim = c(min_val, max_val), lwd = 2,
    main = paste0('Incidence Estimates for ', weekdays(now), ' ', now)
  )
  lines(onset_dates2, estimates2, col = 'red', lwd = 2)

  # Add 95% PI shaded regions for both nowcasts
  polygon(c(onset_dates1, rev(onset_dates1)), c(lower1, rev(upper1)), 
          col = rgb(0, 0, 1, 0.2), border = NA) 
  polygon(c(onset_dates2, rev(onset_dates2)), c(lower2, rev(upper2)), 
          col = rgb(1, 0, 0, 0.2), border = NA) 

  # Add true case counts as points
  points(case_dates, case_counts, col = 'black', pch = 20)
  
  # Add a legend
  legend(
    'topleft',
    legend = c('Estimates without DoW effect', 
               'Estimates with DoW effect', 
               '95% PI (No DoW Effect)', 
               '95% PI (With DoW Effect)',
               'Eventual cases'),
    col = c('blue', 'red', rgb(0, 0, 1, 0.2), rgb(1, 0, 0, 0.2), 'black'),
    lty = c(1, 1, NA, NA, NA), lwd = c(2, 2, NA, NA, NA), 
    pch = c(NA, NA, 15, 15, 20),
    pt.cex = c(NA, NA, 1.5, 1.5, 1), 
    cex = 0.9
  )
}

# Calculate true case counts
current_data <- mpoxdat[mpoxdat$dx_date <= now, ]
cases_per_date <- current_data %>%
  group_by(dx_date) %>%
  summarize(count = n())

# plot a comparison of the nowcast estimates
plot_estimates(nowcast_without_dow, nowcast_with_dow, cases_per_date, now)


## ----gammas_analysis, fig.width=8, fig.height=4-------------------------------

# Specify correct ordering for days of the week
weekdays_order <- c("Mon", "Tue", "Wed", "Thu", "Fri", "Sat")

# Function to extract DoW effect for a given nowcast (mean + 95% CI from posterior samples)
extract_dow_effect <- function(nowcast) {
  gammas <- numeric(6)  # Only 6 because Sunday is ref
  gamma_lower <- numeric(6)
  gamma_upper <- numeric(6)
  
  for (i in 1:6) {
    param_name <- paste0("gamma[", i, "]")
    
    # Extract posterior samples
    posterior_samples <- exp(nowcast$params.post[[param_name]])
    
    # Compute mean and credible intervals
    gammas[i] <- mean(posterior_samples)  # Mean estimate
    gamma_lower[i] <- quantile(posterior_samples, 0.025)  # Lower 95% CI
    gamma_upper[i] <- quantile(posterior_samples, 0.975)  # Upper 95% CI
  }
  
  return(list(means = gammas, lower = gamma_lower, upper = gamma_upper))
}

dow_effects <- extract_dow_effect(nowcast_with_dow)
  
dow_df <- data.frame(
    Day = factor(weekdays_order, levels = weekdays_order), 
    Mean = dow_effects$means,
    Lower = dow_effects$lower,
    Upper = dow_effects$upper
)
  
plot_title <- paste("Estimates of DoW effect for nowcast performed on ",now)
  
p <- ggplot(dow_df, aes(x = Day, y = Mean)) +
  geom_point(size = 3, color = "blue") +
  geom_errorbar(aes(ymin = Lower, ymax = Upper), width = 0.2, color = "black") +
  labs(title = plot_title, y = "DoW Effect on Expected Cases\n(Compared to Sunday)\n", x = "") +
  theme_minimal()

print(dow_df)
print(p)

## ----employing_priors, warning=FALSE, message=FALSE, fig.width=8, fig.height=4----

prior_mean <- log(dow_effects$means)
prior_sd <- (dow_effects$upper-dow_effects$lower)/(2*1.96) 
prior_prec <- 1/(prior_sd^2)

test_dates <- seq(as.Date("2022-08-16"), as.Date("2022-08-19"), by = 1)
win_size <- 14

# Initialize lists to store nowcasts
nowcasts_without_dow <- list() # Nowcasts without DoW effect
nowcasts_with_dow <- list()    # Nowcasts with DoW effect

# Loop through each "now" date and run nowcasting
for (t in seq_along(test_dates)) {
  now <- test_dates[t]
  
  # Filter data for current "now" date
  current_data <- mpoxdat[mpoxdat$dx_date <= now, ]
  
  # Run nowcasts and store results
  nowcasts_without_dow[[t]] <- NobBS(
    current_data, now, units = "1 day",
    onset_date = "dx_date", report_date = "dx_report_date",
    moving_window = win_size,
  )
  
  nowcasts_with_dow[[t]] <- NobBS(
    current_data, now, units = "1 day",
    onset_date = "dx_date", report_date = "dx_report_date",
    moving_window = win_size, 
    specs=list(gamma.mean.prior=prior_mean,gamma.prec.prior=prior_prec),
    add_dow_cov = TRUE
  )
}

# Loop through each "now" date and plot a comparison of the nowcast estimates
# with and without the DoW effect
for (t in seq_along(test_dates)) {
  now <- test_dates[t]
  nowcast1 <- nowcasts_without_dow[[t]]
  nowcast2 <- nowcasts_with_dow[[t]]
  
  # Calculate true case counts
  current_data <- mpoxdat[mpoxdat$dx_date <= now, ]
  cases_per_date <- current_data %>%
    group_by(dx_date) %>%
    summarize(count = n())

  plot_estimates(nowcast1, nowcast2, cases_per_date, now)
}

## ----wis_calculation, warning=FALSE, message=FALSE, fig.width=8, fig.height=4----

quantiles <- c(0.025,0.25,0.5,0.75,0.975)
q_len <- length(quantiles)
q_cols <- paste0('q_',quantiles)

data <- data.frame(onset_week=as.Date(character()),
                   now=as.Date(character()),
                   horizon=numeric(),
                   quantile_level=numeric(),
                   predicted=numeric(),
                   observed=numeric(),
                   model=character())

cases_per_date <- mpoxdat %>%
    group_by(dx_date) %>%
    summarize(count = n())

horizons <- c(-5,-4,-3,-2,-1,0)

for (t in seq_along(test_dates)) {
  now <- test_dates[t]
  nowcast1 <- nowcasts_without_dow[[t]]
  nowcast2 <- nowcasts_with_dow[[t]]
  for(h in horizons) {
    date <- now + h
    true_value <- cases_per_date[cases_per_date$dx_date==date,]$count
    q_est1 <- unname(unlist(nowcast1$estimates[nowcast1$estimates$onset_date==date,q_cols]))
    q_est2 <- unname(unlist(nowcast2$estimates[nowcast2$estimates$onset_date==date,q_cols]))
    data_est <- data.frame(onset_week=rep(date,q_len*2),
                           now=rep(now,q_len*2),
                           horizon=rep(h,q_len*2),
                           quantile_level=rep(quantiles,2),
                           predicted=c(q_est1,q_est2),
                           observed=rep(true_value,q_len*2),
                           model=c(rep('Without DoW',q_len),rep('With DoW',q_len)))
    data <- rbind(data,data_est)
  }
}

nowcasts <- data %>%
          scoringutils::as_forecast_quantile()
scores <- scoringutils::score(nowcasts, 
             get_metrics(nowcasts,select=c("wis","overprediction","underprediction","dispersion")))

scores_per_model <- scores %>%
          summarise_scores(by = c("model"))

print(scores_per_model)
plot_wis(scores_per_model) + 
  ggtitle('WIS with and without DoW')

## ----dow_estimates2, warning=FALSE, message=FALSE, fig.width=8, fig.height=4----

# Extract DoW effects for each nowcast
for (t in seq_along(test_dates)) {
  now <- test_dates[t]
  
  # Retrieve nowcast with DoW effect
  nowcast_with_dow2 <- nowcasts_with_dow[[t]]
  
  # Extract estimates
  dow_effects2 <- extract_dow_effect(nowcast_with_dow2)
  
  dow_df2 <- data.frame(
    Day = factor(weekdays_order, levels = weekdays_order), 
    Mean = dow_effects2$means,
    Lower = dow_effects2$lower,
    Upper = dow_effects2$upper
  )

  # Add a Type column to distinguish priors and posteriors
  dow_df$Type <- "Prior"
  dow_df2$Type <- "Posterior"

  # Combine the two datasets
  combined_dow_df <- rbind(dow_df, dow_df2)
  # Ensure the Type column is a factor with the desired order
  combined_dow_df$Type <- factor(combined_dow_df$Type, levels = c("Prior", "Posterior"))

  plot_title <- paste("Estimates of DoW effect for nowcast performed on", weekdays(now), now)
  
  # Plot with dodged positioning for side-by-side visualization
  p <- ggplot(combined_dow_df, aes(x = Day, y = Mean, color = Type, group = Type)) +
    geom_point(position = position_dodge(width = 0.4), size = 3) +
    geom_errorbar(aes(ymin = Lower, ymax = Upper), width = 0.2, position = position_dodge(width = 0.4)) +
    labs(title = plot_title, y = "DoW Effect on Expected Cases\n(Compared to Sunday)\n", x = "") +
    ylim(0.5, 4) +
    theme_minimal() +
    scale_color_manual(values = c("Prior" = "blue", "Posterior" = "red")) +
    theme(legend.position = "top")  

  print(p)
}