How to Save and Load XGBoost Models

How to Save and Load XGBoost Models

Models are more often than not trained to be deployed to production and to give meaningful predictions for new input. To move them outside of your training environment - you'll want to save a trained model and load it in a different one.

XGBoost is a great, flexible and blazingly fast library with outstanding performance, and its flagship XGBRegressor and XGBClassifier work wonders.

Let's train a simple regressor on a toy dataset:

import xgboost as xgb

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)


xbg_reg = xgb.XGBRegressor().fit(X_train_scaled, y_train)

Now, let's take a look at how we can save and load the models:

model.save_model() and model.load_model()

It's officially recommended to use the save_model() and load_model() functions to save and load models.

Note: dump_model() is used to dump the configurations for interpretability and visualization, not for saving a trained state.

Both methods are called on a Booster instance:

Get free courses, guided projects, and more

No spam ever. Unsubscribe anytime. Read our Privacy Policy.

# Save as JSON file
xbg_reg.save_model("model.json")
# Save as TXT file
xbg_reg.save_model("model.txt")

# Blank new instance to be loaded into
xbg_reg = xgb.Booster()

xbg_reg.load_model("model.json")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

You can alternativelt specify which booster is being populated, in which case you can feed NumPy arrays, rather than arrays wrapped as DMatrix() matrices:

xbg_reg.save_model("model.json")
xbg_reg.save_model("model.txt")

xbg_reg = xgb.XGBRegressor()

xbg_reg.load_model("model.json")
preds = xbg_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

Note: This approach guarantees compatability down the line. Using external libraries such as joblib and pickle might lead to compatability issues if a new version of XGBoost tries loading an older configuration, serialized before the new version. You can regulate this by pinning versions or rolling back and serializing again, but can be avoided altogether by using the official API.

Joblib

Joblib is a serialization library, with a beautifully simple API that lets you save models in a variety of formats:

import joblib

joblib.dump(xbg_reg, "xgb_reg.sav")
xgb_reg = joblib.load("xgb_reg.sav")

preds = xgb_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]

Pickle

Pickle is another serialization library that allows you to easily serialize models, but works with files a bit more manually:

import pickle

pickle.dump(xbg_reg, open("xgb_reg.sav", "wb"))
xgb_reg = pickle.load(open("xgb_reg.sav", "rb"))

preds = xgb_reg.predict(X_test_scaled)
print(preds[:10]) # [138.68327  73.31986 212.1568  138.29817 277.7322   74.42521 152.05441 83.62783 131.63977 145.00095]
Was this helpful?
David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

Project

Real-Time Road Sign Detection with YOLOv5

# python# machine learning# computer vision# pytorch

If you drive - there's a chance you enjoy cruising down the road. A responsible driver pays attention to the road signs, and adjusts their...

David Landup
David Landup
Details
Project

Hands-On House Price Prediction - Machine Learning in Python

# python# machine learning# scikit-learn# tensorflow

If you've gone through the experience of moving to a new house or apartment - you probably remember the stressful experience of choosing a property,...

David Landup
Ammar Alyousfi
Jovana Ninkovic
Details

© 2013-2022 Stack Abuse. All rights reserved.

DisclosurePrivacyTerms