- R Deep Learning Essentials
- Mark Hodnett Joshua F. Wiley
- 752字
- 2021-08-13 15:34:30
The problem of overfitting data – the consequences explained
A common issue in machine learning is overfitting data. Generally, overfitting is used to refer to the phenomenon where the model performs better on the data used to train the model than it does on data not used to train the model (holdout data, future real use, and so on). Overfitting occurs when a model memorizes part of the training data and fits what is essentially noise in the training data. The accuracy in the training data is high, but because the noise changes from one dataset to the next, this accuracy does not apply to unseen data, that is, we can say that the model does not generalize very well.
Overfitting can occur at any time, but tends to become more severe as the ratio of parameters to information increases. Usually, this can be thought of as the ratio of parameters to observations, but not always. For example, suppose we have a very imbalanced dataset where the outcome we want to predict is a rare event that occurs in 1 in 5 million cases. In that case, a sample size of 15 million may only have 3 positive cases. Even though the sample size is large, the information is low. To consider a simple-but-extreme case, imagine fitting a straight line to two data points. The fit will be perfect, and in those two training data, your linear-regression model will appear to have fully accounted for all variations in the data. However, if we then applied that line to another 1,000 cases, it might not fit very well at all.
In the previous sections, we generated out-of-sample predictions for the our models, that is, we evaluated accuracy on test (or holdout) data. But we never checked whether our models were overfitting, that is, the accuracy levels on the test data. We can examine how well the model generalizes by checking the accuracy on the in-sample predictions. We can see that the accuracy on the in-sample data is 84.7%, compared to 81.7% on the holdout data. There is a 3.0% loss; or, put differently, using training data to evaluate model performance resulted in an overly optimistic estimate of the accuracy, and that overestimate was 3.0%:
digits.yhat4.train <- predict(digits.m4)
digits.yhat4.train <- encodeClassLabels(digits.yhat4.train)
accuracy <- 100.0*sum(I(digits.yhat4.train - 1)==digits.y)/length(digits.y)
print(sprintf(" accuracy = %1.2f%%",accuracy))
[1] " accuracy = 84.70%"
Since we fitted several models earlier of varying complexity, we could examine the degree of overfitting or overly optimistic accuracy from in-sample versus out- of-sample performance measures across them. The code here should be easy enough to follow. We call the predict function for our models and do not pass in any new data; this returns the predictions for the data the model was trained with. The rest of the code is boilerplate code to create the graphic plot.
digits.yhat1.train <- predict(digits.m1)
digits.yhat2.train <- predict(digits.m2)
digits.yhat3.train <- predict(digits.m3)
digits.yhat4.train <- predict(digits.m4)
digits.yhat4.train <- encodeClassLabels(digits.yhat4.train)
measures <- c("AccuracyNull", "Accuracy", "AccuracyLower", "AccuracyUpper")
n5.insample <- caret::confusionMatrix(xtabs(~digits.y + digits.yhat1.train))
n5.outsample <- caret::confusionMatrix(xtabs(~digits.test.y + digits.yhat1))
n10.insample <- caret::confusionMatrix(xtabs(~digits.y + digits.yhat2.train))
n10.outsample <- caret::confusionMatrix(xtabs(~digits.test.y + digits.yhat2))
n40.insample <- caret::confusionMatrix(xtabs(~digits.y + digits.yhat3.train))
n40.outsample <- caret::confusionMatrix(xtabs(~digits.test.y + digits.yhat3))
n40b.insample <- caret::confusionMatrix(xtabs(~digits.y + I(digits.yhat4.train - 1)))
n40b.outsample <- caret::confusionMatrix(xtabs(~ digits.test.y + I(digits.yhat4 - 1)))
shrinkage <- rbind(
cbind(Size = 5, Sample = "In", as.data.frame(t(n5.insample$overall[measures]))),
cbind(Size = 5, Sample = "Out", as.data.frame(t(n5.outsample$overall[measures]))),
cbind(Size = 10, Sample = "In", as.data.frame(t(n10.insample$overall[measures]))),
cbind(Size = 10, Sample = "Out", as.data.frame(t(n10.outsample$overall[measures]))),
cbind(Size = 40, Sample = "In", as.data.frame(t(n40.insample$overall[measures]))),
cbind(Size = 40, Sample = "Out", as.data.frame(t(n40.outsample$overall[measures]))),
cbind(Size = 40, Sample = "In", as.data.frame(t(n40b.insample$overall[measures]))),
cbind(Size = 40, Sample = "Out", as.data.frame(t(n40b.outsample$overall[measures])))
)
shrinkage$Pkg <- rep(c("nnet", "RSNNS"), c(6, 2))
dodge <- position_dodge(width=0.4)
ggplot(shrinkage, aes(interaction(Size, Pkg, sep = " : "), Accuracy,
ymin = AccuracyLower, ymax = AccuracyUpper,
shape = Sample, linetype = Sample)) +
geom_point(size = 2.5, position = dodge) +
geom_errorbar(width = .25, position = dodge) +
xlab("") + ylab("Accuracy + 95% CI") +
theme_classic() +
theme(legend.key.size = unit(1, "cm"), legend.position = c(.8, .2))
The code produces the following plot, which shows the accuracy metrics and the confidence intervals for those metrics. One thing we notice from this plot is that, as the models get more complex, the gap between performance on the in-sample performance measures and the out-sample performance measures increases. This highlights that more complex models tend to overfit, that is, they perform better on the in-sample data than the unseen out-sample data: