Enjoy R: Stratified sampling by group variables

Some years ago I published this article. Starting from a dataset, it provided a couple of functions to split such data into training and test set using stratified sampling.

Stratified sampling is a statistical technique allowing to sample the same percentage of rows from each of the groups present in the data. This is nice in supervised learning, because it may solve the problem of unbalanced groups and ensure representativeness in a given set.

After years of silence, a comment showed up under my article a couple of months ago. It said it’d be nice to have:

  1. the possibility of using more than one column for stratification
  2. a split among training, validation, and test set (not just between training and test).

I really appreciated those remarks, because I strongly believe they dealt with actual problems a practitioner encounters when training a supervised model. Therefore, I decided to edit the code of those functions according to these needs.

To be honest, it was already possible to use more than one group variable, but I actually did not consider showing how. Moreover, a split into three sets instead of just two could be obtained by applying the same procedure twice, both to the whole (initial) dataset and to the resulting test set.

Anyway, let’s see the new code.

strat_sample <- function(data, gr_variab, 
                         tr_percent_train, tr_percent_valid, 
                         tresh_valid = 0, tresh_test = 0, seed) {            

  stopifnot(tr_percent_train + tr_percent_valid > 0 & 
              tr_percent_train + tr_percent_valid < 1)
  
  if(require(dplyr) & require(magrittr)) {
    
    if(!missing(seed)) set.seed(seed)
    
    names0 <- names(data)
    gr_variab <- which(names0 == gr_variab)
    names(data) <- make.unique(c("n", "trainRows", "validRows", 
                                 "SET", names0))[-(1:4)]
    gr_variab <- names(data)[gr_variab]
    
    data %<>%
      sample_frac %>%
      group_by_(gr_variab) %>%
      mutate(n = n(), 
             trainRows = tr_percent_train * n,
             validRows = trainRows + tr_percent_valid * n)
    
    with(data, if(any(validRows < tresh_valid))
      warning("Validation Set: zero or too few observations in one or more groups"))
    with(data, if(any(n - trainRows - validRows < tresh_test))
      warning("Test Set: zero or too few observations in one or more groups"))
    
    data %<>%
      mutate(SET = ifelse(row_number() <= trainRows, 
                          "Train", 
                          ifelse(row_number() <= validRows,                                   
                                 "Validation",                                   
                                 "Test"))) %>%
      select(-n, -trainRows, -validRows) %>%
      ungroup
    
    names(data) <- make.unique(c(names0, "SET"))
    
    data
    
  } 
}


extract_set <- function(data, whichSET) {
  
  stopifnot(is.element(whichSET, c("Train", 
                                   "Validation", 
                                   "Test")))
  
  if(require(dplyr) & require(magrittr)) {
    
    variab <- names(data)[ncol(data)]
    condit <- get(variab, data) == whichSET                         
    data %>%
      filter_(~ condit) %>%
      select_(paste0("-", variab))
    
  }
}

Let’s now provide a sample dataset with three group variables (x, y, z).

n <- 1e+5

set.seed(386)

Df <- data.frame(V1 = rnorm(n),
                 x = sample(letters[1:2], n, replace = T),
                 V2 = rt(n, df = 4),
                 V3 = rpois(n, lambda = 1),
                 y = sample(letters[1:4], n, replace = T,
                            prob = c(.33, .33, .33, .01)),
                 z = sample(letters[1:2], n, replace = T))

Let’s split the data into training (60%), validation (20%), and test (20%) set. I would like to have a warning message if less than 300 units were present in at least one group in the validation set, and another warning message if less than 300 units were present in at least one group in the test set.

groups <- strat_sample(data = Df, 
                       gr_variab = list("x", "y", "z"), 
                       tr_percent_train = .6, 
                       tr_percent_valid = .2, 
                       tresh_valid = 300, 
                       tresh_test = 300)

## Warning messages:
## 1: In eval(expr, envir, enclos) :
##   Validation Set: zero or too few observations in one or more groups
## 2: In eval(expr, envir, enclos) :
##   Test Set: zero or too few observations in one or more groups

It says that in at least one group there are less than 300 rows in both the validation set and the test set. Let’s check that for variable y, whose level “d” is only present 1% of the time.

with(groups, table(y, SET))
##    SET
## y    Test Train Validation
##   a  6628 19882       6628
##   b  6593 19777       6592
##   c  6570 19708       6570
##   d   211   631        210

We see values less than 300, which was what the warning message said.
Finally, let’s check whether the percentages of rows in the groups are ok.

with(groups, prop.table(table(paste0(x, y, z), SET), 1))
##      SET
##            Test     Train Validation
##   aaa 0.2011126 0.5991051  0.1997823
##   aab 0.1952526 0.6048691  0.1998783
##   aba 0.2019184 0.5990772  0.1990044
##   abb 0.2045263 0.5962441  0.1992296
##   aca 0.1941877 0.6006809  0.2051313
##   acb 0.2059075 0.5995362  0.1945563
##   ada 0.1974790 0.6092437  0.1932773
##   adb 0.1821561 0.6282528  0.1895911
##   baa 0.2069582 0.5908495  0.2021923
##   bab 0.1965864 0.6052536  0.1981600
##   bba 0.1935918 0.6046227  0.2017855
##   bbb 0.1999515 0.6000971  0.1999515
##   bca 0.2016313 0.5970973  0.2012714
##   bcb 0.1982950 0.6026686  0.1990363
##   bda 0.1660377 0.5849057  0.2490566
##   bdb 0.2535714 0.5785714  0.1678571

Leave a comment