Same/Other/All resampler

The goal of this vignette is to explain the older resamplers: ResamplingVariableSizeTrainCV and ResamplingSameOtherCV, which output some data which are useful for visualizing the train/test splits. If you do not want to visualize the train/test splits, then it is recommended to instead use the newere resampler, ResamplingSameOtherSizesCV (see other vignette).

Same/Other/All resampler

The goal of thie section is to explain how to quantify the extent to which it is possible to train on one data subset, and predict on another data subset. This kind of problem occurs frequently in many different problem domains:

The ideas are similar to my previous blog posts about how to do this in python and R. Below we explain how to use mlr3resampling for this purpose, in simulated regression and classification problems. To use this method in real data, the important sections to read below are named “Benchmark: computing test error,” which show how to create these cross-validation experiments using mlr3 code.

Simulated regression problems

We begin by generating some data which can be used with regression algorithms. Assume there is a data set with some rows from one person, some rows from another,

N <- 300
library(data.table)
set.seed(1)
abs.x <- 2
reg.dt <- data.table(
  x=runif(N, -abs.x, abs.x),
  person=rep(1:2, each=0.5*N))
reg.pattern.list <- list(
  easy=function(x, person)x^2,
  impossible=function(x, person)(x^2+person*3)*(-1)^person)
reg.task.list <- list()
for(task_id in names(reg.pattern.list)){
  f <- reg.pattern.list[[task_id]]
  yname <- paste0("y_",task_id)
  reg.dt[, (yname) := f(x,person)+rnorm(N)][]
  task.dt <- reg.dt[, c("x","person",yname), with=FALSE]
  reg.task <- mlr3::TaskRegr$new(
    task_id, task.dt, target=yname)
  reg.task$col_roles$subset <- "person"
  reg.task$col_roles$stratum <- "person"
  reg.task$col_roles$feature <- "x"
  reg.task.list[[task_id]] <- reg.task
}
reg.dt
#>               x person      y_easy y_impossible
#>           <num>  <int>       <num>        <num>
#>   1: -0.9379653      1  1.32996609    -2.918082
#>   2: -0.5115044      1  0.24307692    -3.866062
#>   3:  0.2914135      1 -0.23314657    -3.837799
#>   4:  1.6328312      1  1.73677545    -7.221749
#>   5: -1.1932723      1 -0.06356159    -5.877792
#>  ---                                           
#> 296:  0.7257701      2 -2.48130642     5.180948
#> 297: -1.6033236      2  1.20453459     9.604312
#> 298: -1.5243898      2  1.89966190     7.511988
#> 299: -1.7982414      2  3.47047566    11.035397
#> 300:  1.7170157      2  0.60541972    10.719685

The table above shows some simulated data for two regression problems:

Static visualization of simulated data

First we reshape the data using the code below,

(reg.tall <- nc::capture_melt_single(
  reg.dt,
  task_id="easy|impossible",
  value.name="y"))
#>               x person    task_id           y
#>           <num>  <int>     <char>       <num>
#>   1: -0.9379653      1       easy  1.32996609
#>   2: -0.5115044      1       easy  0.24307692
#>   3:  0.2914135      1       easy -0.23314657
#>   4:  1.6328312      1       easy  1.73677545
#>   5: -1.1932723      1       easy -0.06356159
#>  ---                                         
#> 596:  0.7257701      2 impossible  5.18094849
#> 597: -1.6033236      2 impossible  9.60431191
#> 598: -1.5243898      2 impossible  7.51198770
#> 599: -1.7982414      2 impossible 11.03539747
#> 600:  1.7170157      2 impossible 10.71968480

The table above is a more convenient form for the visualization which we create using the code below,

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x, y),
      data=reg.tall)+
    facet_grid(
      task_id ~ person,
      labeller=label_both,
      space="free",
      scales="free")+
    scale_y_continuous(
      breaks=seq(-100, 100, by=2))
}
#> Le chargement a nécessité le package : animint2
#> Registered S3 methods overwritten by 'animint2':
#>   method                   from   
#>   [.uneval                 ggplot2
#>   drawDetails.zeroGrob     ggplot2
#>   grid.draw.absoluteGrob   ggplot2
#>   grobHeight.absoluteGrob  ggplot2
#>   grobHeight.zeroGrob      ggplot2
#>   grobWidth.absoluteGrob   ggplot2
#>   grobWidth.zeroGrob       ggplot2
#>   grobX.absoluteGrob       ggplot2
#>   grobY.absoluteGrob       ggplot2
#>   heightDetails.titleGrob  ggplot2
#>   heightDetails.zeroGrob   ggplot2
#>   makeContext.dotstackGrob ggplot2
#>   print.element            ggplot2
#>   print.ggplot2_bins       ggplot2
#>   print.rel                ggplot2
#>   print.theme              ggplot2
#>   print.uneval             ggplot2
#>   widthDetails.titleGrob   ggplot2
#>   widthDetails.zeroGrob    ggplot2
#> 
#> Attachement du package : 'animint2'
#> Les objets suivants sont masqués depuis 'package:ggplot2':
#> 
#>     %+%, %+replace%, Coord, CoordCartesian, CoordFixed, CoordFlip,
#>     CoordMap, CoordPolar, CoordQuickmap, CoordTrans, Geom, GeomAbline,
#>     GeomAnnotationMap, GeomArea, GeomBar, GeomBlank, GeomContour,
#>     GeomCrossbar, GeomCurve, GeomCustomAnn, GeomDensity, GeomDensity2d,
#>     GeomDotplot, GeomErrorbar, GeomErrorbarh, GeomHex, GeomHline,
#>     GeomLabel, GeomLine, GeomLinerange, GeomLogticks, GeomMap,
#>     GeomPath, GeomPoint, GeomPointrange, GeomPolygon, GeomRaster,
#>     GeomRasterAnn, GeomRect, GeomRibbon, GeomRug, GeomSegment,
#>     GeomSmooth, GeomSpoke, GeomStep, GeomText, GeomTile, GeomViolin,
#>     GeomVline, Position, PositionDodge, PositionFill, PositionIdentity,
#>     PositionJitter, PositionJitterdodge, PositionNudge, PositionStack,
#>     Scale, ScaleContinuous, ScaleContinuousDate,
#>     ScaleContinuousDatetime, ScaleContinuousIdentity,
#>     ScaleContinuousPosition, ScaleDiscrete, ScaleDiscreteIdentity,
#>     ScaleDiscretePosition, Stat, StatBin, StatBin2d, StatBindot,
#>     StatBinhex, StatContour, StatCount, StatDensity, StatDensity2d,
#>     StatEcdf, StatEllipse, StatFunction, StatIdentity, StatQq,
#>     StatSmooth, StatSum, StatSummary, StatSummary2d, StatSummaryBin,
#>     StatSummaryHex, StatUnique, StatYdensity, aes, aes_, aes_all,
#>     aes_auto, aes_q, aes_string, annotate, annotation_custom,
#>     annotation_logticks, annotation_map, annotation_raster,
#>     as_labeller, autoplot, benchplot, borders, calc_element,
#>     continuous_scale, coord_cartesian, coord_equal, coord_fixed,
#>     coord_flip, coord_map, coord_munch, coord_polar, coord_quickmap,
#>     coord_trans, cut_interval, cut_number, cut_width, discrete_scale,
#>     draw_key_abline, draw_key_blank, draw_key_crossbar,
#>     draw_key_dotplot, draw_key_label, draw_key_path, draw_key_point,
#>     draw_key_pointrange, draw_key_polygon, draw_key_rect,
#>     draw_key_smooth, draw_key_text, draw_key_vline, draw_key_vpath,
#>     economics, economics_long, element_blank, element_grob,
#>     element_line, element_rect, element_text, expand_limits,
#>     facet_grid, facet_null, facet_wrap, fortify, geom_abline,
#>     geom_area, geom_bar, geom_bin2d, geom_blank, geom_contour,
#>     geom_count, geom_crossbar, geom_curve, geom_density,
#>     geom_density2d, geom_density_2d, geom_dotplot, geom_errorbar,
#>     geom_errorbarh, geom_freqpoly, geom_hex, geom_histogram,
#>     geom_hline, geom_jitter, geom_label, geom_line, geom_linerange,
#>     geom_map, geom_path, geom_point, geom_pointrange, geom_polygon,
#>     geom_qq, geom_raster, geom_rect, geom_ribbon, geom_rug,
#>     geom_segment, geom_smooth, geom_spoke, geom_step, geom_text,
#>     geom_tile, geom_violin, geom_vline, gg_dep, ggplot, ggplotGrob,
#>     ggplot_build, ggplot_gtable, ggsave, ggtitle, guide_colorbar,
#>     guide_colourbar, guide_legend, guides, is.Coord, is.facet,
#>     is.ggplot, is.theme, label_both, label_bquote, label_context,
#>     label_parsed, label_value, label_wrap_gen, labeller, labs,
#>     last_plot, layer, layer_data, layer_grob, layer_scales, lims,
#>     map_data, margin, mean_cl_boot, mean_cl_normal, mean_sdl, mean_se,
#>     median_hilow, position_dodge, position_fill, position_identity,
#>     position_jitter, position_jitterdodge, position_nudge,
#>     position_stack, presidential, qplot, quickplot, rel,
#>     remove_missing, resolution, scale_alpha, scale_alpha_continuous,
#>     scale_alpha_discrete, scale_alpha_identity, scale_alpha_manual,
#>     scale_color_brewer, scale_color_continuous, scale_color_discrete,
#>     scale_color_distiller, scale_color_gradient, scale_color_gradient2,
#>     scale_color_gradientn, scale_color_grey, scale_color_hue,
#>     scale_color_identity, scale_color_manual, scale_colour_brewer,
#>     scale_colour_continuous, scale_colour_date, scale_colour_datetime,
#>     scale_colour_discrete, scale_colour_distiller,
#>     scale_colour_gradient, scale_colour_gradient2,
#>     scale_colour_gradientn, scale_colour_grey, scale_colour_hue,
#>     scale_colour_identity, scale_colour_manual, scale_fill_brewer,
#>     scale_fill_continuous, scale_fill_date, scale_fill_datetime,
#>     scale_fill_discrete, scale_fill_distiller, scale_fill_gradient,
#>     scale_fill_gradient2, scale_fill_gradientn, scale_fill_grey,
#>     scale_fill_hue, scale_fill_identity, scale_fill_manual,
#>     scale_linetype, scale_linetype_continuous, scale_linetype_discrete,
#>     scale_linetype_identity, scale_linetype_manual, scale_radius,
#>     scale_shape, scale_shape_continuous, scale_shape_discrete,
#>     scale_shape_identity, scale_shape_manual, scale_size,
#>     scale_size_area, scale_size_continuous, scale_size_date,
#>     scale_size_datetime, scale_size_discrete, scale_size_identity,
#>     scale_size_manual, scale_x_continuous, scale_x_date,
#>     scale_x_datetime, scale_x_discrete, scale_x_log10, scale_x_reverse,
#>     scale_x_sqrt, scale_y_continuous, scale_y_date, scale_y_datetime,
#>     scale_y_discrete, scale_y_log10, scale_y_reverse, scale_y_sqrt,
#>     should_stop, stat_bin, stat_bin2d, stat_bin_2d, stat_bin_hex,
#>     stat_binhex, stat_contour, stat_count, stat_density,
#>     stat_density2d, stat_density_2d, stat_ecdf, stat_ellipse,
#>     stat_function, stat_identity, stat_qq, stat_smooth, stat_spoke,
#>     stat_sum, stat_summary, stat_summary2d, stat_summary_2d,
#>     stat_summary_bin, stat_summary_hex, stat_unique, stat_ydensity,
#>     theme, theme_bw, theme_classic, theme_dark, theme_get, theme_gray,
#>     theme_grey, theme_light, theme_linedraw, theme_minimal,
#>     theme_replace, theme_set, theme_update, theme_void,
#>     transform_position, update_geom_defaults, update_labels,
#>     update_stat_defaults, waiver, xlab, xlim, ylab, ylim, zeroGrob

plot of chunk unnamed-chunk-3

In the simulated data above, we can see that

Benchmark: computing test error

In the code below, we define a K-fold cross-validation experiment.

(reg_same_other <- mlr3resampling::ResamplingSameOtherCV$new())
#> <ResamplingSameOtherCV> : Same versus Other Cross-Validation
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 1
#>  $ folds: int 3

In the code below, we define two learners to compare,

(reg.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerRegrRpart$new(),
  mlr3::LearnerRegrFeatureless$new()))
#> [[1]]
#> <LearnerRegrRpart:regr.rpart>: Regression Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response]
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, selected_features, weights
#> 
#> [[2]]
#> <LearnerRegrFeatureless:regr.featureless>: Featureless Regression Learner
#> * Model: -
#> * Parameters: robust=FALSE
#> * Packages: mlr3, stats
#> * Predict Types:  [response], se
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, importance, missings, selected_features

In the code below, we define the benchmark grid, which is all combinations of tasks (easy and impossible), learners (rpart and featureless), and the one resampling method.

(reg.bench.grid <- mlr3::benchmark_grid(
  reg.task.list,
  reg.learner.list,
  reg_same_other))
#>          task          learner    resampling
#>        <char>           <char>        <char>
#> 1:       easy       regr.rpart same_other_cv
#> 2:       easy regr.featureless same_other_cv
#> 3: impossible       regr.rpart same_other_cv
#> 4: impossible regr.featureless same_other_cv

In the code below, we execute the benchmark experiment (in parallel using the multisession future plan).

if(FALSE){#for CRAN.
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
#> Le chargement a nécessité le package : lgr
#> 
#> Attachement du package : 'lgr'
#> L'objet suivant est masqué depuis 'package:ggplot2':
#> 
#>     Layout
(reg.bench.result <- mlr3::benchmark(
  reg.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 72 rows with 4 resampling runs
#>  nr    task_id       learner_id resampling_id iters warnings errors
#>   1       easy       regr.rpart same_other_cv    18        0      0
#>   2       easy regr.featureless same_other_cv    18        0      0
#>   3 impossible       regr.rpart same_other_cv    18        0      0
#>   4 impossible regr.featureless same_other_cv    18        0      0

The code below computes the test error for each split,

reg.bench.score <- mlr3resampling::score(reg.bench.result)
reg.bench.score[1]
#>    train.subsets test.fold test.subset person iteration                  test
#>           <char>     <int>       <int>  <int>     <int>                <list>
#> 1:           all         1           1      1         1  1, 3, 5, 6,12,13,...
#>                    train                                uhash    nr
#>                   <list>                               <char> <int>
#> 1:  4, 7, 9,10,18,20,... 7eb31050-513c-4ba8-8551-1616e5820596     1
#>               task task_id                       learner learner_id
#>             <list>  <char>                        <list>     <char>
#> 1: <TaskRegr:easy>    easy <LearnerRegrRpart:regr.rpart> regr.rpart
#>                 resampling resampling_id       prediction regr.mse algorithm
#>                     <list>        <char>           <list>    <num>    <char>
#> 1: <ResamplingSameOtherCV> same_other_cv <PredictionRegr> 1.638015     rpart

The code below visualizes the resulting test accuracy numbers.

if(require(animint2)){
  ggplot()+
    scale_x_log10()+
    geom_point(aes(
      regr.mse, train.subsets, color=algorithm),
      shape=1,
      data=reg.bench.score)+
    facet_grid(
      task_id ~ person,
      labeller=label_both,
      scales="free")
}

plot of chunk unnamed-chunk-9

It is clear from the plot above that

Interactive visualization of data, test error, and splits

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

inst <- reg.bench.score$resampling[[1]]$instance
rect.expand <- 0.2
grid.dt <- data.table(x=seq(-abs.x, abs.x, l=101), y=0)
grid.task <- mlr3::TaskRegr$new("grid", grid.dt, target="y")
pred.dt.list <- list()
point.dt.list <- list()
for(score.i in 1:nrow(reg.bench.score)){
  reg.bench.row <- reg.bench.score[score.i]
  task.dt <- data.table(
    reg.bench.row$task[[1]]$data(),
    reg.bench.row$resampling[[1]]$instance$id.dt)
  names(task.dt)[1] <- "y"
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=reg.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ]
  point.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(task_id, iteration)],
    i.points)
  i.learner <- reg.bench.row$learner[[1]]
  pred.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(
      task_id, iteration, algorithm
    )],
    as.data.table(
      i.learner$predict(grid.task)
    )[, .(x=grid.dt$x, y=response)]
  )
}
(pred.dt <- rbindlist(pred.dt.list))
#>          task_id iteration   algorithm     x        y
#>           <char>     <int>      <char> <num>    <num>
#>    1:       easy         1       rpart -2.00 3.557968
#>    2:       easy         1       rpart -1.96 3.557968
#>    3:       easy         1       rpart -1.92 3.557968
#>    4:       easy         1       rpart -1.88 3.557968
#>    5:       easy         1       rpart -1.84 3.557968
#>   ---                                                
#> 7268: impossible        18 featureless  1.84 7.204232
#> 7269: impossible        18 featureless  1.88 7.204232
#> 7270: impossible        18 featureless  1.92 7.204232
#> 7271: impossible        18 featureless  1.96 7.204232
#> 7272: impossible        18 featureless  2.00 7.204232
(point.dt <- rbindlist(point.dt.list))
#>           task_id iteration set.name row_id           y          x  fold person
#>            <char>     <int>   <char>  <int>       <num>      <num> <int>  <int>
#>     1:       easy         1     test      1  1.32996609 -0.9379653     1      1
#>     2:       easy         1    train      2  0.24307692 -0.5115044     3      1
#>     3:       easy         1     test      3 -0.23314657  0.2914135     1      1
#>     4:       easy         1    train      4  1.73677545  1.6328312     2      1
#>     5:       easy         1     test      5 -0.06356159 -1.1932723     1      1
#>    ---                                                                         
#> 21596: impossible        18    train    296  5.18094849  0.7257701     1      2
#> 21597: impossible        18    train    297  9.60431191 -1.6033236     1      2
#> 21598: impossible        18     test    298  7.51198770 -1.5243898     3      2
#> 21599: impossible        18    train    299 11.03539747 -1.7982414     1      2
#> 21600: impossible        18     test    300 10.71968480  1.7170157     3      2
#>        subset display_row
#>         <int>       <int>
#>     1:      1           1
#>     2:      1         101
#>     3:      1           2
#>     4:      1          51
#>     5:      1           3
#>    ---                   
#> 21596:      2         198
#> 21597:      2         199
#> 21598:      2         299
#> 21599:      2         200
#> 21600:      2         300
set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
make_person_subset <- function(DT){
  DT[, "person/subset" := person]
}
make_person_subset(point.dt)
make_person_subset(reg.bench.score)

if(require(animint2)){
  viz <- animint(
    title="Train/predict on subsets, regression",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme_animint(height=400)+
      scale_fill_manual(values=set.colors)+
      geom_point(aes(
        x, y, fill=set.name),
        showSelected="iteration",
        size=3,
        shape=21,
        data=point.dt)+
      scale_color_manual(values=algo.colors)+
      geom_line(aes(
        x, y, color=algorithm, subset=paste(algorithm, iteration)),
        showSelected="iteration",
        data=pred.dt)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        space="free",
        scales="free")+
      scale_y_continuous(
        breaks=seq(-100, 100, by=2)),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(height=400)+
      scale_y_log10(
        "Mean squared error on test set")+
      scale_fill_manual(values=algo.colors)+
      scale_x_discrete(
        "People/subsets in train set")+
      geom_point(aes(
        train.subsets, regr.mse, fill=algorithm),
        shape=1,
        size=5,
        stroke=2,
        color="black",
        color_off=NA,
        clickSelects="iteration",
        data=reg.bench.score)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        scales="free"),
    diagram=ggplot()+
      ggtitle("Select train/test split")+
      theme_bw()+
      theme_animint(height=300)+
      facet_grid(
        . ~ train.subsets,
        scales="free",
        space="free")+
      scale_size_manual(values=c(subset=3, fold=1))+
      scale_color_manual(values=c(subset="orange", fold="grey50"))+
      geom_rect(aes(
        xmin=-Inf, xmax=Inf,
        color=rows,
        size=rows,
        ymin=display_row, ymax=display_end),
        fill=NA,
        data=inst$viz.rect.dt)+
      scale_fill_manual(values=set.colors)+
      geom_rect(aes(
        xmin=iteration-rect.expand, ymin=display_row,
        xmax=iteration+rect.expand, ymax=display_end,
        fill=set.name),
        clickSelects="iteration",
        data=inst$viz.set.dt)+
      geom_text(aes(
        ifelse(rows=="subset", Inf, -Inf),
        (display_row+display_end)/2,
        hjust=ifelse(rows=="subset", 1, 0),
        label=paste0(rows, "=", ifelse(rows=="subset", subset, fold))),
        data=data.table(train.name="same", inst$viz.rect.dt))+
      scale_x_continuous(
        "Split number / cross-validation iteration")+
      scale_y_continuous(
        "Row number"),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
  viz
}
// Constructor for animint Object. var animint = function (to_select, json_file) { var default_axis_px = 16; function wait_until_then(timeout, condFun, readyFun) { var args=arguments function checkFun() { if(condFun()) { readyFun(args[3],args[4]); } else{ setTimeout(checkFun, timeout); } } checkFun(); } function convert_R_types(resp_array, types){ return resp_array.map(function (d) { for (var v_name in d) { if(!is_interactive_aes(v_name)){ var r_type = types[v_name]; if (r_type == "integer") { d[v_name] = parseInt(d[v_name]); } else if (r_type == "numeric") { d[v_name] = parseFloat(d[v_name]); } else if (r_type == "factor" || r_type == "rgb" || r_type == "linetype" || r_type == "label" || r_type == "character") { // keep it as a character } else if (r_type == "character" & v_name == "outliers") { d[v_name] = parseFloat(d[v_name].split(" @ ")); } } } return d; }); } // replacing periods in variable with an underscore this makes sure // that selector doesn't confuse . in name with css selectors function safe_name(unsafe_name){ return unsafe_name.replace(/[ .]/g, '_'); } function legend_class_name(selector_name){ return safe_name(selector_name) + "_variable"; } function is_interactive_aes(v_name){ if(v_name.indexOf("clickSelects") > -1){ return true; } if(v_name.indexOf("showSelected") > -1){ return true; } return false; } var linetypesize2dasharray = function (lt, size) { var isInt = function(n) { return typeof n === 'number' && parseFloat(n) == parseInt(n, 10) && !isNaN(n); }; if(isInt(lt)){ // R integer line types. if(lt == 1){ return null; } var o = { 0: size * 0 + "," + size * 10, 2: size * 4 + "," + size * 4, 3: size + "," + size * 2, 4: size + "," + size * 2 + "," + size * 4 + "," + size * 2, 5: size * 8 + "," + size * 4, 6: size * 2 + "," + size * 2 + "," + size * 6 + "," + size * 2 }; } else { // R defined line types if(lt == "solid" || lt === null){ return null; } var o = { "blank": size * 0 + "," + size * 10, "none": size * 0 + "," + size * 10, "dashed": size * 4 + "," + size * 4, "dotted": size + "," + size * 2, "dotdash": size + "," + size * 2 + "," + size * 4 + "," + size * 2, "longdash": size * 8 + "," + size * 4, "twodash": size * 2 + "," + size * 2 + "," + size * 6 + "," + size * 2, "22": size * 2 + "," + size * 2, "42": size * 4 + "," + size * 2, "44": size * 4 + "," + size * 4,"13": size + "," + size * 3, "1343": size + "," + size * 3 + "," + size * 4 + "," + size * 3, "73": size * 7 + "," + size * 3, "2262": size * 2 + "," + size * 2 + "," + size * 6 + "," + size * 2, "12223242": size + "," + size * 2 + "," + size * 2 + "," + size * 2 + "," + size * 3 + "," + size * 2 + "," + size * 4 + "," + size * 2, "F282": size * 15 + "," + size * 2 + "," + size * 8 + "," + size * 2, "F4448444": size * 15 + "," + size * 4 + "," + size * 4 + "," + size * 4 + "," + size * 8 + "," + size * 4 + "," + size * 4 + "," + size * 4, "224282F2": size * 2 + "," + size * 2 + "," + size * 4 + "," + size * 2 + "," + size * 8 + "," + size * 2 + "," + size * 16 + "," + size * 2, "F1": size * 16 + "," + size }; } if (lt in o){ return o[lt]; } else{ // manually specified line types str = lt.split(""); strnum = str.map(function (d) { return size * parseInt(d, 16); }); return strnum; } }; var isArray = function(o) { return Object.prototype.toString.call(o) === '[object Array]'; }; // create a dummy element, apply the appropriate classes, // and then measure the element // Inspired from http://jsfiddle.net/uzddx/2/ var measureText = function(pText, pFontSize, pAngle, pStyle) { if (!pText || pText.length === 0) return {height: 0, width: 0}; if (pAngle === null || isNaN(pAngle)) pAngle = 0; var container = element.append('svg'); // do we need to set the class so that styling is applied? //.attr('class', classname); container.append('text') .attr({x: -1000, y: -1000}) .attr("transform", "rotate(" + pAngle + ")") .attr("style", pStyle) .attr("font-size", pFontSize) .text(pText); var bbox = container.node().getBBox(); container.remove(); return {height: bbox.height, width: bbox.width}; }; var nest_by_group = d3.nest().key(function(d){ return d.group; }); var dirs = json_file.split("/"); dirs.pop(); //if a directory path exists, remove the JSON file from dirs var element = d3.select(to_select); this.element = element; var viz_id = element.attr("id"); var Widgets = {}; this.Widgets = Widgets; var Selectors = {}; this.Selectors = Selectors; var Plots = {}; this.Plots = Plots; var Geoms = {}; this.Geoms = Geoms; // SVGs must be stored separately from Geoms since they are // initialized first, with the Plots. var SVGs = {}; this.SVGs = SVGs; var Animation = {}; this.Animation = Animation; var all_geom_names = {}; this.all_geom_names = all_geom_names; //creating an array to contain the selectize widgets var selectized_array = []; var data_object_geoms = { "line":true, "path":true, "ribbon":true, "polygon":true }; var css = document.createElement('style'); css.type = 'text/css'; var styles = [".axis path{fill: none;stroke: black;shape-rendering: crispEdges;}", ".axis line{fill: none;stroke: black;shape-rendering: crispEdges;}", ".axis text {font-family: sans-serif;font-size: 11px;}"]; var add_geom = function (g_name, g_info) { // Determine if data will be an object or an array. if(g_info.geom in data_object_geoms){ g_info.data_is_object = true; }else{ g_info.data_is_object = false; } // Add a row to the loading table. g_info.tr = Widgets["loading"].append("tr"); g_info.tr.append("td").text(g_name); g_info.tr.append("td").attr("class", "chunk"); g_info.tr.append("td").attr("class", "downloaded").text(0); g_info.tr.append("td").text(g_info.total); g_info.tr.append("td").attr("class", "status").text("initialized"); // load chunk tsv g_info.data = {}; g_info.download_status = {}; Geoms[g_name] = g_info; // Determine whether common chunk tsv exists // If yes, load it if(g_info.hasOwnProperty("columns") && g_info.columns.common){ var common_tsv = get_tsv(g_info, "_common"); g_info.common_tsv = common_tsv; var common_path = getTSVpath(common_tsv); d3.tsv(common_path, function (error, response) { var converted = convert_R_types(response, g_info.types); g_info.data[common_tsv] = nest_by_group.map(converted); }); } else { g_info.common_tsv = null; } // Save this geom and load it! update_geom(g_name, null); }; var add_plot = function (p_name, p_info) { // Each plot may have one or more legends. To make space for the // legends, we put each plot in a table with one row and two // columns: tdLeft and tdRight. var plot_table = element.append("table").style("display", "inline-block"); var plot_tr = plot_table.append("tr"); var tdLeft = plot_tr.append("td"); var tdRight = plot_tr.append("td").attr("class", p_name+"_legend"); if(viz_id === null){ p_info.plot_id = p_name; }else{ p_info.plot_id = viz_id + "_" + p_name; } var svg = tdLeft.append("svg") .attr("id", p_info.plot_id) .attr("height", p_info.options.height) .attr("width", p_info.options.width); // divvy up width/height based on the panel layout var nrows = Math.max.apply(null, p_info.layout.ROW); var ncols = Math.max.apply(null, p_info.layout.COL); var panel_names = p_info.layout.PANEL; var npanels = Math.max.apply(null, panel_names); // Note axis names are "shared" across panels (just like the title) var xtitlepadding = 5 + measureText(p_info["xtitle"], default_axis_px).height; var ytitlepadding = 5 + measureText(p_info["ytitle"], default_axis_px).height; // 'margins' are fixed across panels and do not // include title/axis/label padding (since these are not // fixed across panels). They do, however, account for // spacing between panels var text_height_pixels = measureText("foo", 11).height; var margin = { left: 0, right: text_height_pixels * p_info.panel_margin_lines, top: text_height_pixels * p_info.panel_margin_lines, bottom: 0 }; var plotdim = { width: 0, height: 0, xstart: 0, xend: 0, ystart: 0, yend: 0, graph: { width: 0, height: 0 }, margin: margin, xlab: { x: 0, y: 0 }, ylab: { x: 0, y: 0 }, title: { x: 0, y: 0 } }; // Draw the title var titlepadding = measureText(p_info.title, p_info.title_size).height; // why are we giving the title padding if it is undefined? if (p_info.title === undefined) titlepadding = 0; plotdim.title.x = p_info.options.width / 2; plotdim.title.y = titlepadding; svg.append("text") .text(p_info.title) .attr("class", "plottitle") .attr("font-family", "sans-serif") .attr("font-size", p_info.title_size) .attr("transform", "translate(" + plotdim.title.x + "," + plotdim.title.y + ")") .style("text-anchor", "middle"); // grab max text size over axis labels and facet strip labels var axispaddingy = 5; if(p_info.hasOwnProperty("ylabs") && p_info.ylabs.length){ axispaddingy += Math.max.apply(null, p_info.ylabs.map(function(entry){ // + 5 to give a little extra space to avoid bad axis labels // in shiny. return measureText(entry, p_info.ysize).width + 5; })); } var axispaddingx = 10 + 20; if(p_info.hasOwnProperty("xlabs") && p_info.xlabs.length){ // TODO: throw warning if text height is large portion of plot height? axispaddingx += Math.max.apply(null, p_info.xlabs.map(function(entry){ return measureText(entry, p_info.xsize, p_info.xangle).height; })); // TODO: carefully calculating this gets complicated with rotating xlabs //margin.right += 5; } plotdim.margin = margin; var strip_heights = p_info.strips.top.map(function(entry){ return measureText(entry, p_info.strip_text_xsize).height; }); var strip_widths = p_info.strips.right.map(function(entry){ return measureText(entry, p_info.strip_text_ysize).height; }); // compute the number of x/y axes, max strip height per row, and // max strip width per columns, for calculating height/width of // graphing region. var row_strip_heights = []; var col_strip_widths = []; var n_xaxes = 0; var n_yaxes = 0; var current_row, current_col; for (var layout_i = 0; layout_i < npanels; layout_i++) { current_row = p_info.layout.ROW[layout_i] - 1; current_col = p_info.layout.COL[layout_i] - 1; if(row_strip_heights[current_row] === undefined){ row_strip_heights[current_row] = []; } if(col_strip_widths[current_col] === undefined){ col_strip_widths[current_col] = []; } row_strip_heights[current_row].push(strip_heights[layout_i]); col_strip_widths[current_col].push(strip_widths[layout_i]); if (p_info.layout.COL[layout_i] == 1) { n_xaxes += p_info.layout.AXIS_X[layout_i]; } if (p_info.layout.ROW[layout_i] == 1) { n_yaxes += p_info.layout.AXIS_Y[layout_i]; } } function cumsum_array(array_of_arrays){ var cumsum = [], max_value, cumsum_value = 0; for(var i=0; i 1) { background.append("rect") .attr("x", plotdim.xstart) .attr("y", plotdim.ystart) .attr("width", plotdim.xend - plotdim.xstart) .attr("height", plotdim.yend - plotdim.ystart) .attr("class", "background_rect") .style("fill", p_info.panel_background.fill) .style("stroke", p_info.panel_background.colour) .style("stroke-dasharray", function() { return linetypesize2dasharray(p_info.panel_background.linetype, p_info.panel_background.size); }); } // drawing the grid lines ["grid_minor", "grid_major"].forEach(function(grid_class){ var grid_background = p_info[grid_class]; // if grid lines are defined if(grid_background.hasOwnProperty("size")) { var grid = background.append("g") .attr("class", grid_class); ["x","y"].forEach(function(scale_var){ var const_var; if(scale_var == "x"){ const_var = "y"; }else{ const_var = "x"; } grid.append("g") .attr("class", scale_var) .selectAll("line") .data(grid_background.loc[scale_var][layout_i]) .enter() .append("line") .attr(const_var + "1", plotdim[const_var + "start"]) .attr(const_var + "2", plotdim[const_var + "end"]) .attr(scale_var + "1", function(d) { return scales[panel_i][scale_var](d); }) .attr(scale_var + "2", function(d) { return scales[panel_i][scale_var](d); }) .style("stroke", grid_background.colour) .style("stroke-linecap", grid_background.lineend) .style("stroke-width", grid_background.size) .style("stroke-dasharray", linetypesize2dasharray( grid_background.linetype, grid_background.size)) ; }); } }); // drawing border // uses insert to draw it right before the #plottitle if(Object.keys(p_info.panel_border).length > 1) { background.append("rect") .attr("x", plotdim.xstart) .attr("y", plotdim.ystart) .attr("width", plotdim.xend - plotdim.xstart) .attr("height", plotdim.yend - plotdim.ystart) .attr("class", "border_rect") .style("fill", p_info.panel_border.fill) .style("stroke", p_info.panel_border.colour) .style("stroke-dasharray", function() { return linetypesize2dasharray(p_info.panel_border.linetype, p_info.panel_border.size); }); } } //end of for(layout_i // After drawing all backgrounds, we can draw the axis labels. if(p_info["ytitle"]){ svg.append("text") .text(p_info["ytitle"]) .attr("class", "ytitle") .style("text-anchor", "middle") .style("font-size", default_axis_px + "px") .attr("transform", "translate(" + ytitle_x + "," + (ytitle_top + ytitle_bottom)/2 + ")rotate(270)") ; } if(p_info["xtitle"]){ svg.append("text") .text(p_info["xtitle"]) .attr("class", "xtitle") .style("text-anchor", "middle") .style("font-size", default_axis_px + "px") .attr("transform", "translate(" + (xtitle_left + xtitle_right)/2 + "," + xtitle_y + ")") ; } Plots[p_name].scales = scales; }; //end of add_plot() function update_legend_opacity(v_name){ var s_info = Selectors[v_name]; s_info.legend_tds.style("opacity", s_info.legend_update_fun); } var add_selector = function (s_name, s_info) { Selectors[s_name] = s_info; if(s_info.type == "multiple"){ if(!isArray(s_info.selected)){ s_info.selected = [s_info.selected]; } // legend_update_fun is evaluated in the context of the // td.legend_entry_label. s_info.legend_update_fun = function(d){ var i_value = s_info.selected.indexOf(this.textContent); if(i_value == -1){ return 0.5; }else{ return 1; } } }else{ s_info.legend_update_fun = function(d){ if(this.textContent == s_info.selected){ return 1; }else{ return 0.5; } } } s_info.legend_tds = element.selectAll("tr."+legend_class_name(s_name)+" td.legend_entry_label") ; update_legend_opacity(s_name); }; //end of add_selector() function get_tsv(g_info, chunk_id){ return g_info.classed + "_chunk" + chunk_id + ".tsv"; } function getTSVpath(tsv_name){ return dirs.concat(tsv_name).join("/"); } /** * copy common chunk tsv to varied chunk tsv, returning an array of * objects. */ function copy_chunk(g_info, varied_chunk) { var varied_by_group = nest_by_group.map(varied_chunk); var common_by_group = g_info.data[g_info.common_tsv]; var new_varied_chunk = []; for(group_id in varied_by_group){ var varied_one_group = varied_by_group[group_id]; var common_one_group = common_by_group[group_id]; var common_i = 0; for(var varied_i=0; varied_i < varied_one_group.length; varied_i++){ // there are two cases: each group of varied data is of length // 1, or of length of the common data. if(common_one_group.length == varied_one_group.length){ common_i = varied_i; } var varied_obj = varied_one_group[varied_i]; var common_obj = common_one_group[common_i]; for(col in common_obj){ if(col != "group"){ varied_obj[col] = common_obj[col]; } } new_varied_chunk.push(varied_obj); } } return new_varied_chunk; } // update_geom is called from add_geom and update_selector. It // downloads data if necessary, and then calls draw_geom. var update_geom = function (g_name, selector_name) { var g_info = Geoms[g_name]; // First apply chunk_order selector variables. var chunk_id = g_info.chunks; g_info.chunk_order.forEach(function (v_name) { if(chunk_id == null){ return; // no data in a higher up chunk var. } var value = Selectors[v_name].selected; if(chunk_id.hasOwnProperty(value)){ chunk_id = chunk_id[value]; }else{ chunk_id = null; // no data to show in this subset. } }); if(chunk_id == null){ draw_panels(g_info, [], selector_name); //draw nothing. return; } var tsv_name = get_tsv(g_info, chunk_id); // get the data if it has not yet been downloaded. g_info.tr.select("td.chunk").text(tsv_name); if(g_info.data.hasOwnProperty(tsv_name)){ draw_panels(g_info, g_info.data[tsv_name], selector_name); }else{ g_info.tr.select("td.status").text("downloading"); var svg = SVGs[g_name]; var loading = svg.append("text") .attr("class", "loading"+tsv_name) .text("Downloading "+tsv_name+"...") .attr("font-size", 9) //.attr("x", svg.attr("width")/2) .attr("y", 10) .style("fill", "red"); download_chunk(g_info, tsv_name, function(chunk){ loading.remove(); draw_panels(g_info, chunk, selector_name); }); } }; var draw_panels = function(g_info, chunk, selector_name) { // derive the plot name from the geometry name var g_names = g_info.classed.split("_"); var p_name = g_names[g_names.length - 1]; var panels = Plots[p_name].layout.PANEL; panels.forEach(function(panel) { draw_geom(g_info, chunk, selector_name, panel); }); }; function download_next(g_name){ var g_info = Geoms[g_name]; var selector_value = Animation.sequence[g_info.seq_i]; var chunk_id = g_info.chunks[selector_value]; var tsv_name = get_tsv(g_info, chunk_id); g_info.seq_count += 1; if(Animation.sequence.length == g_info.seq_count){ Animation.done_geoms[g_name] = 1; return; } g_info.seq_i += 1; if(g_info.seq_i == Animation.sequence.length){ g_info.seq_i = 0; } if(typeof(chunk_id) == "string"){ download_chunk(g_info, tsv_name, function(chunk){ download_next(g_name); }) }else{ download_next(g_name); } } // download_chunk is called from update_geom and download_next. function download_chunk(g_info, tsv_name, funAfter){ if(g_info.download_status.hasOwnProperty(tsv_name)){ var chunk; if(g_info.data_is_object){ chunk = {}; }else{ chunk = []; } funAfter(chunk); return; // do not download twice. } g_info.download_status[tsv_name] = "downloading"; // prefix tsv file with appropriate path var tsv_file = getTSVpath(tsv_name); d3.tsv(tsv_file, function (error, response) { // First convert to correct types. g_info.download_status[tsv_name] = "processing"; response = convert_R_types(response, g_info.types); wait_until_then(500, function(){ if(g_info.common_tsv) { return g_info.data.hasOwnProperty(g_info.common_tsv); }else{ return true; } }, function(){ if(g_info.common_tsv) { // copy data from common tsv to varied tsv response = copy_chunk(g_info, response); } var nest = d3.nest(); g_info.nest_order.forEach(function (v_name) { nest.key(function (d) { return d[v_name]; }); }); var chunk = nest.map(response); g_info.data[tsv_name] = chunk; g_info.tr.select("td.downloaded").text(d3.keys(g_info.data).length); g_info.download_status[tsv_name] = "saved"; funAfter(chunk); }); }); }//download_chunk. // update_geom is responsible for obtaining a chunk of downloaded // data, and then calling draw_geom to actually draw it. var draw_geom = function(g_info, chunk, selector_name, PANEL){ g_info.tr.select("td.status").text("displayed"); var svg = SVGs[g_info.classed]; // derive the plot name from the geometry name var g_names = g_info.classed.split("_"); var p_name = g_names[g_names.length - 1]; var scales = Plots[p_name].scales[PANEL]; var selected_arrays = [ [] ]; //double array necessary. var has_clickSelects = g_info.aes.hasOwnProperty("clickSelects"); var has_clickSelects_variable = g_info.aes.hasOwnProperty("clickSelects.variable"); g_info.subset_order.forEach(function (aes_name) { var selected, values; var new_arrays = []; if(0 < aes_name.indexOf(".variable")){ selected_arrays.forEach(function(old_array){ var some_data = chunk; old_array.forEach(function(value){ if(some_data.hasOwnProperty(value)) { some_data = some_data[value]; } else { some_data = {}; } }) values = d3.keys(some_data); values.forEach(function(s_name){ var selected = Selectors[s_name].selected; var new_array = old_array.concat(s_name).concat(selected); new_arrays.push(new_array); }) }) }else{//not .variable aes: if(aes_name == "PANEL"){ selected = PANEL; }else{ var s_name = g_info.aes[aes_name]; selected = Selectors[s_name].selected; } if(isArray(selected)){ values = selected; //multiple selection. }else{ values = [selected]; //single selection. } values.forEach(function(value){ selected_arrays.forEach(function(old_array){ var new_array = old_array.concat(value); new_arrays.push(new_array); }) }) } selected_arrays = new_arrays; }); // data can be either an array[] if it will be directly involved // in a data-bind, or an object{} if it will be involved in a // data-bind by group (e.g. geom_line). var data; if(g_info.data_is_object){ data = {}; }else{ data = []; } selected_arrays.forEach(function(value_array){ var some_data = chunk; value_array.forEach(function(value){ if (some_data.hasOwnProperty(value)) { some_data = some_data[value]; } else { if(g_info.data_is_object){ some_data = {}; }else{ some_data = []; } } }); if(g_info.data_is_object){ if(isArray(some_data) && some_data.length){ data["0"] = some_data; }else{ for(k in some_data){ data[k] = some_data[k]; } } }else{//some_data is an array. data = data.concat(some_data); } }); var aes = g_info.aes; var toXY = function (xy, a) { return function (d) { return scales[xy](d[a]); }; }; var layer_g_element = svg.select("g." + g_info.classed); var panel_g_element = layer_g_element.select("g.PANEL" + PANEL); var elements = panel_g_element.selectAll(".geom"); // helper functions so we can write code that works for both // grouped and ungrouped geoms. get_one_row returns one row of // data (not one group), in both cases. var get_fun = function(fun){ return function(input){ var d = get_one_row(input); return fun(d); }; }; var get_attr = function(attr_name){ return get_fun(function(d){ return d[attr_name]; }); }; var size = 2; var get_size; if(aes.hasOwnProperty("size")){ get_size = get_attr("size"); }else{ get_size = function(d){ return size; }; } var get_style_on_stroke_width = get_size; // stroke_width for geom_point var stroke_width = 1; // by default ggplot2 has 0.5, animint has 1 var get_stroke_width; if(aes.hasOwnProperty("stroke")){ get_stroke_width = get_attr("stroke"); }else{ get_stroke_width = function(d){ return stroke_width; }; } var linetype = "solid"; var get_linetype; if(aes.hasOwnProperty("linetype")){ get_linetype = get_attr("linetype"); }else{ get_linetype = function(d){ return linetype; }; } var get_dasharray = function(d){ var lt = get_linetype(d); return linetypesize2dasharray(lt, get_size(d)); }; var alpha = 1, alpha_off = 0.5; var get_alpha; var get_alpha_off = function (d) { return alpha_off; }; if(aes.hasOwnProperty("alpha")){ get_alpha = get_attr("alpha"); get_alpha_off = get_attr("alpha"); } else { get_alpha = function(d){ return alpha; }; } var colour = "black", colour_off; var get_colour; var get_colour_off = function (d) { return colour_off; }; if(aes.hasOwnProperty("colour")){ get_colour = get_attr("colour"); get_colour_off = get_colour; }else{ get_colour = function (d) { return colour; }; } var get_colour_off_default = get_colour; var fill = "black", fill_off = "black"; var get_fill = function (d) { return fill; }; var get_fill_off = function (d) { return fill_off; }; var angle = 0; var get_angle; if(aes.hasOwnProperty("angle")){ get_angle = get_attr("angle"); }else{ get_angle = function(d){ return angle; }; } var get_rotate = function(d){ // x and y are the coordinates to rotate around, we choose the center // point of the text because otherwise it will rotate around (0,0) of its // coordinate system, which is the top left of the plot x = scales["x"](d["x"]); y = scales["y"](d["y"]); var angle = get_angle(d); // ggplot expects angles to be in degrees CCW, SVG uses degrees CW, so // we negate the angle. return `rotate(${-angle}, ${x}, ${y})`; }; // For aes(hjust) the compiler should make an "anchor" column. var text_anchor = "middle"; var get_text_anchor; if(g_info.aes.hasOwnProperty("hjust")) { get_text_anchor = function(d){ return d["anchor"]; } }else{ get_text_anchor = function(d){ return text_anchor; } } var eActions, eAppend; var key_fun = null; if(g_info.aes.hasOwnProperty("key")){ key_fun = function(d){ return d.key; }; } var get_one_row;//different for grouped and ungrouped geoms. var data_to_bind; g_info.style_list = [ "opacity","stroke","stroke-width","stroke-dasharray","fill"]; var line_style_list = [ "opacity","stroke","stroke-width","stroke-dasharray"]; var fill_comes_from="fill", fill_off_comes_from="fill_off"; if(g_info.data_is_object) { // Lines, paths, polygons, and ribbons are a bit special. For // every unique value of the group variable, we take the // corresponding data rows and make 1 path. The tricky part is // that to use d3 I do a data-bind of some "fake" data which are // just group ids, which is the kv variable in the code below // // case of only 1 line and no groups. // if(!aes.hasOwnProperty("group")){ // kv = [{"key":0,"value":0}]; // data = {0:data}; // }else{ // // we need to use a path for each group. // var kv = d3.entries(d3.keys(data)); // kv = kv.map(function(d){ // d[aes.group] = d.value; // return d; // }); // } // For an example consider breakpointError$error which is // defined using this R code // geom_line(aes(segments, error, group=bases.per.probe, // clickSelects=bases.per.probe), data=only.error, lwd=4) // Inside update_geom the variables take the following values // (pseudo-Javascript code) // var kv = [{"key":"0","value":"133","bases.per.probe":"133"}, // {"key":"1","value":"2667","bases.per.probe":"2667"}]; // var data = {"133":[array of 20 points used to draw the line for group 133], // "2667":[array of 20 points used to draw the line for group 2667]}; // I do elements.data(kv) so that when I set the d attribute of // each path, I need to select the correct group before // returning anything. // e.attr("d",function(group_info){ // var one_group = data[group_info.value]; // return lineThing(one_group); // }) // To make color work I think you just have to select the group // and take the color of the first element, e.g. // .style("stroke",function(group_info){ // var one_group = data[group_info.value]; // var one_row = one_group[0]; // return get_color(one_row); // } // In order to get d3 lines to play nice, bind fake "data" (group // id's) -- the kv variable. Then each separate object is plotted // using path (case of only 1 thing and no groups). // we need to use a path for each group. var keyed_data = {}, one_group, group_id, k; for(group_id in data){ one_group = data[group_id]; one_row = one_group[0]; if(one_row.hasOwnProperty("key")){ k = one_row.key; }else{ k = group_id; } keyed_data[k] = one_group; } var kv_array = d3.entries(d3.keys(keyed_data)); var kv = kv_array.map(function (d) { //d[aes.group] = d.value; // Need to store the clickSelects value that will // be passed to the selector when we click on this // item. d.clickSelects = keyed_data[d.value][0].clickSelects; return d; }); // line, path, and polygon use d3.svg.line(), // ribbon uses d3.svg.area() // we have to define lineThing accordingly. if (g_info.geom == "ribbon") { var lineThing = d3.svg.area() .x(toXY("x", "x")) .y(toXY("y", "ymax")) .y0(toXY("y", "ymin")); } else { var lineThing = d3.svg.line() .x(toXY("x", "x")) .y(toXY("y", "y")); } if(["line","path"].includes(g_info.geom)){ fill = "none"; fill_off = "none"; } // select the correct group before returning anything. key_fun = function(group_info){ return group_info.value; }; data_to_bind = kv; get_one_row = function(group_info) { var one_group = keyed_data[group_info.value]; var one_row = one_group[0]; return one_row; }; eActions = function (e) { e.attr("d", function (d) { var one_group = keyed_data[d.value]; // filter NaN since they make the whole line disappear! var no_na = one_group.filter(function(d){ if(g_info.geom == "ribbon"){ return !isNaN(d.x) && !isNaN(d.ymin) && !isNaN(d.ymax); }else{ return !isNaN(d.x) && !isNaN(d.y); } }); return lineThing(no_na); }) }; eAppend = "path"; }else{ get_one_row = function(d){ return d; } data_to_bind = data; if (g_info.geom == "segment") { g_info.style_list = line_style_list; eActions = function (e) { e.attr("x1", function (d) { return scales.x(d["x"]); }) .attr("x2", function (d) { return scales.x(d["xend"]); }) .attr("y1", function (d) { return scales.y(d["y"]); }) .attr("y2", function (d) { return scales.y(d["yend"]); }) }; eAppend = "line"; } if (g_info.geom == "linerange") { g_info.style_list = line_style_list; eActions = function (e) { e.attr("x1", function (d) { return scales.x(d["x"]); }) .attr("x2", function (d) { return scales.x(d["x"]); }) .attr("y1", function (d) { return scales.y(d["ymax"]); }) .attr("y2", function (d) { return scales.y(d["ymin"]); }) ; }; eAppend = "line"; } if (g_info.geom == "vline") { g_info.style_list = line_style_list; eActions = function (e) { e.attr("x1", toXY("x", "xintercept")) .attr("x2", toXY("x", "xintercept")) .attr("y1", scales.y.range()[0]) .attr("y2", scales.y.range()[1]) ; }; eAppend = "line"; } if (g_info.geom == "hline") { g_info.style_list = line_style_list; eActions = function (e) { e.attr("y1", toXY("y", "yintercept")) .attr("y2", toXY("y", "yintercept")) .attr("x1", scales.x.range()[0]) .attr("x2", scales.x.range()[1]) ; }; eAppend = "line"; } if (g_info.geom == "text") { size = 12;//default get_colour = function(d){ return "none"; }; get_colour_off = function(d) { return "none"; }; fill_comes_from = "colour"; fill_off_comes_from = "colour_off"; g_info.style_list = [ "opacity","fill"]; eActions = function (e) { e.attr("x", toXY("x", "x")) .attr("y", toXY("y", "y")) .attr("font-size", get_size) .style("text-anchor", get_text_anchor) .attr("transform", get_rotate) .text(function (d) { return d.label; }) ; }; eAppend = "text"; } if (g_info.geom == "point") { // point is special because it takes SVG fill from ggplot // colour, if fill is not specified. if(!( g_info.params.hasOwnProperty("fill") || aes.hasOwnProperty("fill") )){ fill_comes_from = "colour"; } if(!g_info.params.hasOwnProperty("fill_off")){ fill_off_comes_from = "colour_off"; } get_style_on_stroke_width = get_stroke_width;//not size. eActions = function (e) { e.attr("cx", toXY("x", "x")) .attr("cy", toXY("y", "y")) .attr("r", get_size) ; }; eAppend = "circle"; } var rect_geoms = ["tallrect","widerect","rect"]; if(rect_geoms.includes(g_info.geom)){ eAppend = "rect"; if (g_info.geom == "tallrect") { eActions = function (e) { e.attr("x", toXY("x", "xmin")) .attr("width", function (d) { return scales.x(d["xmax"]) - scales.x(d["xmin"]); }) .attr("y", scales.y.range()[1]) .attr("height", scales.y.range()[0] - scales.y.range()[1]) ; }; } if (g_info.geom == "widerect") { eActions = function (e) { e.attr("y", toXY("y", "ymax")) .attr("height", function (d) { return scales.y(d["ymin"]) - scales.y(d["ymax"]); }) .attr("x", scales.x.range()[0]) .attr("width", scales.x.range()[1] - scales.x.range()[0]) ; }; } if (g_info.geom == "rect") { alpha_off = alpha; colour_off = "transparent"; get_colour_off_default = get_colour_off; eActions = function (e) { e.attr("x", toXY("x", "xmin")) .attr("width", function (d) { return Math.abs(scales.x(d.xmax) - scales.x(d.xmin)); }) .attr("y", toXY("y", "ymax")) .attr("height", function (d) { return Math.abs(scales.y(d.ymin) - scales.y(d.ymax)); }) ; }; } } } // set params after geom-specific code, because each geom may have // a different default. if (g_info.params.hasOwnProperty("stroke")) { stroke_width = g_info.params.stroke; } if (g_info.params.hasOwnProperty("linetype")) { linetype = g_info.params.linetype; } if(g_info.params.hasOwnProperty("alpha")){ alpha = g_info.params.alpha; alpha_off = alpha - 0.5 } if(g_info.params.hasOwnProperty("alpha_off")){ alpha_off = g_info.params.alpha_off; } if(g_info.params.hasOwnProperty("anchor")){ text_anchor = g_info.params["anchor"]; } if(g_info.params.hasOwnProperty("colour")){ colour = g_info.params.colour; } if(g_info.params.hasOwnProperty("colour_off")){ colour_off = g_info.params.colour_off; }else{ get_colour_off = get_colour_off_default; } if (g_info.params.hasOwnProperty("angle")) { angle = g_info.params["angle"]; } if (g_info.params.hasOwnProperty(fill_comes_from)) { fill = g_info.params[fill_comes_from]; } if (g_info.params.hasOwnProperty(fill_off_comes_from)) { fill_off = g_info.params[fill_off_comes_from]; }else{ fill_off = fill; } if(aes.hasOwnProperty(fill_comes_from)){ get_fill = get_attr(fill_comes_from); get_fill_off = get_attr(fill_comes_from); }; if (g_info.params.hasOwnProperty("size")) { size = g_info.params.size; } var styleActions = function(e){ g_info.style_list.forEach(function(s){ e.style(s, function(d) { var style_on_fun = style_on_funs[s]; return style_on_fun(d); }); }); }; var style_on_funs = { "opacity": get_alpha, "stroke": get_colour, "fill": get_fill, "stroke-width": get_style_on_stroke_width, "stroke-dasharray": get_dasharray }; var style_off_funs = { "opacity": get_alpha_off, "stroke": get_colour_off, "fill": get_fill_off }; // TODO cleanup. var select_style_default = ["opacity","stroke","fill"]; g_info.select_style = select_style_default.filter( X => g_info.style_list.includes(X)); var over_fun = function(e){ g_info.select_style.forEach(function(s){ e.style(s, function (d) { return style_on_funs[s](d); }); }); }; var out_fun = function(e){ g_info.select_style.forEach(function(s){ e.style(s, function (d) { var select_on = style_on_funs[s](d); var select_off = style_off_funs[s](d); if(has_clickSelects){ return ifSelectedElse( d.clickSelects, g_info.aes.clickSelects, select_on, select_off); }else if(has_clickSelects_variable){ return ifSelectedElse( d["clickSelects.value"], d["clickSelects.variable"], select_on, select_off); } }); }); }; elements = elements.data(data_to_bind, key_fun); elements.exit().remove(); var enter = elements.enter(); if(g_info.aes.hasOwnProperty("href")){ enter = enter.append("svg:a") .append("svg:"+eAppend); }else{ enter = enter.append(eAppend) .attr("class", "geom"); } var moreActions = function(e){}; if (has_clickSelects || has_clickSelects_variable) { moreActions = out_fun; elements.call(out_fun) .on("mouseover", function (d) { d3.select(this).call(over_fun); }) .on("mouseout", function (d) { d3.select(this).call(out_fun); }) ; if(has_clickSelects){ elements.on("click", function (d) { var s_name = g_info.aes.clickSelects; update_selector(s_name, d.clickSelects); }); }else{ elements.on("click", function(d){ var s_name = d["clickSelects.variable"]; var s_value = d["clickSelects.value"]; update_selector(s_name, s_value); }); } } // Set attributes of only the entering elements. This is needed to // prevent things from flying around from the upper left when they // enter the plot. var doActions = function(e) { eActions(e); styleActions(e); moreActions(e) }; doActions(enter); // DO NOT DELETE! var has_tooltip = g_info.aes.hasOwnProperty("tooltip"); if(has_clickSelects || has_tooltip || has_clickSelects_variable){ var text_fun; if(has_tooltip){ text_fun = function(d){ return d.tooltip; }; }else if(has_clickSelects){ text_fun = function(d){ var v_name = g_info.aes.clickSelects; return v_name + " " + d.clickSelects; }; }else{ //clickSelects_variable text_fun = function(d){ return d["clickSelects.variable"] + " " + d["clickSelects.value"]; }; } // if elements have an existing title, remove it. elements.selectAll("title").remove(); elements.append("svg:title") .text(get_fun(text_fun)) ; } if(Selectors.hasOwnProperty(selector_name)){ var milliseconds = Selectors[selector_name].duration; elements = elements.transition().duration(milliseconds); } if(g_info.aes.hasOwnProperty("id")){ elements.attr("id", get_attr("id")); } if(g_info.aes.hasOwnProperty("href")){ // elements are , children are e.g. var linked_geoms = elements.select(eAppend); doActions(linked_geoms); elements.attr("xlink:href", get_attr("href")) .attr("target", "_blank") .attr("class", "geom"); }else{ // elements are e.g. doActions(elements); // Set the attributes of all elements (enter/exit/stay) } }; var value_tostring = function(selected_values) { //function that is helpful to change the format of the string var selector_url="#" for (var selc_var in selected_values){ if(selected_values.hasOwnProperty(selc_var)){ var values_str=selected_values[selc_var].join(); var sub_url=selc_var.concat("=","{",values_str,"}"); selector_url=selector_url.concat(sub_url); } } var url_nohash=window.location.href.match(/(^[^#]*)/)[0]; selector_url=url_nohash.concat(selector_url); return selector_url; }; var get_values=function(){ // function that is useful to get the selected values var selected_values={} for(var s_name in Selectors){ var s_info=Selectors[s_name]; var initial_selections = []; if(s_info.type==="single"){ initial_selections=[s_info.selected]; } else{ for(var i in s_info.selected) { initial_selections[i] = s_info.selected[i]; } } selected_values[s_name]=initial_selections; } return selected_values; }; // update scales for the plots that have update_axes option in // theme_animint function update_scales(p_name, axes, v_name, value){ // Get pre-computed domain var axis_domains = Plots[p_name]["axis_domains"]; if(!isArray(axes)){ axes = [axes]; } if(axis_domains != null){ axes.forEach(function(xyaxis){ // For Each PANEL, update the axes Plots[p_name].layout.PANEL.forEach(function(panel_i, i){ // Determine whether this panel has a scale or not // If not we just update the scales according to the common // scale and skip the updating of axis var draw_axes = Plots[p_name].layout["AXIS_"+ xyaxis.toUpperCase()][i]; if(draw_axes){ var use_panel = panel_i; }else{ var use_panel = Plots[p_name].layout.PANEL[0]; } // We update the current selection of the plot every time // and use it to index the correct domain var curr_select = axis_domains[xyaxis].curr_select; if(axis_domains[xyaxis].selectors.indexOf(v_name) > -1){ curr_select[v_name] = value; var str = use_panel+"."; for(selec in curr_select){ str = str + curr_select[selec] + "_"; } str = str.substring(0, str.length - 1); // Strip off trailing underscore var use_domain = axis_domains[xyaxis]["domains"][str]; } if(use_domain != null){ Plots[p_name]["scales"][panel_i][xyaxis].domain(use_domain); var scales = Plots[p_name]["scales"][panel_i][xyaxis]; // major and minor grid lines as calculated in the compiler var grid_vals = Plots[p_name]["axis_domains"][xyaxis]["grids"][str]; // Once scales are updated, update the axis ticks if needed if(draw_axes){ // Tick values are same as major grid lines update_axes(p_name, xyaxis, panel_i, grid_vals[1]); } // Update major and minor grid lines update_grids(p_name, xyaxis, panel_i, grid_vals, scales); } }); }); } } // Update the axis ticks etc. once plot is zoomed in/out // currently called from update_scales. function update_axes(p_name, axes, panel_i, tick_vals){ var orientation; if(axes == "x"){ orientation = "bottom"; }else{ orientation = "left"; } if(!isArray(tick_vals)){ tick_vals = [tick_vals]; } var xyaxis = d3.svg.axis() .scale(Plots[p_name]["scales"][panel_i][axes]) .orient(orientation) .tickValues(tick_vals); // update existing axis var xyaxis_g = element.select("#plot_"+p_name).select("."+axes+"axis_"+panel_i) .transition() .duration(1000) .call(xyaxis); } // Update major/minor grids once axes ticks have been updated function update_grids(p_name, axes, panel_i, grid_vals, scales){ // Select panel to update var bgr = element.select("#plot_"+p_name).select(".bgr"+panel_i); // Update major and minor grid lines ["minor", "major"].forEach(function(grid_class, j){ var lines = bgr.select(".grid_"+grid_class).select("."+axes); var xy1, xy2; if(axes == "x"){ xy1 = lines.select("line").attr("y1"); xy2 = lines.select("line").attr("y2"); }else{ xy1 = lines.select("line").attr("x1"); xy2 = lines.select("line").attr("x2"); } // Get default values for grid lines like colour, stroke etc. var grid_background = Plots[p_name]["grid_"+grid_class]; var col = grid_background.colour; var lt = grid_background.linetype; var size = grid_background.size; var cap = grid_background.lineend; // Remove old lines lines.selectAll("line") .remove(); if(!isArray(grid_vals[j])){ grid_vals[j] = [grid_vals[j]]; } if(axes == "x"){ lines.selectAll("line") .data(grid_vals[j]) .enter() .append("line") .attr("y1", xy1) .attr("y2", xy2) .attr("x1", function(d) { return scales(d); }) .attr("x2", function(d) { return scales(d); }) .style("stroke", col) .style("stroke-linecap", cap) .style("stroke-width", size) .style("stroke-dasharray", function() { return linetypesize2dasharray(lt, size); }); }else{ lines.selectAll("line") .data(grid_vals[j]) .enter() .append("line") .attr("x1", xy1) .attr("x2", xy2) .attr("y1", function(d) { return scales(d); }) .attr("y2", function(d) { return scales(d); }) .style("stroke", col) .style("stroke-linecap", cap) .style("stroke-width", size) .style("stroke-dasharray", function() { return linetypesize2dasharray(lt, size); }); } }); } var update_selector = function (v_name, value) { if(!Selectors.hasOwnProperty(v_name)){ return; } value = value + ""; var s_info = Selectors[v_name]; if(s_info.type == "single"){ // value is the new selection. s_info.selected = value; }else{ // value should be added or removed from the selection. var i_value = s_info.selected.indexOf(value); if(i_value == -1){ // not found, add to selection. s_info.selected.push(value); }else{ // found, remove from selection. s_info.selected.splice(i_value, 1); } } // update_selector_url() // if there are levels, then there is a selectize widget which // should be updated. if(isArray(s_info.levels)){ // the jquery ids if(s_info.type == "single") { var selected_ids = v_name.concat("___", value); } else { var selected_ids = []; for(i in s_info.selected) { selected_ids[i] = v_name.concat("___", s_info.selected[i]); } } // from // https://github.com/brianreavis/selectize.js/blob/master/docs/api.md: // setValue(value, silent) If "silent" is truthy, no change // event will be fired on the original input. selectized_array[v_name].setValue(selected_ids, true); } // For each updated geom, check if the axes of the plot need to be // updated and update them s_info.update.forEach(function(g_name){ var plot_name = g_name.split("_").pop(); var axes = Plots[plot_name]["options"]["update_axes"]; if(axes != null){ update_scales(plot_name, axes, v_name, value); } }); update_legend_opacity(v_name); s_info.update.forEach(function(g_name){ update_geom(g_name, v_name); }); }; var ifSelectedElse = function (s_value, s_name, selected, not_selected) { var is_selected; var s_info = Selectors[s_name]; if(s_info.type == "single"){ is_selected = s_value == s_info.selected; }else{ is_selected = s_info.selected.indexOf(s_value) != -1; } if(is_selected){ return selected; } else { return not_selected; } }; function update_next_animation(){ var values = d3.values(Animation.done_geoms); if(d3.sum(values) == values.length){ // If the values in done_geoms are all 1, then we have loaded // all of the animation-related chunks, and we can start // playing the animation. var v_name = Animation.variable; var cur = Selectors[v_name].selected; var next = Animation.next[cur]; update_selector(v_name, next); } } // The main idea of how legends work: // 1. In getLegend in animint.R I export the legend entries as a // list of rows that can be used in a data() bind in D3. // 2. Here in add_legend I create a for every legend, and // then I bind the legend entries to ,
, and elements. var add_legend = function(p_name, p_info){ // case of multiple legends, d3 reads legend structure in as an array var tdRight = element.select("td."+p_name+"_legend"); var legendkeys = d3.keys(p_info.legend); for(var i=0; i-1){ // aesthetics that would draw a rect legend_svgs.append("rect") .attr("x", 2) .attr("y", 2) .attr("width", 10) .attr("height", 10) .style("stroke-width", function(d){return d["polygonsize"]||1;}) .style("stroke-dasharray", function(d){ return linetypesize2dasharray(d["polygonlinetype"], d["size"]||2); }) .style("stroke", function(d){return d["polygoncolour"] || "#000000";}) .style("fill", function(d){return d["polygonfill"] || "#FFFFFF";}) .style("opacity", function(d){return d["polygonalpha"]||1;}); } if(l_info.geoms.indexOf("text")>-1){ // aesthetics that would draw a rect legend_svgs.append("text") .attr("x", 10) .attr("y", 14) .style("fill", function(d){return d["textcolour"]||1;}) .style("text-anchor", "middle") .attr("font-size", function(d){return d["textsize"]||1;}) .text("a"); } if(l_info.geoms.indexOf("path")>-1){ // aesthetics that would draw a line legend_svgs.append("line") .attr("x1", 1).attr("x2", 19).attr("y1", 7).attr("y2", 7) .style("stroke-width", function(d){ return linescale(d["pathsize"])||2; }) .style("stroke-dasharray", function(d){ return linetypesize2dasharray(d["pathlinetype"], d["pathsize"] || 2); }) .style("stroke", function(d){return d["pathcolour"] || "#000000";}) .style("opacity", function(d){return d["pathalpha"]||1;}); } if(l_info.geoms.indexOf("point")>-1){ // aesthetics that would draw a point legend_svgs.append("circle") .attr("cx", 10) .attr("cy", 7) .attr("r", function(d){return pointscale(d["pointsize"])||4;}) .style("stroke", function(d){return d["pointcolour"] || "#000000";}) .style("fill", function(d){ return d["pointfill"] || d["pointcolour"] || "#000000"; }) .style("opacity", function(d){return d["pointalpha"]||1;}); } legend_rows.append("td") .attr("align", "left") // TODO: right for numbers? .attr("class", "legend_entry_label") .attr("id", function(d){ return d["id"]+"_label"; }) .style("font-size", function(d){ return d["text_size"]}) .text(function(d){ return d["label"];}); } } // Download the main description of the interactive plot. d3.json(json_file, function (error, response) { if(response.hasOwnProperty("title")){ // This selects the title of the web page, outside of wherever // the animint is defined, usually a
-- so it is OK to use // global d3.select here. d3.select("title").text(response.title); } // Add plots. for (var p_name in response.plots) { add_plot(p_name, response.plots[p_name]); add_legend(p_name, response.plots[p_name]); // Append style sheet to document head. css.appendChild(document.createTextNode(styles.join(" "))); document.head.appendChild(css); } // Then add selectors and start downloading the first data subset. for (var s_name in response.selectors) { add_selector(s_name, response.selectors[s_name]); } // Update the scales/axes of the plots if needed // We do this so that the plots zoom in initially after loading for (var p_name in response.plots) { if(response.plots[p_name].axis_domains !== null){ for(var xy in response.plots[p_name].axis_domains){ var selectors = response.plots[p_name].axis_domains[xy].selectors; if(!isArray(selectors)){ selectors = [selectors]; } update_scales(p_name, xy, selectors[0], response.selectors[selectors[0]].selected); } } } //////////////////////////////////////////// // Widgets at bottom of page //////////////////////////////////////////// element.append("br"); if(response.hasOwnProperty("source")){ element.append("a") .attr("id","a_source_href") .attr("href", response.source) .text("source"); } // loading table. var show_hide_table = element.append("button") .text("Show download status table"); show_hide_table .on("click", function(){ if(this.textContent == "Show download status table"){ loading.style("display", ""); show_hide_table.text("Hide download status table"); }else{ loading.style("display", "none"); show_hide_table.text("Show download status table"); } }); var loading = element.append("table") .style("display", "none"); Widgets["loading"] = loading; var tr = loading.append("tr"); tr.append("th").text("geom"); tr.append("th").attr("class", "chunk").text("selected chunk"); tr.append("th").attr("class", "downloaded").text("downloaded"); tr.append("th").attr("class", "total").text("total"); tr.append("th").attr("class", "status").text("status"); // Add geoms and construct nest operators. for (var g_name in response.geoms) { add_geom(g_name, response.geoms[g_name]); } // Animation control widgets. var show_message = "Show animation controls"; // add a button to view the animation widgets var show_hide_animation_controls = element.append("button") .text(show_message) .attr("id", viz_id + "_show_hide_animation_controls") .on("click", function(){ if(this.textContent == show_message){ time_table.style("display", ""); show_hide_animation_controls.text("Hide animation controls"); }else{ time_table.style("display", "none"); show_hide_animation_controls.text(show_message); } }) ; // table of the animint widgets var time_table = element.append("table") .style("display", "none"); var first_tr = time_table.append("tr"); var first_th = first_tr.append("th"); // if there's a time variable, add a button to pause the animint if(response.time){ Animation.next = {}; Animation.ms = response.time.ms; Animation.variable = response.time.variable; Animation.sequence = response.time.sequence; Widgets["play_pause"] = first_th.append("button") .text("Play") .attr("id", "play_pause") .on("click", function(){ if(this.textContent == "Play"){ Animation.play(); }else{ Animation.pause(false); } }) ; } first_tr.append("th").text("milliseconds"); if(response.time){ var second_tr = time_table.append("tr"); second_tr.append("td").text("updates"); second_tr.append("td").append("input") .attr("id", "updates_ms") .attr("type", "text") .attr("value", Animation.ms) .on("change", function(){ Animation.pause(false); Animation.ms = this.value; Animation.play(); }) ; } for(s_name in Selectors){ var s_info = Selectors[s_name]; if(!s_info.hasOwnProperty("duration")){ s_info.duration = 0; } } var selector_array = d3.keys(Selectors); var duration_rows = time_table.selectAll("tr.duration") .data(selector_array) .enter() .append("tr"); duration_rows .append("td") .text(function(s_name){return s_name;}); var duration_tds = duration_rows.append("td"); var duration_inputs = duration_tds .append("input") .attr("id", function(s_name){ return viz_id + "_duration_ms_" + s_name; }) .attr("type", "text") .on("change", function(s_name){ Selectors[s_name].duration = this.value; }) .attr("value", function(s_name){ return Selectors[s_name].duration; }); // selector widgets var toggle_message = "Show selection menus"; var show_or_hide_fun = function(){ if(this.textContent == toggle_message){ selector_table.style("display", ""); show_hide_selector_widgets.text("Hide selection menus"); d3.select(".urltable").style("display","") }else{ selector_table.style("display", "none"); show_hide_selector_widgets.text(toggle_message); d3.select(".urltable").style("display","none") } } var show_hide_selector_widgets = element.append("button") .text(toggle_message) .attr("class", "show_hide_selector_widgets") .on("click", show_or_hide_fun) ; // adding a table for selector widgets var selector_table = element.append("table") .style("display", "none") .attr("class", "table_selector_widgets") ; var selector_first_tr = selector_table.append("tr"); selector_first_tr .append("th") .text("Variable") ; selector_first_tr .append("th") .text("Selected value(s)") ; // looping through and adding a row for each selector for(s_name in Selectors) { var s_info = Selectors[s_name]; // for .variable .value selectors, levels is undefined and we do // not want to make a selectize widget. // TODO: why does it take so long to initialize the selectize // widget when there are many (>1000) values? if(isArray(s_info.levels)){ // If there were no geoms that specified clickSelects for this // selector, then there is no way to select it other than the // selectize widgets (and possibly legends). So in this case // we show the selectize widgets by default. var selector_widgets_hidden = show_hide_selector_widgets.text() == toggle_message; var has_no_clickSelects = !Selectors[s_name].hasOwnProperty("clickSelects") var has_no_legend = !Selectors[s_name].hasOwnProperty("legend") if(selector_widgets_hidden && has_no_clickSelects && has_no_legend){ var node = show_hide_selector_widgets.node(); show_or_hide_fun.apply(node); } // removing "." from name so it can be used in ids var s_name_id = legend_class_name(s_name); // adding a row for each selector var selector_widget_row = selector_table .append("tr") .attr("class", function() { return s_name_id + "_selector_widget"; }) ; selector_widget_row.append("td").text(s_name); // adding the selector var selector_widget_select = selector_widget_row .append("td") .append("select") .attr("class", function() { return s_name_id + "_input"; }) .attr("placeholder", function() { return "Toggle " + s_name; }); // adding an option for each level of the variable selector_widget_select.selectAll("option") .data(s_info.levels) .enter() .append("option") .attr("value", function(d) { return d; }) .text(function(d) { return d; }); // making sure that the first option is blank selector_widget_select .insert("option") .attr("value", "") .text(function() { return "Toggle " + s_name; }); // calling selectize var selectize_selector = to_select + ' .' + s_name_id + "_input"; if(s_info.type == "single") { // setting up array of selector and options var selector_values = []; for(i in s_info.levels) { selector_values[i] = { id: s_name.concat("___", s_info.levels[i]), text: s_info.levels[i] }; } // the id of the first selector var selected_id = s_name.concat("___", s_info.selected); // if single selection, only allow one item var $temp = $(selectize_selector) .selectize({ create: false, valueField: 'id', labelField: 'text', searchField: ['text'], options: selector_values, items: [selected_id], maxItems: 1, allowEmptyOption: true, onChange: function(value) { // extracting the name and the level to update var selector_name = value.split("___")[0]; var selected_level = value.split("___")[1]; // updating the selector update_selector(selector_name, selected_level); } }) ; } else { // multiple selection: // setting up array of selector and options var selector_values = []; if(typeof s_info.levels == "object") { for(i in s_info.levels) { selector_values[i] = { id: s_name.concat("___", s_info.levels[i]), text: s_info.levels[i] }; } } else { selector_values[0] = { id: s_name.concat("___", s_info.levels), text: s_info.levels }; } // setting up an array to contain the initally selected elements var initial_selections = []; for(i in s_info.selected) { initial_selections[i] = s_name.concat("___", s_info.selected[i]); } // construct the selectize var $temp = $(selectize_selector) .selectize({ create: false, valueField: 'id', labelField: 'text', searchField: ['text'], options: selector_values, items: initial_selections, maxItems: s_info.levels.length, allowEmptyOption: true, onChange: function(value) { // if nothing is selected, remove what is currently selected if(value == null) { // extracting the selector ids from the options var the_ids = Object.keys($(this)[0].options); // the name of the appropriate selector var selector_name = the_ids[0].split("___")[0]; // the previously selected elements var old_selections = Selectors[selector_name].selected; // updating the selector for each of the old selections old_selections.forEach(function(element) { update_selector(selector_name, element); }); } else { // value is not null: // grabbing the name of the selector from the selected value var selector_name = value[0].split("___")[0]; // identifying the levels that should be selected var specified_levels = []; for(i in value) { specified_levels[i] = value[i].split("___")[1]; } // the previously selected entries old_selections = Selectors[selector_name].selected; // the levels that need to have selections turned on specified_levels .filter(function(n) { return old_selections.indexOf(n) == -1; }) .forEach(function(element) { update_selector(selector_name, element); }) ; // the levels that need to be turned off // - same approach old_selections .filter(function(n) { return specified_levels.indexOf(n) == -1; }) .forEach(function(element) { update_selector(selector_name, element); }) ; }//value==null }//onChange })//selectize ; }//single or multiple selection. selectized_array[s_name] = $temp[0].selectize; }//levels, is.variable.value } // close for loop through selector widgets // If this is an animation, then start downloading all the rest of // the data, and start the animation. if (response.time) { var i, prev, cur; for (var i = 0; i < Animation.sequence.length; i++) { if (i == 0) { prev = Animation.sequence[Animation.sequence.length-1]; } else { prev = Animation.sequence[i - 1]; } cur = Animation.sequence[i]; Animation.next[prev] = cur; } Animation.timer = null; Animation.play = function(){ if(Animation.timer == null){ // only play if not already playing. // as shown on http://bl.ocks.org/mbostock/3808234 Animation.timer = setInterval(update_next_animation, Animation.ms); Widgets["play_pause"].text("Pause"); } }; Animation.play_after_visible = false; Animation.pause = function(play_after_visible){ Animation.play_after_visible = play_after_visible; clearInterval(Animation.timer); Animation.timer = null; Widgets["play_pause"].text("Play"); }; var s_info = Selectors[Animation.variable]; Animation.done_geoms = {}; s_info.update.forEach(function(g_name){ var g_info = Geoms[g_name]; if(g_info.chunk_order.length == 1 && g_info.chunk_order[0] == Animation.variable){ g_info.seq_i = Animation.sequence.indexOf(s_info.selected); g_info.seq_count = 0; Animation.done_geoms[g_name] = 0; download_next(g_name); } }); Animation.play(); all_geom_names = d3.keys(response.geoms); // This code starts/stops the animation timer when the page is // hidden, inspired by // http://stackoverflow.com/questions/1060008 function onchange (evt) { if(document.visibilityState == "visible"){ if(Animation.play_after_visible){ Animation.play(); } }else{ if(Widgets["play_pause"].text() == "Pause"){ Animation.pause(true); } } }; document.addEventListener("visibilitychange", onchange); } // update_selector_url() var check_func=function(){ var status_array = $('.status').map(function(){ return $.trim($(this).text()); }).get(); status_array=status_array.slice(1) return status_array.every(function(elem){ return elem === "displayed"}); } if(window.location.hash) { var fragment=window.location.hash; fragment=fragment.slice(1); fragment=decodeURI(fragment) var frag_array=fragment.split(/(.*?})/); frag_array=frag_array.filter(function(x){ return x!=""}) frag_array.forEach(function(selector_string){ var selector_hash=selector_string.split("="); var selector_nam=selector_hash[0]; var selector_values=selector_hash[1]; var re = /\{(.*?)\}/; selector_values = re.exec(selector_values)[1]; var array_values = selector_values.split(','); if(Selectors.hasOwnProperty(selector_nam)){ var s_info = Selectors[selector_nam] if(s_info.type=="single"){//TODO fix array_values.forEach(function(element) { wait_until_then(100, check_func, update_selector,selector_nam,element) if(response.time)Animation.pause(true) }); }else{ var old_selections = Selectors[selector_nam].selected; // the levels that need to have selections turned on array_values .filter(function(n) { return old_selections.indexOf(n) == -1; }) .forEach(function(element) { wait_until_then(100, check_func, update_selector,selector_nam,element) if(response.time){ Animation.pause(true) } }); old_selections .filter(function(n) { return array_values.indexOf(n) == -1; }) .forEach(function(element) { wait_until_then(100, check_func, update_selector,selector_nam,element) if(response.time){ Animation.pause(true) } }); }//if(single) else multiple selection }//if(Selectors.hasOwnProperty(selector_nam)) })//frag_array.forEach }//if(window.location.hash) }); };

if(FALSE){
  animint2pages(viz, "2023-12-13-train-predict-subsets-regression")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-13-train-predict-subsets-regression/

Simulated classification problems

The previous section investigated a simulated regression problem, whereas in this section we simulate a binary classification problem. Assume there is a data set with some rows from one person, some rows from another,

N <- 200
library(data.table)
(full.dt <- data.table(
  label=factor(rep(c("spam","not spam"), l=N)),
  person=rep(1:2, each=0.5*N)
)[, signal := ifelse(label=="not spam", 0, 3)][])
#>         label person signal
#>        <fctr>  <int>  <num>
#>   1:     spam      1      3
#>   2: not spam      1      0
#>   3:     spam      1      3
#>   4: not spam      1      0
#>   5:     spam      1      3
#>  ---                       
#> 196: not spam      2      0
#> 197:     spam      2      3
#> 198: not spam      2      0
#> 199:     spam      2      3
#> 200: not spam      2      0

Above each row has an person ID between 1 and 2. We can imagine a spam filtering system, that has training data for multiple people (here just two). Each row in the table above represents a message which has been labeled as spam or not, by one of the two people. Can we train on one person, and accurately predict on the other person? To do that we will need some features, which we generate/simulate below:

set.seed(1)
n.people <- length(unique(full.dt$person))
for(person.i in 1:n.people){
  use.signal.vec <- list(
    easy=rep(if(person.i==1)TRUE else FALSE, N),
    impossible=full.dt$person==person.i)
  for(task_id in names(use.signal.vec)){
    use.signal <- use.signal.vec[[task_id]]
    full.dt[
    , paste0("x",person.i,"_",task_id) := ifelse(
      use.signal, signal, 0
    )+rnorm(N)][]
  }
}
full.dt
#>         label person signal    x1_easy x1_impossible    x2_easy x2_impossible
#>        <fctr>  <int>  <num>      <num>         <num>      <num>         <num>
#>   1:     spam      1      3  2.3735462     3.4094018  1.0744410    -0.3410670
#>   2: not spam      1      0  0.1836433     1.6888733  1.8956548     1.5024245
#>   3:     spam      1      3  2.1643714     4.5865884 -0.6029973     0.5283077
#>   4: not spam      1      0  1.5952808    -0.3309078 -0.3908678     0.5421914
#>   5:     spam      1      3  3.3295078     0.7147645 -0.4162220    -0.1366734
#>  ---                                                                         
#> 196: not spam      2      0 -1.0479844    -0.9243128  0.7682782    -1.0293917
#> 197:     spam      2      3  4.4411577     1.5929138 -0.8161606     2.9890743
#> 198: not spam      2      0 -1.0158475     0.0450106 -0.4361069    -1.2249912
#> 199:     spam      2      3  3.4119747    -0.7151284  0.9047050     0.4038886
#> 200: not spam      2      0 -0.3810761     0.8652231 -0.7630863     1.1691226

In the table above, there are two sets of two features:

  • For easy features, one is correlated with the label (x1_easy), and one is random noise (x2_easy), so the algorithm just needs to learn to ignore the noise feature, and concentrate on the signal feature. That should be possible given data from either person (same signal in each person).
  • Each impossible feature is correlated with the label (when feature number same as person number), or is just noise (when person number different from feature number). So if the algorithm has access to the correct person (same as test, say person 2), then it needs to learn to use the corresponding feature x2_impossible. But if the algorithm does not have access to that person, then the best it can do is same as featureless (predict most frequent class label in train data).

Static visualization of simulated data

Below we reshape the data to a table which is more suitable for visualization:

(scatter.dt <- nc::capture_melt_multiple(
  full.dt,
  column="x[12]",
  "_",
  task_id="easy|impossible"))
#>         label person signal    task_id         x1         x2
#>        <fctr>  <int>  <num>     <char>      <num>      <num>
#>   1:     spam      1      3       easy  2.3735462  1.0744410
#>   2: not spam      1      0       easy  0.1836433  1.8956548
#>   3:     spam      1      3       easy  2.1643714 -0.6029973
#>   4: not spam      1      0       easy  1.5952808 -0.3908678
#>   5:     spam      1      3       easy  3.3295078 -0.4162220
#>  ---                                                        
#> 396: not spam      2      0 impossible -0.9243128 -1.0293917
#> 397:     spam      2      3 impossible  1.5929138  2.9890743
#> 398: not spam      2      0 impossible  0.0450106 -1.2249912
#> 399:     spam      2      3 impossible -0.7151284  0.4038886
#> 400: not spam      2      0 impossible  0.8652231  1.1691226

Below we visualize the pattern for each person and feature type:

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x1, x2, color=label),
      shape=1,
      data=scatter.dt)+
    facet_grid(
      task_id ~ person,
      labeller=label_both)
}

plot of chunk unnamed-chunk-13

In the plot above, it is apparent that

  • for easy features (left), the two label classes differ in x1 values for both people. So it should be possible/easy to train on person 1, and predict accurately on person 2.
  • for impossible features (right), the two people have different label patterns. For person 1, the two label classes differ in x1 values, whereas for person 2, the two label classes differ in x2 values. So it should be impossible to train on person 1, and predict accurately on person 2.

Benchmark: computing test error

We use the code below to create a list of classification tasks, for use in the mlr3 framework.

class.task.list <- list()
for(task_id in c("easy","impossible")){
  feature.names <- grep(task_id, names(full.dt), value=TRUE)
  task.col.names <- c(feature.names, "label", "person")
  task.dt <- full.dt[, task.col.names, with=FALSE]
  this.task <- mlr3::TaskClassif$new(
    task_id, task.dt, target="label")
  this.task$col_roles$subset <- "person"
  this.task$col_roles$stratum <- c("person","label")
  this.task$col_roles$feature <- setdiff(names(task.dt), this.task$col_roles$stratum)
  class.task.list[[task_id]] <- this.task
}
class.task.list
#> $easy
#> <TaskClassif:easy> (200 x 3)
#> * Target: label
#> * Properties: twoclass, strata
#> * Features (2):
#>   - dbl (2): x1_easy, x2_easy
#> * Strata: person, label
#> 
#> $impossible
#> <TaskClassif:impossible> (200 x 3)
#> * Target: label
#> * Properties: twoclass, strata
#> * Features (2):
#>   - dbl (2): x1_impossible, x2_impossible
#> * Strata: person, label

Note in the code above that person is assigned roles subset and stratum, whereas label is assigned roles target and stratum. When adapting the code above to real data, the important part is the mlr3::TaskClassif line which tells mlr3 what data set to use, and what columns should be used for target/subset/stratum.

The code below is used to define a K-fold cross-validation experiment,

(class_same_other <- mlr3resampling::ResamplingSameOtherCV$new())
#> <ResamplingSameOtherCV> : Same versus Other Cross-Validation
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 1
#>  $ folds: int 3

The code below is used to define the learning algorithms to test,

(class.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerClassifRpart$new(),
  mlr3::LearnerClassifFeatureless$new()))
#> [[1]]
#> <LearnerClassifRpart:classif.rpart>: Classification Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response], prob
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, multiclass, selected_features,
#>   twoclass, weights
#> 
#> [[2]]
#> <LearnerClassifFeatureless:classif.featureless>: Featureless Classification Learner
#> * Model: -
#> * Parameters: method=mode
#> * Packages: mlr3
#> * Predict Types:  [response], prob
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, importance, missings, multiclass,
#>   selected_features, twoclass

The code below defines the grid of tasks, learners, and resamplings.

(class.bench.grid <- mlr3::benchmark_grid(
  class.task.list,
  class.learner.list,
  class_same_other))
#>          task             learner    resampling
#>        <char>              <char>        <char>
#> 1:       easy       classif.rpart same_other_cv
#> 2:       easy classif.featureless same_other_cv
#> 3: impossible       classif.rpart same_other_cv
#> 4: impossible classif.featureless same_other_cv

The code below runs the benchmark experiment grid. Note that each iteration can be parallelized by declaring a future plan.

if(FALSE){
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(class.bench.result <- mlr3::benchmark(
  class.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 72 rows with 4 resampling runs
#>  nr    task_id          learner_id resampling_id iters warnings errors
#>   1       easy       classif.rpart same_other_cv    18        0      0
#>   2       easy classif.featureless same_other_cv    18        0      0
#>   3 impossible       classif.rpart same_other_cv    18        0      0
#>   4 impossible classif.featureless same_other_cv    18        0      0

Below we compute scores (test error) for each resampling iteration, and show the first row of the result.

class.bench.score <- mlr3resampling::score(class.bench.result)
class.bench.score[1]
#>    train.subsets test.fold test.subset person iteration                  test
#>           <char>     <int>       <int>  <int>     <int>                <list>
#> 1:           all         1           1      1         1  1, 2, 8,11,12,18,...
#>                    train                                uhash    nr
#>                   <list>                               <char> <int>
#> 1:  3, 4, 5, 6, 9,10,... c89df2a9-3bc9-477d-b981-dbfcafb0504d     1
#>                  task task_id                             learner    learner_id
#>                <list>  <char>                              <list>        <char>
#> 1: <TaskClassif:easy>    easy <LearnerClassifRpart:classif.rpart> classif.rpart
#>                 resampling resampling_id          prediction classif.ce
#>                     <list>        <char>              <list>      <num>
#> 1: <ResamplingSameOtherCV> same_other_cv <PredictionClassif> 0.08823529
#>    algorithm
#>       <char>
#> 1:     rpart

Finally we plot the test error values below.

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      classif.ce, train.subsets, color=algorithm),
      shape=1,
      data=class.bench.score)+
    facet_grid(
      person ~ task_id,
      labeller=label_both,
      scales="free")
}

plot of chunk unnamed-chunk-20

It is clear from the plot above that

  • for the easy task, training on same is just as good as all or other subsets.
  • for the impossible task, we must train on same subset for minimal test error; training on all is almost as good, because the pattern in person 1 is orthogonal to person 2; training on other is just as bad as featureless, because patterns are different.
  • in a real data task, training on other will most likely not be quite as bad as in the impossible task above, but also not as good as in the easy task.

Interactive visualization of data, test error, and splits

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

inst <- class.bench.score$resampling[[1]]$instance
rect.expand <- 0.2
grid.value.dt <- scatter.dt[
, lapply(.SD, function(x)do.call(seq, c(as.list(range(x)), l=21)))
, .SDcols=c("x1","x2")]
grid.class.dt <- data.table(
  label=full.dt$label[1],
  do.call(
    CJ, grid.value.dt
  )
)
class.pred.dt.list <- list()
class.point.dt.list <- list()
for(score.i in 1:nrow(class.bench.score)){
  class.bench.row <- class.bench.score[score.i]
  task.dt <- data.table(
    class.bench.row$task[[1]]$data(),
    class.bench.row$resampling[[1]]$instance$id.dt)
  names(task.dt)[2:3] <- c("x1","x2")
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=class.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ][]
  class.point.dt.list[[score.i]] <- data.table(
    class.bench.row[, .(task_id, iteration)],
    i.points)
  if(class.bench.row$algorithm!="featureless"){
    i.learner <- class.bench.row$learner[[1]]
    i.learner$predict_type <- "prob"
    i.task <- class.bench.row$task[[1]]
    setnames(grid.class.dt, names(i.task$data()))
    grid.class.task <- mlr3::TaskClassif$new(
      "grid", grid.class.dt, target="label")
    pred.grid <- as.data.table(
      i.learner$predict(grid.class.task)
    )[, data.table(grid.class.dt, prob.spam)]
    names(pred.grid)[2:3] <- c("x1","x2")
    pred.wide <- dcast(pred.grid, x1 ~ x2, value.var="prob.spam")
    prob.mat <- as.matrix(pred.wide[,-1])
    contour.list <- contourLines(
      grid.value.dt$x1, grid.value.dt$x2, prob.mat, levels=0.5)
    class.pred.dt.list[[score.i]] <- data.table(
      class.bench.row[, .(
        task_id, iteration, algorithm
      )],
      data.table(contour.i=seq_along(contour.list))[, {
        do.call(data.table, contour.list[[contour.i]])[, .(level, x1=x, x2=y)]
      }, by=contour.i]
    )
  }
}
(class.pred.dt <- rbindlist(class.pred.dt.list))
#>         task_id iteration algorithm contour.i level       x1        x2
#>          <char>     <int>    <char>     <int> <num>    <num>     <num>
#>   1:       easy         1     rpart         1   0.5 1.856156 -3.008049
#>   2:       easy         1     rpart         1   0.5 1.856156 -2.606579
#>   3:       easy         1     rpart         1   0.5 1.856156 -2.205109
#>   4:       easy         1     rpart         1   0.5 1.856156 -1.803639
#>   5:       easy         1     rpart         1   0.5 1.856156 -1.402169
#>  ---                                                                  
#> 766: impossible        18     rpart         1   0.5 3.743510  1.225096
#> 767: impossible        18     rpart         1   0.5 4.158037  1.225096
#> 768: impossible        18     rpart         1   0.5 4.572564  1.225096
#> 769: impossible        18     rpart         1   0.5 4.987091  1.225096
#> 770: impossible        18     rpart         1   0.5 5.401618  1.225096
(class.point.dt <- rbindlist(class.point.dt.list))
#>           task_id iteration set.name row_id    label         x1         x2
#>            <char>     <int>   <char>  <int>   <fctr>      <num>      <num>
#>     1:       easy         1     test      1     spam  2.3735462  1.0744410
#>     2:       easy         1     test      2 not spam  0.1836433  1.8956548
#>     3:       easy         1    train      3     spam  2.1643714 -0.6029973
#>     4:       easy         1    train      4 not spam  1.5952808 -0.3908678
#>     5:       easy         1    train      5     spam  3.3295078 -0.4162220
#>    ---                                                                    
#> 14396: impossible        18    train    196 not spam -0.9243128 -1.0293917
#> 14397: impossible        18    train    197     spam  1.5929138  2.9890743
#> 14398: impossible        18    train    198 not spam  0.0450106 -1.2249912
#> 14399: impossible        18    train    199     spam -0.7151284  0.4038886
#> 14400: impossible        18    train    200 not spam  0.8652231  1.1691226
#>         fold person subset display_row
#>        <int>  <int>  <int>       <int>
#>     1:     1      1      1           1
#>     2:     1      1      1           2
#>     3:     2      1      1          35
#>     4:     2      1      1          36
#>     5:     2      1      1          37
#>    ---                                
#> 14396:     2      2      2         166
#> 14397:     2      2      2         167
#> 14398:     1      2      2         133
#> 14399:     1      2      2         134
#> 14400:     2      2      2         168

set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
make_person_subset <- function(DT){
  DT[, "person/subset" := person]
}
make_person_subset(class.point.dt)
make_person_subset(class.bench.score)
if(require(animint2)){
  viz <- animint(
    title="Train/predict on subsets, classification",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme_animint(height=400)+
      scale_fill_manual(values=set.colors)+
      scale_color_manual(values=c(spam="black","not spam"="white"))+
      geom_point(aes(
        x1, x2, color=label, fill=set.name),
        showSelected="iteration",
        size=3,
        stroke=2,
        shape=21,
        data=class.point.dt)+
      geom_path(aes(
        x1, x2, 
        subset=paste(algorithm, iteration, contour.i)),
        showSelected=c("iteration","algorithm"),
        color=algo.colors[["rpart"]],
        data=class.pred.dt)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both,
        space="free",
        scales="free")+
      scale_y_continuous(
        breaks=seq(-100, 100, by=2)),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(height=400)+
      theme(panel.margin=grid::unit(1, "lines"))+
      scale_y_continuous(
        "Classification error on test set",
        breaks=seq(0, 1, by=0.25))+
      scale_fill_manual(values=algo.colors)+
      scale_x_discrete(
        "People/subsets in train set")+
      geom_hline(aes(
        yintercept=yint),
        data=data.table(yint=0.5),
        color="grey50")+
      geom_point(aes(
        train.subsets, classif.ce, fill=algorithm),
        shape=1,
        size=5,
        stroke=2,
        color="black",
        color_off=NA,
        clickSelects="iteration",
        data=class.bench.score)+
      facet_grid(
        task_id ~ `person/subset`,
        labeller=label_both),
    diagram=ggplot()+
      ggtitle("Select train/test split")+
      theme_bw()+
      theme_animint(height=300)+
      facet_grid(
        . ~ train.subsets,
        scales="free",
        space="free")+
      scale_size_manual(values=c(subset=3, fold=1))+
      scale_color_manual(values=c(subset="orange", fold="grey50"))+
      geom_rect(aes(
        xmin=-Inf, xmax=Inf,
        color=rows,
        size=rows,
        ymin=display_row, ymax=display_end),
        fill=NA,
        data=inst$viz.rect.dt)+
      scale_fill_manual(values=set.colors)+
      geom_rect(aes(
        xmin=iteration-rect.expand, ymin=display_row,
        xmax=iteration+rect.expand, ymax=display_end,
        fill=set.name),
        clickSelects="iteration",
        data=inst$viz.set.dt)+
      geom_text(aes(
        ifelse(rows=="subset", Inf, -Inf),
        (display_row+display_end)/2,
        hjust=ifelse(rows=="subset", 1, 0),
        label=paste0(rows, "=", ifelse(rows=="subset", subset, fold))),
        data=data.table(train.name="same", inst$viz.rect.dt))+
      scale_x_continuous(
        "Split number / cross-validation iteration")+
      scale_y_continuous(
        "Row number"),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingSameOtherCV.Rmd")
  viz
}

if(FALSE){
  animint2pages(viz, "2023-12-13-train-predict-subsets-classification")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-13-train-predict-subsets-classification/

Conclusion

In this section we have shown how to use mlr3resampling for comparing test error of models trained on same/all/other subsets.

Variable size train resampler

The goal of this section is to explain how to ResamplingVariableSizeTrainCV, which can be used to determine how many train data are necessary to provide accurate predictions on a given test set.

Simulated regression problems

The code below creates data for simulated regression problems. First we define a vector of input values,

N <- 300
abs.x <- 10
set.seed(1)
x.vec <- runif(N, -abs.x, abs.x)
str(x.vec)
#>  num [1:300] -4.69 -2.56 1.46 8.16 -5.97 ...

Below we define a list of two true regression functions (tasks in mlr3 terminology) for our simulated data,

reg.pattern.list <- list(
  sin=sin,
  constant=function(x)0)

The constant function represents a regression problem which can be solved by always predicting the mean value of outputs (featureless is the best possible learning algorithm). The sin function will be used to generate data with a non-linear pattern that will need to be learned. Below we use a for loop over these two functions/tasks, to simulate the data which will be used as input to the learning algorithms:

library(data.table)
reg.task.list <- list()
reg.data.list <- list()
for(task_id in names(reg.pattern.list)){
  f <- reg.pattern.list[[task_id]]
  task.dt <- data.table(
    x=x.vec,
    y = f(x.vec)+rnorm(N,sd=0.5))
  reg.data.list[[task_id]] <- data.table(task_id, task.dt)
  reg.task.list[[task_id]] <- mlr3::TaskRegr$new(
    task_id, task.dt, target="y"
  )
}
(reg.data <- rbindlist(reg.data.list))
#>       task_id         x          y
#>        <char>     <num>      <num>
#>   1:      sin -4.689827  1.2248390
#>   2:      sin -2.557522 -0.5607042
#>   3:      sin  1.457067  0.8345056
#>   4:      sin  8.164156  0.4875994
#>   5:      sin -5.966361 -0.4321800
#>  ---                              
#> 596: constant  3.628850 -0.6728968
#> 597: constant -8.016618  0.5168327
#> 598: constant -7.621949 -0.4058882
#> 599: constant -8.991207  0.9008627
#> 600: constant  8.585078  0.8857710

In the table above, the input is x, and the output is y. Below we visualize these data, with one task in each facet/panel:

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x, y),
      data=reg.data)+
    facet_grid(task_id ~ ., labeller=label_both)
}

plot of chunk unnamed-chunk-24

In the plot above we can see two different simulated data sets (constant and sin). Note that the code above used the animint2 package, which provides interactive extensions to the static graphics of the ggplot2 package (see below section Interactive data viz).

Visualizing instance table

In the code below, we define a K-fold cross-validation experiment, with K=3 folds.

reg_size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
reg_size_cv$param_set$values$train_sizes <- 6
reg_size_cv
#> <ResamplingVariableSizeTrainCV> : Cross-Validation with variable size train sets
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 4
#>  $ folds         : int 3
#>  $ min_train_data: int 10
#>  $ random_seeds  : int 3
#>  $ train_sizes   : int 6

In the output above we can see the parameters of the resampling object, all of which should be integer scalars:

  • folds is the number of cross-validation folds.
  • min_train_data is the minimum number of train data to consider.
  • random_seeds is the number of random seeds, each of which determines a different random ordering of the train data. The random ordering determines which data are included in small train set sizes.
  • train_sizes is the number of train set sizes, evenly spaced on a log scale, from min_train_data to the max number of train data (determined by folds).

Below we instantiate the resampling on one of the tasks:

reg_size_cv$instantiate(reg.task.list[["sin"]])
reg_size_cv$instance
#> $iteration.dt
#>     test.fold  seed small_stratum_size train_size_i train_size
#>         <int> <int>              <int>        <int>      <int>
#>  1:         1     1                 10            1         10
#>  2:         1     1                 18            2         18
#>  3:         1     1                 33            3         33
#>  4:         1     1                 60            4         60
#>  5:         1     1                110            5        110
#>  6:         1     1                200            6        200
#>  7:         1     2                 10            1         10
#>  8:         1     2                 18            2         18
#>  9:         1     2                 33            3         33
#> 10:         1     2                 60            4         60
#> 11:         1     2                110            5        110
#> 12:         1     2                200            6        200
#> 13:         1     3                 10            1         10
#> 14:         1     3                 18            2         18
#> 15:         1     3                 33            3         33
#> 16:         1     3                 60            4         60
#> 17:         1     3                110            5        110
#> 18:         1     3                200            6        200
#> 19:         2     1                 10            1         10
#> 20:         2     1                 18            2         18
#> 21:         2     1                 33            3         33
#> 22:         2     1                 60            4         60
#> 23:         2     1                110            5        110
#> 24:         2     1                200            6        200
#> 25:         2     2                 10            1         10
#> 26:         2     2                 18            2         18
#> 27:         2     2                 33            3         33
#> 28:         2     2                 60            4         60
#> 29:         2     2                110            5        110
#> 30:         2     2                200            6        200
#> 31:         2     3                 10            1         10
#> 32:         2     3                 18            2         18
#> 33:         2     3                 33            3         33
#> 34:         2     3                 60            4         60
#> 35:         2     3                110            5        110
#> 36:         2     3                200            6        200
#> 37:         3     1                 10            1         10
#> 38:         3     1                 18            2         18
#> 39:         3     1                 33            3         33
#> 40:         3     1                 60            4         60
#> 41:         3     1                110            5        110
#> 42:         3     1                200            6        200
#> 43:         3     2                 10            1         10
#> 44:         3     2                 18            2         18
#> 45:         3     2                 33            3         33
#> 46:         3     2                 60            4         60
#> 47:         3     2                110            5        110
#> 48:         3     2                200            6        200
#> 49:         3     3                 10            1         10
#> 50:         3     3                 18            2         18
#> 51:         3     3                 33            3         33
#> 52:         3     3                 60            4         60
#> 53:         3     3                110            5        110
#> 54:         3     3                200            6        200
#>     test.fold  seed small_stratum_size train_size_i train_size
#>                           train                  test iteration train_min_size
#>                          <list>                <list>     <int>          <int>
#>  1: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         1             10
#>  2: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         2             18
#>  3: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         3             33
#>  4: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         4             60
#>  5: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         5            110
#>  6: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         6            200
#>  7: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...         7             10
#>  8: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...         8             18
#>  9: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...         9             33
#> 10: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...        10             60
#> 11: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...        11            110
#> 12: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...        12            200
#> 13:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        13             10
#> 14:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        14             18
#> 15:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        15             33
#> 16:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        16             60
#> 17:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        17            110
#> 18:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        18            200
#> 19: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        19             10
#> 20: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        20             18
#> 21: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        21             33
#> 22: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        22             60
#> 23: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        23            110
#> 24: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        24            200
#> 25: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        25             10
#> 26: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        26             18
#> 27: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        27             33
#> 28: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        28             60
#> 29: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        29            110
#> 30: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        30            200
#> 31:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        31             10
#> 32:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        32             18
#> 33:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        33             33
#> 34:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        34             60
#> 35:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        35            110
#> 36:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        36            200
#> 37: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        37             10
#> 38: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        38             18
#> 39: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        39             33
#> 40: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        40             60
#> 41: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        41            110
#> 42: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        42            200
#> 43: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        43             10
#> 44: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        44             18
#> 45: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        45             33
#> 46: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        46             60
#> 47: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        47            110
#> 48: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        48            200
#> 49:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        49             10
#> 50:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        50             18
#> 51:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        51             33
#> 52:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        52             60
#> 53:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        53            110
#> 54:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        54            200
#>                           train                  test iteration train_min_size
#> 
#> $id.dt
#>      row_id  fold
#>       <int> <int>
#>   1:      1     1
#>   2:      2     3
#>   3:      3     3
#>   4:      4     2
#>   5:      5     3
#>  ---             
#> 296:    296     2
#> 297:    297     1
#> 298:    298     1
#> 299:    299     3
#> 300:    300     2

Above we see the instance, which need not be examined by the user, but for informational purposes, it contains the following data:

  • iteration.dt has one row for each train/test split,
  • id.dt has one row for each data point.

Benchmark: computing test error

In the code below, we define two learners to compare,

(reg.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerRegrRpart$new(),
  mlr3::LearnerRegrFeatureless$new()))
#> [[1]]
#> <LearnerRegrRpart:regr.rpart>: Regression Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response]
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, selected_features, weights
#> 
#> [[2]]
#> <LearnerRegrFeatureless:regr.featureless>: Featureless Regression Learner
#> * Model: -
#> * Parameters: robust=FALSE
#> * Packages: mlr3, stats
#> * Predict Types:  [response], se
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, importance, missings, selected_features

The code above defines

  • regr.rpart: Regression Tree learning algorithm, which should be able to learn the non-linear pattern in the sin data (if there are enough data in the train set).
  • regr.featureless: Featureless Regression learning algorithm, which should be optimal for the constant data, and can be used as a baseline in the sin data. When the rpart learner gets smaller prediction error rates than featureless, then we know that it has learned some non-trivial relationship between inputs and outputs.

In the code below, we define the benchmark grid, which is all combinations of tasks (constant and sin), learners (rpart and featureless), and the one resampling method.

(reg.bench.grid <- mlr3::benchmark_grid(
  reg.task.list,
  reg.learner.list,
  reg_size_cv))
#>        task          learner             resampling
#>      <char>           <char>                 <char>
#> 1:      sin       regr.rpart variable_size_train_cv
#> 2:      sin regr.featureless variable_size_train_cv
#> 3: constant       regr.rpart variable_size_train_cv
#> 4: constant regr.featureless variable_size_train_cv

In the code below, we execute the benchmark experiment (optionally in parallel using the multisession future plan).

if(FALSE){
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(reg.bench.result <- mlr3::benchmark(
  reg.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 216 rows with 4 resampling runs
#>  nr  task_id       learner_id          resampling_id iters warnings errors
#>   1      sin       regr.rpart variable_size_train_cv    54        0      0
#>   2      sin regr.featureless variable_size_train_cv    54        0      0
#>   3 constant       regr.rpart variable_size_train_cv    54        0      0
#>   4 constant regr.featureless variable_size_train_cv    54        0      0

The code below computes the test error for each split, and visualizes the information stored in the first row of the result:

reg.bench.score <- mlr3resampling::score(reg.bench.result)
reg.bench.score[1]
#>    test.fold  seed small_stratum_size train_size_i train_size
#>        <int> <int>              <int>        <int>      <int>
#> 1:         1     1                 10            1         10
#>                          train                  test iteration train_min_size
#>                         <list>                <list>     <int>          <int>
#> 1: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         1             10
#>                                   uhash    nr           task task_id
#>                                  <char> <int>         <list>  <char>
#> 1: e3dabfcd-49c7-428b-bb5d-ec6df8f744fa     1 <TaskRegr:sin>     sin
#>                          learner learner_id                      resampling
#>                           <list>     <char>                          <list>
#> 1: <LearnerRegrRpart:regr.rpart> regr.rpart <ResamplingVariableSizeTrainCV>
#>             resampling_id       prediction  regr.mse algorithm
#>                    <char>           <list>     <num>    <char>
#> 1: variable_size_train_cv <PredictionRegr> 0.8008255     rpart

The output above contains all of the results related to a particular train/test split. In particular for our purposes, the interesting columns are:

  • test.fold is the cross-validation fold ID.
  • seed is the random seed used to determine the train set order.
  • train_size is the number of data in the train set.
  • train and test are vectors of row numbers assigned to each set.
  • iteration is an ID for the train/test split, for a particular learning algorithm and task. It is the row number of iteration.dt (see instance above), which has one row for each unique combination of test.fold, seed, and train_size.
  • learner is the mlr3 learner object, which can be used to compute predictions on new data (including a grid of inputs, to show predictions in the visualization below).
  • regr.mse is the mean squared error on the test set.
  • algorithm is the name of the learning algorithm (same as learner_id but without regr. prefix).

The code below visualizes the resulting test accuracy numbers.

train_size_vec <- unique(reg.bench.score$train_size)
if(require(animint2)){
  ggplot()+
    scale_x_log10(
      breaks=train_size_vec)+
    scale_y_log10()+
    geom_line(aes(
      train_size, regr.mse,
      group=paste(algorithm, seed),
      color=algorithm),
      shape=1,
      data=reg.bench.score)+
    geom_point(aes(
      train_size, regr.mse, color=algorithm),
      shape=1,
      data=reg.bench.score)+
    facet_grid(
      test.fold~task_id,
      labeller=label_both,
      scales="free")
}

plot of chunk unnamed-chunk-31

Above we plot the test error for each fold and train set size. There is a different panel for each task and test fold. Each line represents a random seed (ordering of data in train set), and each dot represents a specific train set size. So the plot above shows that some variation in test error, for a given test fold, is due to the random ordering of the train data.

Below we summarize each train set size, by taking the mean and standard deviation over each random seed.

reg.mean.dt <- dcast(
  reg.bench.score,
  task_id + train_size + test.fold + algorithm ~ .,
  list(mean, sd),
  value.var="regr.mse")
if(require(animint2)){
  ggplot()+
    scale_x_log10(
      breaks=train_size_vec)+
    scale_y_log10()+
    geom_ribbon(aes(
      train_size,
      ymin=regr.mse_mean-regr.mse_sd,
      ymax=regr.mse_mean+regr.mse_sd,
      fill=algorithm),
      alpha=0.5,
      data=reg.mean.dt)+
    geom_line(aes(
      train_size, regr.mse_mean, color=algorithm),
      shape=1,
      data=reg.mean.dt)+
    facet_grid(
      test.fold~task_id,
      labeller=label_both,
      scales="free")
}

plot of chunk unnamed-chunk-32

The plot above shows a line for the mean, and a ribbon for the standard deviation, over the three random seeds. It is clear from the plot above that

  • in constant task, the featureless always has smaller or equal prediction error rates than rpart, which indicates that rpart sometimes overfits for large sample sizes.
  • in sin task, more than 30 samples are required for rpart to be more accurate than featureless, which indicates it has learned a non-trivial relationship between input and output.

Interactive data viz

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

grid.dt <- data.table(x=seq(-abs.x, abs.x, l=101), y=0)
grid.task <- mlr3::TaskRegr$new("grid", grid.dt, target="y")
pred.dt.list <- list()
point.dt.list <- list()
for(score.i in 1:nrow(reg.bench.score)){
  reg.bench.row <- reg.bench.score[score.i]
  task.dt <- data.table(
    reg.bench.row$task[[1]]$data(),
    reg.bench.row$resampling[[1]]$instance$id.dt)
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=reg.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ]
  point.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(task_id, iteration)],
    i.points)
  i.learner <- reg.bench.row$learner[[1]]
  pred.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(
      task_id, iteration, algorithm
    )],
    as.data.table(
      i.learner$predict(grid.task)
    )[, .(x=grid.dt$x, y=response)]
  )
}
(pred.dt <- rbindlist(pred.dt.list))
#>         task_id iteration   algorithm     x           y
#>          <char>     <int>      <char> <num>       <num>
#>     1:      sin         1       rpart -10.0  0.25011658
#>     2:      sin         1       rpart  -9.8  0.25011658
#>     3:      sin         1       rpart  -9.6  0.25011658
#>     4:      sin         1       rpart  -9.4  0.25011658
#>     5:      sin         1       rpart  -9.2  0.25011658
#>    ---                                                 
#> 21812: constant        54 featureless   9.2 -0.03385654
#> 21813: constant        54 featureless   9.4 -0.03385654
#> 21814: constant        54 featureless   9.6 -0.03385654
#> 21815: constant        54 featureless   9.8 -0.03385654
#> 21816: constant        54 featureless  10.0 -0.03385654
(point.dt <- rbindlist(point.dt.list))
#>         task_id iteration set.name row_id          y         x  fold
#>          <char>     <int>   <char>  <int>      <num>     <num> <int>
#>     1:      sin         1     test      1  1.2248390 -4.689827     1
#>     2:      sin         1   unused      2 -0.5607042 -2.557522     3
#>     3:      sin         1   unused      3  0.8345056  1.457067     3
#>     4:      sin         1   unused      4  0.4875994  8.164156     2
#>     5:      sin         1   unused      5 -0.4321800 -5.966361     3
#>    ---                                                              
#> 64796: constant        54    train    296 -0.6728968  3.628850     2
#> 64797: constant        54    train    297  0.5168327 -8.016618     1
#> 64798: constant        54    train    298 -0.4058882 -7.621949     1
#> 64799: constant        54     test    299  0.9008627 -8.991207     3
#> 64800: constant        54    train    300  0.8857710  8.585078     2
set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
if(require(animint2)){
  viz <- animint(
    title="Variable size train set, regression",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme_animint(height=400)+
      scale_fill_manual(values=set.colors)+
      geom_point(aes(
        x, y, fill=set.name),
        showSelected="iteration",
        size=3,
        shape=21,
        data=point.dt)+
      scale_size_manual(values=c(
        featureless=3,
        rpart=2))+
      scale_color_manual(values=algo.colors)+
      geom_line(aes(
        x, y,
        color=algorithm,
        size=algorithm,
        group=paste(algorithm, iteration)),
        showSelected="iteration",
        data=pred.dt)+
      facet_grid(
        task_id ~ .,
        labeller=label_both),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(width=500)+
      theme(
        panel.margin=grid::unit(1, "lines"),
        legend.position="none")+
      scale_y_log10(
        "Mean squared error on test set")+
      scale_color_manual(values=algo.colors)+
      scale_x_log10(
        "Train set size",
        breaks=train_size_vec)+
      geom_line(aes(
        train_size, regr.mse,
        group=paste(algorithm, seed),
        color=algorithm),
        clickSelects="seed",
        alpha_off=0.2,
        showSelected="algorithm",
        size=4,
        data=reg.bench.score)+
      facet_grid(
        test.fold~task_id,
        labeller=label_both,
        scales="free")+
      geom_point(aes(
        train_size, regr.mse,
        color=algorithm),
        size=5,
        stroke=3,
        fill="black",
        fill_off=NA,
        showSelected=c("algorithm","seed"),
        clickSelects="iteration",
        data=reg.bench.score),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/Simulations.Rmd")
  viz
}

if(FALSE){
  animint2pages(viz, "2023-12-26-train-sizes-regression")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-26-train-sizes-regression/

The interactive data viz consists of two plots:

  • The first plot shows the data, with each point colored according to the set it was assigned, in the currently selected split/iteration. The red/blue lines additionally show the learned prediction functions for the currently selected split/iteration.
  • The second plot shows the test error rates, as a function of train set size. Clicking a line selects the corresponding random seed, which makes the corresponding points on that line appear. Clicking a point selects the corresponding iteration (seed, test fold, and train set size).

Simulated classification problems

Whereas in the section above, we focused on regression (output is a real number), in this section we simulate a binary classification problem (output if a factor with two levels).

class.N <- 300
class.abs.x <- 1
rclass <- function(){
  runif(class.N, -class.abs.x, class.abs.x)
}
library(data.table)
set.seed(1)
class.x.dt <- data.table(x1=rclass(), x2=rclass())
class.fun.list <- list(
  constant=function(...)0.5,
  xor=function(x1, x2)xor(x1>0, x2>0))
class.data.list <- list()
class.task.list <- list()
for(task_id in names(class.fun.list)){
  class.fun <- class.fun.list[[task_id]]
  y <- factor(ifelse(
    class.x.dt[, class.fun(x1, x2)+rnorm(class.N, sd=0.5)]>0.5,
    "spam", "not"))
  task.dt <- data.table(class.x.dt, y)
  this.task <- mlr3::TaskClassif$new(
    task_id, task.dt, target="y")
  this.task$col_roles$stratum <- "y"
  class.task.list[[task_id]] <- this.task
  class.data.list[[task_id]] <- data.table(task_id, task.dt)
}
(class.data <- rbindlist(class.data.list))
#>       task_id         x1           x2      y
#>        <char>      <num>        <num> <fctr>
#>   1: constant -0.4689827  0.347424466   spam
#>   2: constant -0.2557522 -0.810284289    not
#>   3: constant  0.1457067 -0.014807758   spam
#>   4: constant  0.8164156 -0.076896319    not
#>   5: constant -0.5966361 -0.249566938   spam
#>  ---                                        
#> 596:      xor  0.3628850  0.297101895    not
#> 597:      xor -0.8016618 -0.040328411    not
#> 598:      xor -0.7621949 -0.009871789   spam
#> 599:      xor -0.8991207 -0.240254817    not
#> 600:      xor  0.8585078 -0.099029126   spam

The simulated data table above consists of two input features (x1 and x2) along with an output/label to predict (y). Below we count the number of times each label appears in each task:

class.data[, .(count=.N), by=.(task_id, y)]
#>     task_id      y count
#>      <char> <fctr> <int>
#> 1: constant   spam   143
#> 2: constant    not   157
#> 3:      xor   spam   145
#> 4:      xor    not   155

The table above shows that the spam label is the minority class (not is majority, so that will be the prediction of the featureless baseline). Below we visualize the data in the feature space:

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x1, x2, color=y),
      shape=1,
      data=class.data)+
    facet_grid(. ~ task_id, labeller=label_both)+
    coord_equal()
}

plot of chunk unnamed-chunk-35

The plot above shows how the output y is related to the two inputs x1 and x2, for the two tasks.

  • For the constant task, the two inputs are not related to the output.
  • For the xor task, the spam label is associated with either x1 or x2 being negative (but not both).

In the mlr3 code below, we define a list of learners, our resampling method, and a benchmark grid:

class.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerClassifRpart$new(),
  mlr3::LearnerClassifFeatureless$new())
size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
(class.bench.grid <- mlr3::benchmark_grid(
  class.task.list,
  class.learner.list,
  size_cv))
#>        task             learner             resampling
#>      <char>              <char>                 <char>
#> 1: constant       classif.rpart variable_size_train_cv
#> 2: constant classif.featureless variable_size_train_cv
#> 3:      xor       classif.rpart variable_size_train_cv
#> 4:      xor classif.featureless variable_size_train_cv

Below we run the learning algorithm for each of the train/test splits defined by our benchmark grid:

if(FALSE){
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(class.bench.result <- mlr3::benchmark(
  class.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 180 rows with 4 resampling runs
#>  nr  task_id          learner_id          resampling_id iters warnings errors
#>   1 constant       classif.rpart variable_size_train_cv    45        0      0
#>   2 constant classif.featureless variable_size_train_cv    45        0      0
#>   3      xor       classif.rpart variable_size_train_cv    45        0      0
#>   4      xor classif.featureless variable_size_train_cv    45        0      0

Below we compute scores (test error) for each resampling iteration, and show the first row of the result.

class.bench.score <- mlr3resampling::score(class.bench.result)
class.bench.score[1]
#>    test.fold  seed small_stratum_size train_size_i train_size
#>        <int> <int>              <int>        <int>      <int>
#> 1:         1     1                 10            1         21
#>                          train                  test iteration train_min_size
#>                         <list>                <list>     <int>          <int>
#> 1: 132,239, 10,216,245,276,...  5, 6, 8,21,23,28,...         1             21
#>                                   uhash    nr                   task  task_id
#>                                  <char> <int>                 <list>   <char>
#> 1: 6e475cca-b056-41c6-b61e-00299178f921     1 <TaskClassif:constant> constant
#>                                learner    learner_id
#>                                 <list>        <char>
#> 1: <LearnerClassifRpart:classif.rpart> classif.rpart
#>                         resampling          resampling_id          prediction
#>                             <list>                 <char>              <list>
#> 1: <ResamplingVariableSizeTrainCV> variable_size_train_cv <PredictionClassif>
#>    classif.ce algorithm
#>         <num>    <char>
#> 1:  0.4257426     rpart

The output above has columns which are very similar to the regression example in the previous section. The main difference is the classif.ce column, which is the classification error on the test set.

Finally we plot the test error values below.

if(require(animint2)){
  ggplot()+
    geom_line(aes(
      train_size, classif.ce,
      group=paste(algorithm, seed),
      color=algorithm),
      shape=1,
      data=class.bench.score)+
    geom_point(aes(
      train_size, classif.ce, color=algorithm),
      shape=1,
      data=class.bench.score)+
    facet_grid(
      task_id ~ test.fold,
      labeller=label_both,
      scales="free")+
    scale_x_log10()
}

plot of chunk unnamed-chunk-39

It is clear from the plot above that

  • in constant task, rpart does not have significantly lower error rates than featureless, which is expected, because the best prediction function is constant (predict the most frequent class, no relationship between inputs and output).
  • in xor task, more than 30 samples are required for rpart to be more accurate than featureless, which indicates it has learned a non-trivial relationship between inputs and output.

Exercise for the reader: compute and plot mean and SD for these classification tasks, similar to the plot for the regression tasks in the previous section.

Interactive visualization of data, test error, and splits

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

class.grid.vec <- seq(-class.abs.x, class.abs.x, l=21)
class.grid.dt <- CJ(x1=class.grid.vec, x2=class.grid.vec)
class.pred.dt.list <- list()
class.point.dt.list <- list()
for(score.i in 1:nrow(class.bench.score)){
  class.bench.row <- class.bench.score[score.i]
  task.dt <- data.table(
    class.bench.row$task[[1]]$data(),
    class.bench.row$resampling[[1]]$instance$id.dt)
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=class.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ][]
  class.point.dt.list[[score.i]] <- data.table(
    class.bench.row[, .(task_id, iteration)],
    i.points)
  if(class.bench.row$algorithm!="featureless"){
    i.learner <- class.bench.row$learner[[1]]
    i.learner$predict_type <- "prob"
    i.task <- class.bench.row$task[[1]]
    grid.class.task <- mlr3::TaskClassif$new(
      "grid", class.grid.dt[, label:=factor(NA,levels(task.dt$y))], target="label")
    pred.grid <- as.data.table(
      i.learner$predict(grid.class.task)
    )[, data.table(class.grid.dt, prob.spam)]
    pred.wide <- dcast(pred.grid, x1 ~ x2, value.var="prob.spam")
    prob.mat <- as.matrix(pred.wide[,-1])
    if(length(table(prob.mat))>1){
      contour.list <- contourLines(
        class.grid.vec, class.grid.vec, prob.mat, levels=0.5)
      class.pred.dt.list[[score.i]] <- data.table(
        class.bench.row[, .(
          task_id, iteration, algorithm
        )],
        data.table(contour.i=seq_along(contour.list))[, {
          do.call(data.table, contour.list[[contour.i]])[, .(level, x1=x, x2=y)]
        }, by=contour.i]
      )
    }
  }
}
(class.pred.dt <- rbindlist(class.pred.dt.list))
#>        task_id iteration algorithm contour.i level     x1          x2
#>         <char>     <int>    <char>     <int> <num>  <num>       <num>
#>    1: constant         1     rpart         1   0.5 0.0375 -1.00000000
#>    2: constant         1     rpart         1   0.5 0.0375 -0.90000000
#>    3: constant         1     rpart         1   0.5 0.0375 -0.80000000
#>    4: constant         1     rpart         1   0.5 0.0375 -0.70000000
#>    5: constant         1     rpart         1   0.5 0.0375 -0.60000000
#>   ---                                                                
#> 5190:      xor        45     rpart         2   0.5 0.6000  0.04888889
#> 5191:      xor        45     rpart         2   0.5 0.7000  0.04888889
#> 5192:      xor        45     rpart         2   0.5 0.8000  0.04888889
#> 5193:      xor        45     rpart         2   0.5 0.9000  0.04888889
#> 5194:      xor        45     rpart         2   0.5 1.0000  0.04888889
(class.point.dt <- rbindlist(class.point.dt.list))
#>         task_id iteration set.name row_id      y         x1           x2  fold
#>          <char>     <int>   <char>  <int> <fctr>      <num>        <num> <int>
#>     1: constant         1   unused      1   spam -0.4689827  0.347424466     3
#>     2: constant         1   unused      2    not -0.2557522 -0.810284289     2
#>     3: constant         1   unused      3   spam  0.1457067 -0.014807758     3
#>     4: constant         1    train      4    not  0.8164156 -0.076896319     3
#>     5: constant         1     test      5   spam -0.5966361 -0.249566938     1
#>    ---                                                                        
#> 53996:      xor        45    train    296    not  0.3628850  0.297101895     2
#> 53997:      xor        45    train    297    not -0.8016618 -0.040328411     2
#> 53998:      xor        45     test    298   spam -0.7621949 -0.009871789     3
#> 53999:      xor        45     test    299    not -0.8991207 -0.240254817     3
#> 54000:      xor        45    train    300   spam  0.8585078 -0.099029126     2

set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
if(require(animint2)){
  viz <- animint(
    title="Variable size train sets, classification",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme(panel.margin=grid::unit(1, "lines"))+
      theme_animint(width=600)+
      coord_equal()+
      scale_fill_manual(values=set.colors)+
      scale_color_manual(values=c(spam="black","not spam"="white"))+
      geom_point(aes(
        x1, x2, color=y, fill=set.name),
        showSelected="iteration",
        size=3,
        stroke=2,
        shape=21,
        data=class.point.dt)+
      geom_path(aes(
        x1, x2, 
        group=paste(algorithm, iteration, contour.i)),
        showSelected=c("iteration","algorithm"),
        color=algo.colors[["rpart"]],
        data=class.pred.dt)+
      facet_grid(
        . ~ task_id,
        labeller=label_both,
        space="free",
        scales="free"),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(height=400)+
      theme(panel.margin=grid::unit(1, "lines"))+
      scale_y_continuous(
        "Classification error on test set")+
      scale_color_manual(values=algo.colors)+
      scale_x_log10(
        "Train set size")+
      geom_line(aes(
        train_size, classif.ce,
        group=paste(algorithm, seed),
        color=algorithm),
        clickSelects="seed",
        alpha_off=0.2,
        showSelected="algorithm",
        size=4,
        data=class.bench.score)+
      facet_grid(
        test.fold~task_id,
        labeller=label_both,
        scales="free")+
      geom_point(aes(
        train_size, classif.ce,
        color=algorithm),
        size=5,
        stroke=3,
        fill="black",
        fill_off=NA,
        showSelected=c("algorithm","seed"),
        clickSelects="iteration",
        data=class.bench.score),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingVariableSizeTrainCV.Rmd")
  viz
}

if(FALSE){
  animint2pages(viz, "2023-12-27-train-sizes-classification")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-27-train-sizes-classification/

The interactive data viz consists of two plots

  • The first plot shows the data, with each point colored according to its label/y value (black outline for spam, white outline for not), and the set it was assigned (fill color) in the currently selected split/iteration. The red lines additionally show the learned decision boundary for rpart, given the currently selected split/iteration. For constant, the ideal decision boundary is none (always predict the most frequent class), and for xor, the ideal decision boundary looks like a plus sign.
  • The second plot shows the test error rates, as a function of train set size. Clicking a line selects the corresponding random seed, which makes the corresponding points on that line appear. Clicking a point selects the corresponding iteration (seed, test fold, and train set size).

Conclusion

In this section we have shown how to use mlr3resampling for comparing test error of models trained on different sized train sets.

Session info

sessionInfo()
#> R version 4.4.1 (2024-06-14)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 22.04.4 LTS
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0 
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0
#> 
#> locale:
#>  [1] LC_CTYPE=fr_FR.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=fr_FR.UTF-8        LC_COLLATE=C              
#>  [5] LC_MONETARY=fr_FR.UTF-8    LC_MESSAGES=fr_FR.UTF-8   
#>  [7] LC_PAPER=fr_FR.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=fr_FR.UTF-8 LC_IDENTIFICATION=C       
#> 
#> time zone: America/New_York
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] lgr_0.4.4              animint2_2024.6.6      directlabels_2024.1.21
#> [4] mlr3_0.20.0            ggplot2_3.5.1          data.table_1.15.99    
#> 
#> loaded via a namespace (and not attached):
#>  [1] utf8_1.2.4              future_1.33.1           generics_0.1.3         
#>  [4] stringi_1.8.3           listenv_0.9.1           digest_0.6.34          
#>  [7] magrittr_2.0.3          evaluate_0.23           grid_4.4.1             
#> [10] plyr_1.8.9              backports_1.4.1         fansi_1.0.6            
#> [13] mlr3resampling_2024.7.3 scales_1.3.0            RhpcBLASctl_0.23-42    
#> [16] mlr3tuning_1.0.0        codetools_0.2-20        mlr3measures_0.5.0     
#> [19] palmerpenguins_0.1.1    cli_3.6.2               rlang_1.1.3            
#> [22] crayon_1.5.2            parallelly_1.36.0       future.apply_1.11.1    
#> [25] munsell_0.5.0           commonmark_1.9.1        withr_3.0.0            
#> [28] nc_2024.2.21            tools_4.4.1             parallel_4.4.1         
#> [31] reshape2_1.4.4          RJSONIO_1.3-1.9         uuid_1.2-0             
#> [34] checkmate_2.3.1         dplyr_1.1.4             colorspace_2.1-0       
#> [37] globals_0.16.2          bbotk_1.0.0             vctrs_0.6.5            
#> [40] R6_2.5.1                mime_0.12               rpart_4.1.23           
#> [43] lifecycle_1.0.4         stringr_1.5.1           mlr3misc_0.15.1        
#> [46] pkgconfig_2.0.3         pillar_1.9.0            gtable_0.3.4           
#> [49] Rcpp_1.0.12             glue_1.7.0              paradox_1.0.0          
#> [52] xfun_0.45               tibble_3.2.1            tidyselect_1.2.0       
#> [55] highr_0.11              knitr_1.47              farver_2.1.1           
#> [58] labeling_0.4.3          compiler_4.4.1          quadprog_1.5-8         
#> [61] markdown_1.13