Enjoy R: Compare ROC curves of different multinomial classification models

When we think of a ROC curve, we usually refer it to a binary classification problem.

For a multiclass case, it comes to be less used, also because it loses most of its explicative power.

However, it would be good to use it also in this scenario in order to own one more weapon for diagnostics.

A possibility I have met across surfing the net is to implement an OVA approach (one vs all the others). It means to pick a class a time, say class C, and draw the ROC for C against not C.

We end up having as many curves as the number of classes of the qualitative response.

This is true for one specific model, but, if we aim to compare n models with a k-class response, we end up with n X k curves on the plot, which give you an image of the behavior of each model with the respect of each class.

To visually evaluate the overall (general) behavior of each model, we could get a central tendency measure, which might be:

  • a simple arithmetic mean;
  • a weighted mean, taking as weights the number of observations belonging to each class

A weighted average might be more correct, but even more “misleading” in telling you your model is good: if your data has quite all observations from class A and a few from the others and your model predicts always A, your weighted mean would say your model is quite ok, whereas a simple mean (which weighs the classes equally) would say your model is a disaster. That’s why looking at a simple average might be too extreme, but at least it indicates you that something should be reviewed.

What follows is a procedure that ends with a graph representing all the info you should evaluate your model with a critical eye.

First the function:

compare_multiROC <- function(truth, pred, 
                             abscissa = "tpr", ordinate = "fpr") {
  
  stopifnot(is.list(pred))
  
  truth <- ordered(truth); lev <- levels(truth)
  pred <- lapply(pred, function(x) ordered(x, levels = lev))
  truth <- if(require(dummies)) data.frame(dummy(truth))
  pred <- lapply(pred, function(x) data.frame(dummy(x)))
  
  library(ROCR)
  appr <- interp <- list()
  
  for(j in seq_along(pred)) {
    interp[[j]] <- lapply(seq_len(ncol(truth)), function(i) {
    predob <- prediction(pred[[j]][, i], truth[, i])
    perf <- performance(predob, abscissa, ordinate)
    cbind(perf@x.values[[1]], perf@y.values[[1]])
    })
    appr[[j]] <- lapply(interp[[j]], function(f) approxfun(f[, 1], f[, 2]))
  }
  
  Xs <- sort(unique(unlist(lapply(interp, function(x) 
                                       lapply(x, function(X) {X[, 1]})))))
  
  lapply(seq_along(appr), function(j) {
    Ys <- sapply(appr[[j]], function(f) f(Xs)); colnames(Ys) <- lev
    data.frame(X = Xs, Ys, avg = rowMeans(Ys),
               w.avg = apply(Ys, 1, weighted.mean, colSums(truth)),
               check.names = F)
  })
}

Let’s create some random data to see what the output is:

set.seed(103)

n <- 1e+04

truth <- factor(sample(letters[1:3], n, T, c(.7, .2, .1)))

cl1 <- cl2 <- cl3 <- truth

trC <- truth == "c"; trClen <- length(truth[trC])

# building cl1
na <- is.na(sample(c(NA, 0), trClen, T, c(.2, .8))); cl1[trC][na] <- "a"

# building cl2
na <- is.na(sample(c(NA, 0), trClen, T, c(.3, .7))); cl2[trC][na] <- "a"

# building cl3
na <- is.na(sample(c(NA, 0), trClen, T, c(.35, .65))); cl3[trC][na] <- "a"

# applying the function
valuesROC <- compare_multiROC(truth, list(cl1, cl2, cl3))
names(valuesROC) <- paste0("cl", 1:3)
valuesROC
$cl1
           X a         b         c       avg     w.avg
1 0.00000000 0 0.5000000 0.3969124 0.2989708 0.1414000
2 0.06820428 1 0.5341021 0.4380455 0.6573826 0.8489559
3 0.09983526 1 0.5499176 0.4571218 0.6690131 0.8540833
4 0.11696870 1 0.5584843 0.4674547 0.6753130 0.8568606
5 1.00000000 1 1.0000000 1.0000000 1.0000000 1.0000000

$cl2
           X         a         b         c       avg     w.avg
1 0.00000000 0.0000000 0.5000000 0.3491036 0.2830345 0.1366000
2 0.06820428 0.6831683 0.5341021 0.3934975 0.5369227 0.6238100
3 0.09983526 1.0000000 0.5499176 0.4140860 0.6546679 0.8497625
4 0.11696870 1.0000000 0.5584843 0.4252381 0.6612408 0.8526221
5 1.00000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000

$cl3
           X         a         b         c       avg     w.avg
1 0.00000000 0.0000000 0.5000000 0.3232072 0.2744024 0.1340000
2 0.06820428 0.5830986 0.5341021 0.3693673 0.4955227 0.5516888
3 0.09983526 0.8535211 0.5499176 0.3907750 0.5980712 0.7453995
4 0.11696870 1.0000000 0.5584843 0.4023707 0.6536184 0.8503262
5 1.00000000 1.0000000 1.0000000 1.0000000 1.0000000 1.0000000

Now we should “melt this data and create a long data frame to be provided to ggplot.

library(reshape2); library(ggplot2); library(dplyr)

dat <- mutate(melt(valuesROC, id.vars = "X"),
              facet = ordered(ifelse(variable %in% c("avg", "w.avg"),
                                     "Averages", "Class Values"),
                              levels = c("Class Values", "Averages")))
names(dat)[-5] <- c("True positive rate", "Curve", 
                    "False positive rate", "Model")

Finally, the plot:

ggplot(dat, aes(`True positive rate`, `False positive rate`,
                color = Curve, lty = Model)) +
geom_path() +
facet_wrap(~facet) +
scale_x_continuous(breaks = seq(0, 1, .1)) +
scale_y_continuous(breaks = seq(0, 1, .1)) +
geom_segment(x = 0, xend = 1, y = 0, yend = 1, color = "grey90") +
geom_ribbon(aes(x = `True positive rate`, ymin = 0, 
                ymax = `True positive rate`), 
            color = NA, fill = "pink", alpha = .2) +
geom_ribbon(aes(x = `True positive rate`, 
                ymin = `True positive rate`, ymax = 1),
            color = NA, fill = "lightblue", alpha = .2) +
theme_bw() +
guides(linetype = guide_legend(
override.aes = list(colour = NULL, fill = "white"))) +
ggtitle("ROC curves of different multinomial classification models")

Rplot01

Advertisements

1 Comment

  1. Re-executing the code after a while, I noticed that the ggplot2 function override.aes produces an error (IIRC, it didn’t when I published this post). It is not that big deal, you can just skip those two lines of code related to guides. Or, if you have time – which i don’t -, try to solve this little problem and post a comment with your solution 😉

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: