---
title: "Multinomial Logistic Regression with Heavy-Tailed Priors"
author: "Longhai Li and Steven Liu"
date: "`r Sys.Date()`"
bibliography: HTLR.bib
output:
  rmarkdown::html_document:
    toc: true
    toc_float: false
    toc_depth: 4
    number_sections: true
vignette: >
  %\VignetteIndexEntry{intro}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
```

# Data Generation

Load the necessary libraries:

```{r setup}
library(HTLR)
library(bayesplot)
```

The description of the dataset generating scheme is found from @li2018fully.

There are 4 groups of features:

* feature #1: marginally related feature
 
* feature #2: marginally unrelated feature, but feature #2 is correlated with feature #1

* feature #3 - #10: marginally related features and also internally correlated

* feature #11 - #2000: noise features without relationship with the y

```{r}
SEED <- 1234

n <- 510
p <- 2000

means <- rbind(
  c(0, 1, 0),
  c(0, 0, 0),
  c(0, 0, 1),
  c(0, 0, 1),
  c(0, 0, 1),
  c(0, 0, 1),
  c(0, 0, 1),
  c(0, 0, 1),
  c(0, 0, 1),
  c(0, 0, 1)
) * 2

means <- rbind(means, matrix(0, p - 10, 3))

A <- diag(1, p)

A[1:10, 1:3] <-
  rbind(
    c(1, 0, 0),
    c(2, 1, 0),
    c(0, 0, 1),
    c(0, 0, 1),
    c(0, 0, 1),
    c(0, 0, 1),
    c(0, 0, 1),
    c(0, 0, 1),
    c(0, 0, 1),
    c(0, 0, 1)
  )

set.seed(SEED)
dat <- gendata_FAM(n, means, A, sd_g = 0.5, stdx = TRUE)
str(dat)
```

Look at the correlation between features:
```{r}
# require(corrplot)
cor(dat$X[ , 1:11]) %>% corrplot::corrplot(tl.pos = "n")
```

Split the data into training and testing sets:
```{r}
set.seed(SEED)
dat <- split_data(dat$X, dat$y, n.train = 500)
str(dat)
```

# Model Fitting

Fit a HTLR model with all default settings:
```{r}
set.seed(SEED)
system.time(
  fit.t <- htlr(dat$x.tr, dat$y.tr)
)
print(fit.t)
```

With another configuration:
```{r}
set.seed(SEED)
system.time(
  fit.t2 <- htlr(X = dat$x.tr, y = dat$y.tr, 
                 prior = htlr_prior("t", df = 1, logw = -20, sigmab0 = 1500), 
                 iter = 4000, init = "bcbc", keep.warmup.hist = T)
)
print(fit.t2)
```

# Model Inspection 

Look at the point summaries of posterior of selected parameters:
```{r}
summary(fit.t2, features = c(1:10, 100, 200, 1000, 2000), method = median)
```

Plot interval estimates from posterior draws using [bayesplot](https://mc-stan.org/bayesplot/index.html):
```{r}
post.t <- as.matrix(fit.t2, k = 2)
## signal parameters
mcmc_intervals(post.t, pars = c("Intercept", "V1", "V2", "V3", "V1000"))
```

Trace plot of MCMC draws:
```{r}
as.matrix(fit.t2, k = 2, include.warmup = T) %>%
  mcmc_trace(c("V1", "V1000"), facet_args = list("nrow" = 2), n_warmup = 2000)
```

The coefficient of unrelated features (noise) are not updated during some iterations due to restricted Gibbs sampling @li2018fully, hence the computational cost is greatly reduced.

# Make Predictions

A glance at the prediction accuracy:
```{r}
y.class <- predict(fit.t, dat$x.te, type = "class")
y.class
print(paste0("prediction accuracy of model 1 = ", 
             sum(y.class == dat$y.te) / length(y.class)))

y.class2 <- predict(fit.t2, dat$x.te, type = "class")
print(paste0("prediction accuracy of model 2 = ", 
             sum(y.class2 == dat$y.te) / length(y.class)))

```

More details about the prediction result: 
```{r}
predict(fit.t, dat$x.te, type = "response") %>%
  evaluate_pred(y.true = dat$y.te)
```