How to Save and Load XGBoost Models

Models are more often than not trained to be deployed to production and to give meaningful predictions for new input. To move them outside of your training environment - you'll want to save a trained model and load it in a different one.

XGBoost is a great, flexible and blazingly fast library with outstanding performance, and its flagship XGBRegressor and XGBClassifier work wonders.

Let's train a simple regressor on a toy dataset:

import xgboost as xgb

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)


xbg_reg = xgb.XGBRegressor().fit(X_train_scaled, y_train)

Now, let's take a look at how we can save and load the models:

model.save_model() and model.load_model()

It's officially recommended to use the save_model() and load_model() functions to save and load models.

Note: dump_model() is used to dump the configurations for interpret-ability and visualization, not for saving a trained state.

Both methods are called on a Booster instance:

Get free courses, guided projects, and more

No spam ever. Unsubscribe anytime. Read our Privacy Policy.

# Save as JSON file
xbg_reg.save_model("model.json")
# Save as TXT file
xbg_reg.save_model("model.txt")

# Blank new instance to be loaded into
xbg_reg = xgb.Booster()

xbg_reg.load_model("model.json")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

You can alternatively specify which booster is being populated, in which case you can feed NumPy arrays, rather than arrays wrapped as DMatrix() matrices:

xbg_reg.save_model("model.json")
xbg_reg.save_model("model.txt")

xbg_reg = xgb.XGBRegressor()

xbg_reg.load_model("model.json")
preds = xbg_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

Note: This approach guarantees compatibility down the line. Using external libraries such as joblib and pickle might lead to compatibility issues if a new version of XGBoost tries loading an older configuration, serialized before the new version. You can regulate this by pinning versions or rolling back and serializing again, but can be avoided altogether by using the official API.

Joblib

Joblib is a serialization library, with a beautifully simple API that lets you save models in a variety of formats:

import joblib

joblib.dump(xbg_reg, "xgb_reg.sav")
xgb_reg = joblib.load("xgb_reg.sav")

preds = xgb_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

Pickle

Pickle is another serialization library that allows you to easily serialize models, but works with files a bit more manually:

import pickle

pickle.dump(xbg_reg, open("xgb_reg.sav", "wb"))
xgb_reg = pickle.load(open("xgb_reg.sav", "rb"))

preds = xgb_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]
Was this helpful?
David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

Project

Bank Note Fraud Detection with SVMs in Python with Scikit-Learn

# python# machine learning# scikit-learn# data science

Can you tell the difference between a real and a fraud bank note? Probably! Can you do it for 1000 bank notes? Probably! But it...

David Landup
Cássia Sampaio
Details

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms