scikit-learn: Save and Restore Models - Stack Abuse

scikit-learn: Save and Restore Models

On many occasions, while working with the scikit-learn library, you'll need to save your prediction models to file, and then restore them in order to reuse your previous work to: test your model on new data, compare multiple models, or anything else. This saving procedure is also known as object serialization - representing an object with a stream of bytes, in order to store it on disk, send it over a network or save to a database, while the restoring procedure is known as deserialization. In this article, we look at three possible ways to do this in Python and scikit-learn, each presented with its pros and cons.

Tools to Save and Restore Models

The first tool we describe is Pickle, the standard Python tool for object (de)serialization. Afterwards, we look at the Joblib library which offers easy (de)serialization of objects containing large data arrays, and finally we present a manual approach for saving and restoring objects to/from JSON (JavaScript Object Notation). None of these approaches represents an optimal solution, but the right fit should be chosen according to the needs of your project.

Model Initializtion

Initially, let's create one scikit-learn model. In our example we'll use a Logistic Regression model and the Iris dataset. Let's import the needed libraries, load the data, and split it in training and test sets.

from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load and split data
data = load_iris()
Xtrain, Xtest, Ytrain, Ytest = train_test_split(,, test_size=0.3, random_state=4)

Now let's create the model with some non-default parameters and fit it to the training data. We assume that you have previously found the optimal parameters of the model, i.e. the ones which produce highest estimated accuracy.

# Create a model
model = LogisticRegression(C=0.1, 
                           solver='liblinear'), Ytrain)

And our resulting model:

LogisticRegression(C=0.1, class_weight=None, dual=False, fit_intercept=True,
    intercept_scaling=1, max_iter=20, multi_class='ovr', n_jobs=3,
    penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
    verbose=0, warm_start=False)

Using the fit method, the model has learned its coefficients which are stored in model.coef_. The goal is to save the model's parameters and coefficients to file, so you don't need to repeat the model training and parameter optimization steps again on new data.

Pickle Module

In the following few lines of code, the model which we created in the previous step is saved to file, and then loaded as a new object called pickled_model. The loaded model is then used to calculate the accuracy score and predict outcomes on new unseen (test) data.

import pickle

# Create your model here (same as above)

# Save to file in the current working directory
pkl_filename = "pickle_model.pkl"
with open(pkl_filename, 'wb') as file:
    pickle.dump(model, file)

# Load from file
with open(pkl_filename, 'rb') as file:
    pickle_model = pickle.load(file)
# Calculate the accuracy score and predict target values
score = pickle_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)

Running this code should yield your score and save the model via Pickle:

$ python
Test score: 91.11 %

The great thing about using Pickle to save and restore our learning models is that it's quick - you can do it in two lines of code. It is useful if you have optimized the model's parameters on the training data, so you don't need to repeat this step again. Anyway, it doesn't save the test results or any data. Still you can do this by saving a tuple, or a list, of multiple objects (and remember which object goes where), as follows:

tuple_objects = (model, Xtrain, Ytrain, score)

# Save tuple
pickle.dump(tuple_objects, open("tuple_model.pkl", 'wb'))

# Restore tuple
pickled_model, pickled_Xtrain, pickled_Ytrain, pickled_score = pickle.load(open("tuple_model.pkl", 'rb'))

Joblib Module

The Joblib library is intended to be a replacement for Pickle, for objects containing large data. We'll repeat the save and restore procedure as with Pickle.

from sklearn.externals import joblib

# Save to file in the current working directory
joblib_file = "joblib_model.pkl"
joblib.dump(model, joblib_file)

# Load from file
joblib_model = joblib.load(joblib_file)

# Calculate the accuracy and predictions
score = joblib_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)
$ python
Test score: 91.11 %

As seen from the example, the Joblib library offers a bit simpler workflow compared to Pickle. While Pickle requires a file object to be passed as an argument, Joblib works with both file objects and string filenames. In case your model contains large arrays of data, each array will be stored in a separate file, but the save and restore procedure will remain the same. Joblib also allows different compression methods, such as 'zlib', 'gzip', 'bz2', and different levels of compression.

Manual Save and Restore to JSON

Depending on your project, many times you would find Pickle and Joblib as unsuitable solutions. Some of these reasons are discussed later in the Compatibility Issues section. Anyway, whenever you want to have full control over the save and restore process, the best way is to build your own functions manually.

The following shows an example of manually saving and restoring objects using JSON. This approach allows us to select the data which needs to be saved, such as the model parameters, coefficients, training data, and anything else we need.

Since we want to save all of this data in a single object, one possible way to do it is to create a new class which inherits from the model class, which in our example is LogisticRegression. The new class, called MyLogReg, then implements the methods save_json and load_json for saving and restoring to/from a JSON file, respectively.

For simplicity, we'll save only three model parameters and the training data. Some additional data we could store with this approach is, for example, a cross-validation score on the training set, test data, accuracy score on the test data, etc.

import json
import numpy as np

class MyLogReg(LogisticRegression):
    # Override the class constructor
    def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None):
        LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
        self.X_train = X_train
        self.Y_train = Y_train
    # A method for saving object data to JSON file
    def save_json(self, filepath):
        dict_ = {}
        dict_['C'] = self.C
        dict_['max_iter'] = self.max_iter
        dict_['solver'] = self.solver
        dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
        dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'
        # Creat json and save to file
        json_txt = json.dumps(dict_, indent=4)
        with open(filepath, 'w') as file:
    # A method for loading data from JSON file
    def load_json(self, filepath):
        with open(filepath, 'r') as file:
            dict_ = json.load(file)
        self.C = dict_['C']
        self.max_iter = dict_['max_iter']
        self.solver = dict_['solver']
        self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
        self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None

Now let's try the MyLogReg class. First we create an object mylogreg, pass the training data to it, and save it to file. Then we create a new object json_mylogreg and call the load_json method to load the data from file.

filepath = "mylogreg.json"

# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)

# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()

Printing out the new object, we can see our parameters and training data as needed.

Better understand your data with visualizations

  •  30-day no-questions refunds
  •  Beginner to Advanced
  •  Updated regularly (update June 2021)
  •  New bonus resources and guides
     X_train=array([[ 4.3,  3. ,  1.1,  0.1],
       [ 5.7,  4.4,  1.5,  0.4],
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.7,  2.8,  6.7,  2. ]]),
     Y_train=array([0, 0, ..., 2, 2]), class_weight=None, dual=False,
     fit_intercept=True, intercept_scaling=1, max_iter=100,
     multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
     solver='liblinear', tol=0.0001, verbose=0, warm_start=False)

Since the data serialization using JSON actually saves the object into a string format, rather than byte stream, the 'mylogreg.json' file could be opened and modified with a text editor. Although this approach would be convenient for the developer, it is less secure since an intruder can view and amend the content of the JSON file. Moreover, this approach is more suitable for objects with small number of instance variables, such as the scikit-learn models, because any addition of new variables requires changes in the save and restore methods.

Compatibility Issues

While some of the pros and cons of each tool were covered in the text so far, probably the biggest drawback of the Pickle and Joblib tools is its compatibility over different models and Python versions.

Python version compatibility - The documentation of both tools states that it is not recommended to (de)serialize objects across different Python versions, although it might work across minor version changes.

Model compatibility - One of the most frequent mistakes is saving your model with Pickle or Joblib, then changing the model before trying to restore from file. The internal structure of the model needs to stay unchanged between save and reload.

One last issue with both Pickle and Joblib is related to security. Both tools could contain malicious code, so it is not recommended to restore data from untrusted or unauthenticated sources.


In this post we described three tools for saving and restoring scikit-learn models. The Pickle and Joblib libraries are quick and easy to use, but have compatibility issues across different Python versions and changes in the learning model. On the other side, the manual approach is more difficult to implement and needs to be modified with any change in the model structure, but on the plus side it could easily be adapted to various needs, and does not have any compatibility issues.

Last Updated: October 10th, 2017

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

Mihajlo PavloskiAuthor

A research PhD student and coding enthusiast working in Data Science and Machine Learning.

Want a remote job?

    Prepping for an interview?

    • Improve your skills by solving one coding problem every day
    • Get the solutions the next morning via email
    • Practice on actual problems asked by top companies, like:

    Better understand your data with visualizations

    •  30-day no-questions refunds
    •  Beginner to Advanced
    •  Updated regularly (update June 2021)
    •  New bonus resources and guides

    © 2013-2021 Stack Abuse. All rights reserved.