Scikit-Learn's train_test_split() - Training, Testing and Validation Sets

Introduction

Scikit-Learn is one of the most widely-used Machine Learning library in Python. It's optimized and efficient - and its high-level API is simple and easy to use.

Scikit-Learn has a plethora of convenience tools and methods that make preprocessing, evaluating and other painstaking processes as easy as calling a single method - and splitting data between a training and testing set is no exception.

Generally speaking, the rule-of-thumb for splitting data is 80/20 - where 80% of the data is used for training a model, while 20% is used for testing it. This depends on the dataset you're working with, but an 80/20 split is very common and would get you through most datasets just fine.

In this guide - we'll take a look at how to use the split_train_test() method in Scikit-Learn, and how to configure the parameters so that you have control over the splitting process.

Installing Scikit-Learn

Assuming it isn't already installed - Scikit-Learn can easily be installed via pip:

$ pip install scikit-learn

Once installed, you can import the library itself via:

import sklearn

In most cases, people avoid importing the entire library, as it's pretty vast, and import specific classes or modules that they'll be using specifically.

Note: This tends to mean that people have a hefty import list when using Scikit-Learn.

Importance of Training and Testing Sets

The most common procedure when training a (basic) model in Machine Learning follows the same rough outline:

  • Acquiring and processing data which we'll feed into a model.

Scikit-Learn has various datasets to be loaded and used for training the model (iris, diabetes, digits...), mainly for benchmarking/learning.

  • Splitting sets into training and test sets
  • Building a model and defining the architecture
  • Compiling the model
  • Training the model
  • Verifying the results

The training set is a subset of the whole dataset and we generally don't train a model on the entirety of the data. In non-generative models, a training set usually contains around 80% of the main dataset's data. As the name implies, it is used for training the model. This procedure is also referred to as fitting the model.

There are exceptions to this rule, though.

For instance, when training Generative Adversarial Networks (GANs) that generate images - how do you test the results? They're highly subjective in some cases, as they represent new instances that were never seen before. In most generative models, at least as of now, a human is typically required to judge the outputs, in which cases, a test set is totally redundant.

Additionally, sometimes you need more or less than 20% for testing, and if you're using techniques such as cross-validation, you might want to have a tiny bit less testing data so as to not "take away" too much from the training data. For instance, if you have 1.000.000 instances in a dataset, holding out just 5% for a testing set amounts to 50.000 instances, which is most likely more than enough for any model to be tested on.

The test set is a subset of the whole dataset, and is used to evaluate the model and check how well it learned from the training set.

The model mustn't interact or see the test set before evaluating. The data must be unknown when first evaluating, otherwise it's not really testing the model.

What About Validation Sets?

Validation sets are a common sight in professional and academic models. Validation sets are taken out of the training set, and used during training to validate the model's accuracy approximately.

The testing set is fully disconnected until the model is finished training - but the validation set is used to validate it during training.

Note: The validation set isn't used for training, and the model doesn't train on the data. It just validates the current epoch. This way - it indirectly trains on the data, as it does affect its prior beliefs, so the validation set can't be used for testing.

Similar to how you'll learn more about your own knowledge if you hear it's incorrect - even if you don't know why. This is why validation sets approximate a model’s accuracy, and testing sets are still required even when using a validation set.

They help with approximating a model's actual performance during training, so you don't end up with an illusory overfit model without realizing it after testing it via a test set. You can also use validation sets to tune models, and approximately evaluate their ability without exposing them to a testing set.

Deep Learning frameworks such as Keras can display a val_accuracy besides your regular training accuracy as a good sign of overfitting. If they start diverging, your model is overfitting during training, and you don't need to waste time training it further when they diverge enough. Additionally, callbacks such as EarlyStopping can be used to automatically stop a model's training if the val_accuracy doesn't improve after n epochs, even if the accuracy is increasing.

Creating a validation set is easy.

You can, quite literally, just run the train_test_split() method on the training set, which was already split by the train_test_split() method and extract a validation set from it. Alternatively, you can manually create a validation set.

The validation set size is typically split similar to a testing set - anywhere between 10-20% of the training set is typical. For huge datasets, you can do much lower than this, but for small datasets, you can take out too much, making it hard for the model to fit the data in the training set.

In the proceeding sections, we'll also take out a validation set using the same train_test_split() method.

Scikit-Learn's datasets Module

Several clean and popular datasets are available built-into Scikit-Learn, typically used during learning and for benchmarking models on simple tasks.

If you've ever read resources regarding Machine Learning in Python - you've probably seen some of these most popular datasets:

  • Iris - set of 3 classes (flowers), with 50 samples per class, used for classification.
  • Diabetes - set with a total of 442 samples, used for regression.
  • Digits - set of 10 classes (hand-written digits), with ~180 samples per class, used for classification.
  • Wine - set of 3 classes (of wine), with total of 178 samples, used for classification.

Each of these datasets can be loaded in through the datasets module with their respective function:

from sklearn import datasets

X_iris, y_iris = datasets.load_iris(return_X_y=True)
X_diabetes, y_diabetes = datasets.load_diabetes(return_X_y=True)
X_digits, y_digits = datasets.load_digits(return_X_y=True)
X_wine, y_wine = datasets.load_wine(return_X_y=True)

Alternatively, you can load in the specific functions instead:

from sklearn.datasets import load_iris
from sklearn.datasets import load_diabetes
from sklearn.datasets import load_digits
from sklearn.datasets import load_wine

X_iris, y_iris = load_iris(return_X_y=True)
X_diabetes, y_diabetes = load_diabetes(return_X_y=True)
X_digits, y_digits = load_digits(return_X_y=True)
X_wine, y_wine = load_wine(return_X_y=True)

By default, these methods return a Bunch object, containing the data and the targets (data and their classes), however, if you set the return_X_y argument to True, a tuple of (data, targets) is returned, denoting the data you'll be training on and the target classes you want your classifier/regression model to hit.

Splitting a Dataset with train_test_split()

The train_test_split() method resides in the sklearn.model_selection module:

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

from sklearn.model_selection import train_test_split

There are a couple of arguments we can set while working with this method - and the default is very sensible and performs a 75/25 split. In practice, all of Scikit-Learn's default values are fairly reasonable and set to serve well for most tasks. However, it's worth noting what these defaults are, in the cases they don't work that well.

The main two arguments are train_size and test_size, where their values range between 0 and 1 and their sum has to be 1. Their values denote the percentage proportion of the dataset, so even if you define just one, such as train_size, the test_size is equal to 1 - train_size, and vice versa.

Setting the train_size and test_size Arguments

This is the most common approach, which, leaves us with 4 subsets - X_train, X_test, y_train and y_test:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y)

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

Without setting either train_size or test_size the default values kick in, setting the test_size to 0.25, and a complementary (0.75) train_size:

(112, 4)
(38, 4)
(112,)
(38,)

As you can see, train and test sets are split 75%/25%, as there are 112 instances in the X_train set, and 38 instances in the X_test set.

Some other split proportions are: 80%/20% (very common), 67%/33% and more rarely 50%/50%.

Setting any of these boils down to defining either one or both of the arguments in the train_test_split() method:

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, test_size=0.2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

All three of these splits would result in the same split of:

(120, 4)
(30, 4)
(120,)
(30,)

Creating a Validation Set with train_test_split()

Validation sets are really useful during training and make your life as a Data Scientist significantly easier.

Whenever possible, try to use a validation set.

There is no built-in function to extract a validation set from a training set, however, since this boils down to just splitting it like before - why not use the same train_test_split() method?

Let's re-use it to get our hands on a validation set, taking 10% of the data from the training set:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)

X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, train_size=0.9)

print(X_train.shape)
print(X_test.shape)
print(X_valid.shape)

This won't create a 70%-20%-10% split, as we're splitting 10% from the already split X_train so we're actually ending up with a 72%-20%-8% split here:

(108, 4)
(30, 4)
(12, 4)

To account for this, you can either manually set a different number, expecting this, or you could define your proportions upfront, and calculate an updated split to reference the original size, instead of the already truncated size.

To split the data proportionally into a training, testing and validation set - we need to set the test_size argument on the second function call to:

$$
test_s = validation_r/(train_r+test_r)
$$

Let's load in the Diabetes dataset, as it has more instances (due to rounding, small datasets oftentimes produce slightly different splits even with same ratios):

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

X, y = load_diabetes(return_X_y=True)

print(X.shape)
(442, 10)

Say we're aiming for an 80%/10%/10% split - we'd want to have 352, 45 and 45 instances respectively. Let's define these rations and split the dataset into a training, testing and validation set with the train_test_split() function:

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

train_ratio = 0.80
test_ratio = 0.10
validation_ratio = 0.10

X, y = load_diabetes(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio)

X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=validation_ratio/(train_ratio+test_ratio))

print(X_train.shape)
print(X_test.shape)
print(X_valid.shape)

This results in:

(352, 10)
(45, 10)
(45, 10)

Awesome! Our dataset has successfully been split into three sets, which we can now feed into a model and perform validation during training to tune the hyperparameters.

Stratified Split

Sometimes, there are different numbers of samples for each class in a dataset. Say, one class has 100 samples, the second one has 50, the third one 30, etc. Splitting without this in mind creates an issue when you're training a classification model (though, regression models don't suffer from this).

It is best to somehow split the set, so that it preserves the proportions of the classes. This is a stratified split.

Luckily, the method train_test_split has an argument called stratify which takes an array which defines the number of samples by class, when splitting, to stay proportional:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)

In a lot of cases, you can simply use the y NumPy array from your dataset for a good stratify split array. This ensures that your model can fight the lack of balance between instances of classes and becomes less biased towards some.

Conclusion

In this guide, we got familiar with some of the Scikit-Learn library and its datasets module. You've learned what training, testing and validation sets are, where they're applied and the benefits of validating your models.

We've taken a look at how to employ the train_test_split() method to split your data into a training and testing set, as well as how to separate a validation set, dynamically preserving the ratios of these sets.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupEditor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

© 2013-2025 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms