scikit-learn: Save and Restore Models

On many occasions, while working with the scikit-learn library, you'll need to save your prediction models to file, and then restore them in order to reuse your previous work to: test your model on new data, compare multiple models, or anything else. This saving procedure is also known as object serialization - representing an object with a stream of bytes, in order to store it on disk, send it over a network or save to a database, while the restoring procedure is known as deserialization. In this article, we look at three possible ways to do this in Python and scikit-learn, each presented with its pros and cons.

Tools to Save and Restore Models

The first tool we describe is Pickle, the standard Python tool for object (de)serialization. Afterwards, we look at the Joblib library which offers easy (de)serialization of objects containing large data arrays, and finally we present a manual approach for saving and restoring objects to/from JSON (JavaScript Object Notation). None of these approaches represents an optimal solution, but the right fit should be chosen according to the needs of your project.

Model Initialization

Initially, let's create one scikit-learn model. In our example we'll use a Logistic Regression model and the Iris dataset. Let's import the needed libraries, load the data, and split it in training and test sets.

from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load and split data
data = load_iris()
Xtrain, Xtest, Ytrain, Ytest = train_test_split(data.data, data.target, test_size=0.3, random_state=4)

Now let's create the model with some non-default parameters and fit it to the training data. We assume that you have previously found the optimal parameters of the model, i.e. the ones which produce highest estimated accuracy.

# Create a model
model = LogisticRegression(C=0.1, 
                           max_iter=20, 
                           fit_intercept=True, 
                           n_jobs=3, 
                           solver='liblinear')
model.fit(Xtrain, Ytrain)

And our resulting model:

LogisticRegression(C=0.1, class_weight=None, dual=False, fit_intercept=True,
    intercept_scaling=1, max_iter=20, multi_class='ovr', n_jobs=3,
    penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
    verbose=0, warm_start=False)

Using the fit method, the model has learned its coefficients which are stored in model.coef_. The goal is to save the model's parameters and coefficients to file, so you don't need to repeat the model training and parameter optimization steps again on new data.

The pickle Module

In the following few lines of code, the model which we created in the previous step is saved to file, and then loaded as a new object called pickled_model. The loaded model is then used to calculate the accuracy score and predict outcomes on new unseen (test) data.

import pickle

#
# Create your model here (same as above)
#

# Save to file in the current working directory
pkl_filename = "pickle_model.pkl"
with open(pkl_filename, 'wb') as file:
    pickle.dump(model, file)

# Load from file
with open(pkl_filename, 'rb') as file:
    pickle_model = pickle.load(file)
    
# Calculate the accuracy score and predict target values
score = pickle_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)

Running this code should yield your score and save the model via pickle:

$ python save_model_pickle.py
Test score: 91.11 %

The great thing about using pickle to save and restore our learning models is that it's quick - you can do it in two lines of code. It is useful if you have optimized the model's parameters on the training data, so you don't need to repeat this step again. Anyway, it doesn't save the test results or any data. Still you can do this by saving a tuple, or a list, of multiple objects (and remember which object goes where), as follows:

tuple_objects = (model, Xtrain, Ytrain, score)

# Save tuple
pickle.dump(tuple_objects, open("tuple_model.pkl", 'wb'))

# Restore tuple
pickled_model, pickled_Xtrain, pickled_Ytrain, pickled_score = pickle.load(open("tuple_model.pkl", 'rb'))

The joblib Module

The joblib library is intended to be a replacement for pickle, for objects containing large data. We'll repeat the save and restore procedure as with pickle.

from sklearn.externals import joblib

# Save to file in the current working directory
joblib_file = "joblib_model.pkl"
joblib.dump(model, joblib_file)

# Load from file
joblib_model = joblib.load(joblib_file)

# Calculate the accuracy and predictions
score = joblib_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)
$ python save_model_joblib.py
Test score: 91.11 %

As seen from the example, the joblib library offers a bit simpler workflow compared to pickle. While pickle requires a file object to be passed as an argument, joblib works with both file objects and string filenames. In case your model contains large arrays of data, each array will be stored in a separate file, but the save and restore procedure will remain the same. joblib also allows different compression methods, such as zlib, gzip, bz2, and different levels of compression.

Manual Save and Restore to JSON

Depending on your project, many times you would find pickle and joblib as unsuitable solutions. Some of these reasons are discussed later in the Compatibility Issues section. Anyway, whenever you want to have full control over the save and restore process, the best way is to build your own functions manually.

The following shows an example of manually saving and restoring objects using JSON. This approach allows us to select the data which needs to be saved, such as the model parameters, coefficients, training data, and anything else we need.

Since we want to save all of this data in a single object, one possible way to do it is to create a new class which inherits from the model class, which in our example is LogisticRegression. The new class, called MyLogReg, then implements the methods save_json and load_json for saving and restoring to/from a JSON file, respectively.

For simplicity, we'll save only three model parameters and the training data. Some additional data we could store with this approach is, for example, a cross-validation score on the training set, test data, accuracy score on the test data, etc.

import json
import numpy as np

class MyLogReg(LogisticRegression):
    
    # Override the class constructor
    def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None):
        LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
        self.X_train = X_train
        self.Y_train = Y_train
        
    # A method for saving object data to JSON file
    def save_json(self, filepath):
        dict_ = {}
        dict_['C'] = self.C
        dict_['max_iter'] = self.max_iter
        dict_['solver'] = self.solver
        dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
        dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'
        
        # Create json and save to file
        json_txt = json.dumps(dict_, indent=4)
        with open(filepath, 'w') as file:
            file.write(json_txt)
    
    # A method for loading data from JSON file
    def load_json(self, filepath):
        with open(filepath, 'r') as file:
            dict_ = json.load(file)
            
        self.C = dict_['C']
        self.max_iter = dict_['max_iter']
        self.solver = dict_['solver']
        self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
        self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None
        

Now let's try the MyLogReg class. First we create an object mylogreg, pass the training data to it, and save it to a file. Then we create a new object json_mylogreg and call the load_json method to load the data from the file.

filepath = "mylogreg.json"

# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)
mylogreg.save_json(filepath)

# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()
json_mylogreg.load_json(filepath)
json_mylogreg

Printing out the new object, we can see our parameters and training data as needed.

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

MyLogReg(C=1.0,
     X_train=array([[ 4.3,  3. ,  1.1,  0.1],
       [ 5.7,  4.4,  1.5,  0.4],
       ...,
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.7,  2.8,  6.7,  2. ]]),
     Y_train=array([0, 0, ..., 2, 2]), class_weight=None, dual=False,
     fit_intercept=True, intercept_scaling=1, max_iter=100,
     multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
     solver='liblinear', tol=0.0001, verbose=0, warm_start=False)

Since the data serialization using JSON actually saves the object into a string format, rather than byte stream, the mylogreg.json file could be opened and modified with a text editor. Although this approach would be convenient for the developer, it is less secure since an intruder can view and amend the content of the JSON file. Moreover, this approach is more suitable for objects with a small number of instance variables, such as the scikit-learn models, because any addition of new variables requires changes in the save and restore methods.

Compatibility Issues

While some of the pros and cons of each tool were covered in the text so far, probably the biggest drawback of the pickle and joblib tools is its compatibility over different models and Python versions.

Python version compatibility - The documentation of both tools states that it is not recommended to (de)serialize objects across different Python versions, although it might work across minor version changes.

Model compatibility - One of the most frequent mistakes is saving your model with pickle or joblib, then changing the model before trying to restore from the file. The internal structure of the model needs to stay unchanged between save and reload.

One last issue with both pickle and joblib is related to security. Both tools could contain malicious code, so it is not recommended to restore data from untrusted or unauthenticated sources.

Conclusions

In this post we described three tools for saving and restoring scikit-learn models. The pickle and joblib libraries are quick and easy to use, but have compatibility issues across different Python versions and changes in the learning model. On the other hand, the manual approach is more difficult to implement and needs to be modified with any change in the model structure, but on the plus side it could easily be adapted to various needs, and does not have any compatibility issues.

Last Updated: July 13th, 2023
Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

Mihajlo PavloskiAuthor

A research PhD student and coding enthusiast working in Data Science and Machine Learning.

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms