## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(pcpr)

## ----helper-------------------------------------------------------------------
library(magrittr) # for the pipe %>%
library(ggplot2) # for plotting

plot_matrix <- function(D, ..., lp = "none", title = NULL) {
  D <- t(D)
  if (is.null(colnames(D))) colnames(D) <- paste0("C", 1:ncol(D))
  data.frame(D) %>%
    dplyr::mutate(x = paste0("R", 1:nrow(.))) %>%
    tidyr::pivot_longer(tidyselect::all_of(colnames(D)), names_to = "y", values_to = "value") %>%
    dplyr::mutate(x = factor(x, levels = unique(x)), y = factor(y, levels = unique(y))) %>%
    ggplot(aes(x = x, y = y, fill = value)) +
    geom_raster() +
    scale_y_discrete(limits = rev) +
    coord_equal() +
    scale_fill_viridis_c(na.value = "white", ...) +
    theme_minimal() +
    theme(
      axis.text.x = element_blank(),
      axis.ticks.x = element_blank(),
      axis.text.y = element_blank(),
      axis.ticks.y = element_blank(),
      axis.title.x = element_blank(),
      axis.title.y = element_blank(),
      legend.position = lp,
      plot.margin = margin(0, 0, 0, 0),
      aspect.ratio = 1
    ) +
    ggtitle(title)
}

## ----sim data-----------------------------------------------------------------
data <- sim_data()
D <- data$D
L_0 <- data$L
S_0 <- data$S
Z_0 <- data$Z

## ----mat viz------------------------------------------------------------------
plot_matrix(D)
plot_matrix(L_0)
plot_matrix(S_0)
plot_matrix(Z_0)

## ----matrix rank--------------------------------------------------------------
matrix_rank(L_0)
matrix_rank(D)

## ----lod----------------------------------------------------------------------
lod_info <- sim_lod(D, q = 0.1)
D_lod <- lod_info$D_tilde
lod <- lod_info$lod
lod

## ----corrupt mat randomly-----------------------------------------------------
corrupted_data <- sim_na(D_lod, perc = 0.05)
D_tilde <- corrupted_data$D_tilde
lod_root2 <- matrix(
  lod / sqrt(2),
  nrow = nrow(D_tilde),
  ncol = ncol(D_tilde), byrow = TRUE
)
lod_idxs <- which(lod_info$tilde_mask == 1)
D_tilde[lod_idxs] <- lod_root2[lod_idxs]
plot_matrix(D_tilde)

## ----sing---------------------------------------------------------------------
D_imputed <- impute_matrix(D_tilde, apply(D_tilde, 2, mean, na.rm = TRUE))
singular_values <- sing(D_imputed)
plot(singular_values, type = "b")

## ----gs, eval = FALSE, echo = TRUE--------------------------------------------
# eta_0 <- get_pcp_defaults(D_tilde)$eta
# etas <- data.frame("eta" = sort(c(0.1 * eta_0, eta_0 * seq(1, 10, 2))))
# # to get progress bar, could wrap this
# # in a call to progressr::with_progress({ gs <- grid_search_cv(...) })
# gs <- grid_search_cv(
#   D_tilde,
#   pcp_fn = rrmc,
#   grid = etas, r = 5, LOD = lod,
#   parallel_strategy = "multisession",
#   num_workers = 16,
#   verbose = FALSE
# )

## ----real gs, eval = TRUE, echo = FALSE---------------------------------------
gs <- readRDS("rds_files/quickstart-gs.rds")

## ----gs results---------------------------------------------------------------
r_star <- gs$summary_stats$r[1]
eta_star <- round(gs$summary_stats$eta[1], 3)
gs$summary_stats

## ----rrmc---------------------------------------------------------------------
pcp_model <- rrmc(D_tilde, r = r_star, eta = eta_star, LOD = lod)

## ----obj----------------------------------------------------------------------
plot(pcp_model$objective, type = "l")

## ----output L-----------------------------------------------------------------
plot_matrix(pcp_model$L)
matrix_rank(pcp_model$L)

## ----sparse-------------------------------------------------------------------
hist(pcp_model$S)
pcp_model$S <- hard_threshold(pcp_model$S, thresh = 0.4)
plot_matrix(pcp_model$S)
sparsity(pcp_model$S)

## ----pca----------------------------------------------------------------------
L_pca <- proj_rank_r(D_imputed, r = r_star)

## ----performance metrics------------------------------------------------------
data.frame(
  "Obs_rel_err" = norm(L_0 - D_imputed, "F") / norm(L_0, "F"),
  "PCA_L_rel_err" = norm(L_0 - L_pca, "F") / norm(L_0, "F"),
  "PCP_L_rel_err" = norm(L_0 - pcp_model$L, "F") / norm(L_0, "F"),
  "PCP_S_rel_err" = norm(S_0 - pcp_model$S, "F") / norm(S_0, "F"),
  "PCP_L_rank" = matrix_rank(pcp_model$L),
  "PCP_S_sparsity" = sparsity(pcp_model$S)
)