Split Train, Test and Validation Sets with TensorFlow Datasets - tfds

Introduction

TensorFlow Datasets, also known as tfds is is a library that serves as a wrapper to a wide selection of datasets, with proprietary functions to load, split and prepare datasets for Machine and Deep Learning, primarily with TensorFlow.

Note: While the TensorFlow Datasets library is used to get data, it's not used to preprocess data. That job is delegated to the Tensorflow Data (tf.data) library.

All of the datasets acquired through TensorFlow Datasets are wrapped into tf.data.Dataset objects - so you can programmatically obtain and prepare a wide variety of datasets easily! One of the first steps you'll be taking after loading and getting to know a dataset is a train/test/validation split.

In this guide, we'll take a look at what training, testing and validation sets are before learning how to load in and perform a train/test/validation split with TensorFlow Datasets.

Training and Testing Sets

When working on supervised learning tasks - you'll want to obtain a set of features and a set of labels for those features, either as separate entities or within a single Dataset. Just training the network on all of the data is fine and dandy - but you can't test its accuracy on that same data, since evaluating the model like that would be rewarding memorization instead of generalization.

Instead - we train the models on one part of the data, holding off a part of it to test the model once it's done training. The ratio between these two is commonly 80/20, and that's a fairly sensible default. Depending on the size of the dataset, you might opt for different ratios, such as 60/40 or even 90/10. If there are many samples in the testing set, there's no need to have a large percentage of samples dedicated to it. For instance, if 1% of the dataset represents 1.000.000 samples - you probably don't need more than that for testing!

For some models and architectures - you won't have any test set at all! For instance, when training Generative Adversarial Networks (GANs) that generate images - testing the model isn't as easy as comparing the real and predicted labels! In most generative models (music, text, video), at least as of now, a human is typically required to judge the outputs, in which cases, a test set is totally redundant.

The test set should be held out from the model until the testing stage, and it should only ever be used for inference - not training. It's common practice to define a test set and "forget it" until the end stages where you validate the model's accuracy.

Validation Sets

A validation set is an extremely important, and sometimes overlooked set. Validation sets are oftentimes described as taken "out of" test sets, since it's convenient to imagine, but really - they're separate sets. There's no set rule for split ratios, but it's common to have a validation set of similar size to the test set, or slightly smaller - anything along the lines of 75/15/10, 70/15/15, and 70/20/10.

A validation set is used during training, to approximately validate the model on each epoch. This helps to update the model by giving "hints" as to whether it's performing well or not. Additionally, you don't have to wait for an entire set of epochs to finish to get a more accurate glimpse at the model's actual performance.

Note: The validation set isn't used for training, and the model doesn't train on the validation set at any given point. It's used to validate the performance in a given epoch. Since it does affect the training process, the model indirectly trains on the validation set and thus, it can't be fully trusted for testing, but is a good approximation/proxy for updating beliefs during training.

This is analogous to knowing when you're wrong, but not knowing what the right answer is. Eventually, by updating your beliefs after realizing you're not right, you'll get closer to the truth without explicitly being told what it is. A validation set indirectly trains your knowledge.

Using a validation set - you can easily interpret when a model has begun to overfit significantly in real-time, and based on the disparity between the validation and training accuracies, you could opt to trigger responses - such as automatically stopping training, updating the learning rate, etc.

Split Train, Test and Validation Sets using TensorFlow Datasets

The load() function of the tfds module loads in a dataset, given its name. If it's not already downloaded on the local machine - it'll automatically download the dataset with a progress bar:

import tensorflow_datasets as tfds

# Load dataset
dataset, info = tfds.load("cifar10", as_supervised=True, with_info=True)

# Extract informative features
class_names = info.features["label"].names
n_classes = info.features["label"].num_classes

print(class_names) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
print(n_classes) # 10

One of the optional arguments you can pass into the load() function is the split argument.

The new Split API allows you to define which splits of the dataset you want to split out. By default, for this dataset, it only supports a 'train' and 'test' split - these are the "official" splits for this dataset. There's no valid split.

Note: Each dataset has an "official" split. Some only have the 'train' split, some have a 'train' and 'test' split and some even include a 'validation' split. This is the intended split and only if a dataset supports a split, can you use that split's string alias. If a dataset contains only a 'train' split, you can split that training data into a train/test/valid set without issues.

These correspond to the tfds.Split.TRAIN and tfds.Split.TEST and tfds.Split.VALIDATION enums, which used to be exposed through the API in an earlier version.

You can really slice a Dataset into any arbitrary number of sets, though, we typically do three - train_set, test_set, valid_set:

test_set, valid_set, train_set = tfds.load("cifar10", 
                                           split=["test", "train[0%:20%]", "train[20%:]"],
                                           as_supervised=True, with_info=True)

print("Train set size: ", len(train_set)) # Train set size:  40000
print("Test set size: ", len(test_set))   # Test set size:  10000
print("Valid set size: ", len(valid_set)) # Valid set size:  10000

We've taken 'test' split and extracted it into the test_set. The slice between 0% and 20% of the 'train' split is assigned to the valid_set and everything beyond 25% is the train_set. This is validated through the sizes of the sets themselves as well.

Instead of percentages, you can use absolute values or a mix of percentage and absolute values:

# Absolute value split
test_set, valid_set, train_set = tfds.load("cifar10", 
                                           split=["test", "train[0:10000]", "train[10000:]"],
                                           as_supervised=True)

print("Train set size: ", len(train_set)) # Train set size:  40000
print("Test set size: ", len(test_set))   # Test set size:  10000
print("Valid set size: ", len(valid_set)) # Valid set size:  10000


# Mixed notation split
# 5000 - 50% (25000) left unassigned
test_set, valid_set, train_set = tfds.load("cifar10", 
                                           split=["train[:2500]", # First 2500 are assigned to `test_set`
                                           "train[2500:5000]",    # 2500-5000 are assigned to `valid_set`
                                           "train[50%:]"],        # 50% - 100% (25000) assigned to `train_set`
                                           as_supervised=True)

You can additionally do a union of sets, which is less commonly used, as sets are interleaved then:

train_and_test, half_of_train_and_test = tfds.load("cifar10", 
                                split=['train+test', 'train[:50%]+test'],
                                as_supervised=True)
                                
print("Train+test: ", len(train_and_test))               # Train+test:  60000
print("Train[:50%]+test: ", len(half_of_train_and_test)) # Train[:50%]+test:  35000

These two sets are now heavily interleaved.

Even Splits for N Sets

Again, you can create any arbitrary number of splits, just by adding more splits to the split list:

split=["train[:10%]", "train[10%:20%]", "train[20%:30%]", "train[30%:40%]", ...]
Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

However, if you're creating many splits, especially if they're even - the strings you'll be passing in are very predictable. This can be automated by creating a list of strings, with a given equal interval (such as 10%) instead. For exactly this purpose, the tfds.even_splits() function generates a list of strings, given a prefix string and the desired number of splits:

import tensorflow_datasets as tfds

s1, s2, s3, s4, s5 = tfds.even_splits('train', n=5)
# Each of these elements is just a string
split_list = [s1, s2, s3, s4, s5]
print(f"Type: {type(s1)}, contents: '{s1}'")
# Type: <class 'str'>, contents: 'train[0%:20%]'

for split in split_list:
    test_set = tfds.load("cifar10", 
                                split=split,
                                as_supervised=True)
    print(f"Test set length for Split {split}: ", len(test_set))

This results in:

Test set length for Split train[0%:20%]:  10000
Test set length for Split train[20%:40%]:  10000
Test set length for Split train[40%:60%]:  10000
Test set length for Split train[60%:80%]:  10000
Test set length for Split train[80%:100%]:  10000

Alternatively, you can pass in the entire split_list as the split argument itself, to construct several split datasets outside of a loop:

ts1, ts2, ts3, ts4, ts5 = tfds.load("cifar10", 
                                split=split_list,
                                as_supervised=True)

Conclusion

In this guide, we've taken a look at what the training and testing sets are as well as the importance of validation sets. Finally, we've explored the new Splits API of the TensorFlow Datasets library, and performed a train/test/validation split.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

© 2013-2025 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms