Binary Classifier Diagnostics
While kstest
and roc
provide diagnostic measures for comparing model performance, we may want to produce graphs and tables to document its performance, bcdiag
allows us to do this easily.
using ROCKS
using Random
using Distributions
using BenchmarkTools
Random.seed!(888)
const x = rand(Uniform(-5, 5), 1_000_000)
const logit = -3.0 .+ 0.5 .* x .+ rand(Normal(0, 0.1), length(x))
const prob = @. 1.0 / (1.0 + exp(-logit))
const target = rand(length(x)) .<= prob
kstest
:
kstest(target, prob)
(n = 1000000, n1 = 94410, n0 = 905590, baserate = 0.09441, ks = 0.49180459544452004, ksarg = 0.09004201376264864, ksdep = 0.362772)
roc
:
roc(target, prob)
(conc = 69527841929, tied = 393224, disc = 15968516747, auc = 0.8132243271922474, gini = 0.6264486543844947)
These functions are performant:
@benchmark kstest($target, $prob)
BechmarkTools.Trial: 37 samples with 1 evaluations. Range (min … max): 132.672 ms … 147.931 ms ┊ GC (min … max): 0.00% … 0.98% Time (median): 134.942 ms ┊ GC (median): 0.00% Time (mean ± σ): 136.402 ms ± 4.403 ms ┊ GC (mean ± σ): 0.52% ± 0.62% ▆█ ██▇▇▄▄▄▁▁▇▄▁▇▄▄▇▁▁▁▁▁▄▁▄▄▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▄▁▁▁▄▁▁▁▁▁▄▁▁▁▄▁▁▁▄ ▁ 133 ms Histogram: frequency by time 148 ms < Memory estimate: 46.02 MiB, allocs estimate: 19.
@benchmark roc($target, $prob)
BechmarkTools.Trial: 55 samples with 1 evaluations. Range (min … max): 90.971 ms … 93.215 ms ┊ GC (min … max): 0.00% … 1.58% Time (median): 91.127 ms ┊ GC (median): 0.00% Time (mean ± σ): 91.270 ms ± 439.958 μs ┊ GC (mean ± σ): 0.12% ± 0.41% █▂ ▂ ▄▃▄█████▄▁▄▄▄▁▃▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▄ ▁ 91 ms Histogram: frequency by time 92.7 ms < Memory estimate: 7.75 MiB, allocs estimate: 8.
bcdiag
In additional to numeric metrics, often we would like to have plots and tables as part of final model documentation. The bcdiag
function allows easy generation of plots and tables.
Running bcdiag
prints a quick summary:
mdiag = bcdiag(target, prob)
Base rate: 0.0944 n: 1000000 n1: 94410 n0: 905590 ks: 0.4918 occurs at value of 0.09004201376264864 depth of 0.362772 roc: 0.8132 concordant pairs: 69527841929 tied pairs: 393224 discordant pairs: 15968516747 Gini: 0.6264
The output structure allows us to create the following plots and tables to understand:
- the ability of the model to separate the two classes
- the accuracy of the probability point estimates
- how to set cutoff for maximum accuracy
- performance of the model at varying cutoff depth
ksplot
ksplot
plots the cumulative distribution of class 1 (true positive rate) and class 0 (false positive rate) versus depth.
ksplot(mdiag)
It shows where the maximum separation of the two distributions occur.
rocplot
rocplot
plots the true positive rate vs. false positive rate (depth is implicit).
rocplot(mdiag)
A perfect model has auc of 1, a random model has auc of 0.5.
biasplot
Both ksplot
and rocplot
rely on the ability of the model to rank order the observations, the score value itself doesn't matter. For example, if you took the score and perform any monotonic transform, ks
and auc
wouldn't change. There are occasions where the score value does matter, where the probabilities need to be accurate, for example, in expected return calculations. Thus, we need to understand whether the probabilities are accurate, biasplot
does this by plotting the observed response rate versus predicted response rate to look for systemic bias. This is also called the calibration graph.
biasplot(mdiag)
An unbiased model would lie on the diagnonal, systemic shift off the diagonal represents over or under estimate of the true probability.
accuracyplot
People often refer to (TP + TN) / N as accuracy of the model, that is, the ability to correctly identify correct cases. It is used to compare model performance as well - model with higher accuracy is a better model. For a probability based classifier, a cutoff is required to turn probability to predicted class. So, what is the cutoff value to use to achieve maximum accuracy?
There are many approaches to setting the best cutoff, one way is to assign utility values to the four outcomes of [TP, FP, FN, TN] and maximize the sum across different cutoff's. Accuracy measure uses the utility values of [1, 0, 0, 1] giving TP + TN. You can assign negative penalty terms for misclassification as well.
Note that this is different from kstest
- maximum separation on cumulative distribution (normalized to 100%) does not account for class size difference, e.g., class 1 may be only 2% of the cases.
accuracyplot(mdiag)
liftcurve
liftcurve
plots the actual response and predicted response versus depth, with baserate as 1.
liftcurve(mdiag)
We can easily see where the model is performing better than average, approximately the same as average, or below average.
cumliftcurve
cumliftcurve
is similar to liftcurve
, the difference is it is a plot of cumulative response rate from the top of the model.
cumliftcurve(mdiag)
Tables
bcdiag
uses 100 as the default number of groups, this is good for generating plots above.
For tables such as decile reports, we may want to run bcdiag
with only 10 groups and then generate the tables:
mdiag10 = bcdiag(target, prob; groups = 10)
Base rate: 0.0944 n: 1000000 n1: 94410 n0: 905590 ks: 0.4918 occurs at value of 0.09004201376264864 depth of 0.362772 roc: 0.8132 concordant pairs: 69527841929 tied pairs: 393224 discordant pairs: 15968516747 Gini: 0.6264
liftable
liftable
is the table from which liftcurve
is plotted.
liftable(mdiag10)
grp | depth | count | cntObs | cntPrd | rrObs | rrPred | liftObs | liftPrd | |
---|---|---|---|---|---|---|---|---|---|
Int32 | Float64 | Int64 | Int64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | 0 | 0.1 | 100000 | 32486 | 32362.6 | 0.32486 | 0.323626 | 3.44095 | 3.42787 |
2 | 1 | 0.2 | 100000 | 22260 | 22327.0 | 0.2226 | 0.22327 | 2.3578 | 2.3649 |
3 | 2 | 0.3 | 100000 | 14934 | 14899.3 | 0.14934 | 0.148993 | 1.58182 | 1.57815 |
4 | 3 | 0.4 | 100000 | 9720 | 9610.26 | 0.0972 | 0.0961026 | 1.02955 | 1.01793 |
5 | 4 | 0.5 | 100000 | 6050 | 6059.78 | 0.0605 | 0.0605978 | 0.640822 | 0.641858 |
6 | 5 | 0.6 | 100000 | 3736 | 3767.23 | 0.03736 | 0.0376723 | 0.395721 | 0.399028 |
7 | 6 | 0.7 | 100000 | 2386 | 2321.27 | 0.02386 | 0.0232127 | 0.252727 | 0.245872 |
8 | 7 | 0.8 | 100000 | 1421 | 1421.43 | 0.01421 | 0.0142143 | 0.150514 | 0.150559 |
9 | 8 | 0.9 | 100000 | 880 | 866.2 | 0.0088 | 0.008662 | 0.0932105 | 0.0917488 |
10 | 9 | 1.0 | 100000 | 537 | 522.95 | 0.00537 | 0.0052295 | 0.0568796 | 0.0553914 |
cumliftable
cumliftable
is the cumulative version of liftable
.
cumliftable(mdiag10)
grp | depth | count | cumObs | cumPrd | crObs | crPrd | liftObs | liftPrd | |
---|---|---|---|---|---|---|---|---|---|
Int32 | Float64 | Int64 | Int64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | 0 | 0.1 | 100000 | 32486 | 32362.6 | 0.32486 | 0.323626 | 3.44095 | 3.42787 |
2 | 1 | 0.2 | 200000 | 54746 | 54689.6 | 0.27373 | 0.273448 | 2.89938 | 2.89639 |
3 | 2 | 0.3 | 300000 | 69680 | 69588.8 | 0.232267 | 0.231963 | 2.46019 | 2.45697 |
4 | 3 | 0.4 | 400000 | 79400 | 79199.1 | 0.1985 | 0.197998 | 2.10253 | 2.09721 |
5 | 4 | 0.5 | 500000 | 85450 | 85258.9 | 0.1709 | 0.170518 | 1.81019 | 1.80614 |
6 | 5 | 0.6 | 600000 | 89186 | 89026.1 | 0.148643 | 0.148377 | 1.57444 | 1.57162 |
7 | 6 | 0.7 | 700000 | 91572 | 91347.4 | 0.130817 | 0.130496 | 1.38563 | 1.38223 |
8 | 7 | 0.8 | 800000 | 92993 | 92768.8 | 0.116241 | 0.115961 | 1.23124 | 1.22827 |
9 | 8 | 0.9 | 900000 | 93873 | 93635.0 | 0.104303 | 0.104039 | 1.10479 | 1.10199 |
10 | 9 | 1.0 | 1000000 | 94410 | 94158.0 | 0.09441 | 0.094158 | 1.0 | 0.99733 |