The Naive Bayes Algorithm in Python with Scikit-Learn

When studying Probability & Statistics, one of the first and most important theorems students learn is the Bayes' Theorem. This theorem is the foundation of deductive reasoning, which focuses on determining the probability of an event occurring based on prior knowledge of conditions that might be related to the event.

The Naive Bayes Classifier brings the power of this theorem to Machine Learning, building a very simple yet powerful classifier. In this article, we will see an overview on how this classifier works, which suitable applications it has, and how to use it in just a few lines of Python and the Scikit-Learn library.

Theory Behind Bayes' Theorem

If you studied Computer Science, Mathematics, or any other field involving statistics, it is very likely that at some point you stumbled upon the following formula:

P(H|E) = (P(E|H) * P(H)) / P(E)


  • P(H|E) is the probability of hypothesis H given the event E, a posterior probability.
  • P(E|H) is the probability of event E given that the hypothesis H is true.
  • P(H) is the probability of hypothesis H being true (regardless of any related event), or prior probability of H.
  • P(E) is the probability of the event occurring (regardless of the hypothesis).

This is the Bayes Theorem. At first glance it might be hard to make sense out of it, but it is very intuitive if we explore it through an example:

Let's say that we are interested in knowing whether an e-mail that contains the word sex (event) is spam (hypothesis). If we go back to the theorem description, this problem can be formulated as:

P(class=SPAM|contains="sex") = (P(contains="sex"|class=SPAM) * P(class=SPAM)) / P(contains="sex")

which in plain English is: The probability of an e-mail containing the word sex being spam is equal to the proportion of SPAM emails that contain the word sex multiplied by the proportion of e-mails being spam and divided by the proportion of e-mails containing the word sex.

Let's dissect this piece by piece:

  • P(class=SPAM|contains="sex") is the probability of an e-mail being SPAM given that this e-mail contains the word sex. This is what we are interested in predicting.
  • P(contains="sex"|class=SPAM) is the probability of an e-mail containing the word sex given that this e-mail has been recognized as SPAM. This is our training data, which represents the correlation between an e-mail being considered SPAM and such e-mail containing the word sex.
  • P(class=SPAM) is the probability of an e-mail being SPAM (without any prior knowledge of the words it contains). This is simply the proportion of e-mails being SPAM in our entire training set. We multiply by this value because we are interested in knowing how significant is information concerning SPAM e-mails. If this value is low, the significance of any events related to SPAM e-mails will also be low.
  • P(contains="sex") is the probability of an e-mail containing the word sex. This is simply the proportion of e-mails containing the word sex in our entire training set. We divide by this value because the more exclusive the word sex is, the more important is the context in which it appears. Thus, if this number is low (the word appears very rarely), it can be a great indicator that in the cases it does appear, it is a relevant feature to analyze.

In summary, the Bayes Theorem allows us to make reasoned deduction of events happening in the real world based on prior knowledge of observations that may imply it. To apply this theorem to any problem, we need to compute the two types of probabilities that appear in the formula.

Class Probabilities

In the theorem, P(A) represents the probabilities of each event. In the Naive Bayes Classifier, we can interpret these Class Probabilities as simply the frequency of each instance of the event divided by the total number of instances. For example, in the previous example of spam detection, P(class=SPAM) represents the number of e-mails classified as spam divided by the sum of all instances (this is spam + not spam)

P(class=SPAM) = count(class=SPAM) / (count(class=notSPAM) + count(class=SPAM))

Conditional Probabilities

In the theorem, P(A|B) represents the conditional probabilities of an event A given another event B. In the Naive Bayes Classifier, these encode the posterior probability of A occurring when B is true.

For the spam example, P(class=SPAM|contains="sex") represents the number of instances in which an e-mail is considered as spam and contains the word sex, divided by the total number of e-mails that contain the word sex:

P(class=SPAM|contains="sex") = count(class=SPAM & contains=sex) / count(contains=sex)


The application of the Naive Bayes Classifier has been shown successful in different scenarios. A classical use case is document classification: determining whether a given document corresponds to certain categories. Nonetheless, this technique has its advantages and limitations.


  • Naive Bayes is a simple and easy to implement algorithm. Because of this, it might outperform more complex models when the amount of data is limited.
  • Naive Bayes works well with numerical and categorical data. It can also be used to perform regression by using Gaussian Naive Bayes.


  • Given the construction of the theorem, it does not work well when you are missing certain combination of values in your training data. In other words, if you have no occurrences of a class label and a certain attribute value together (e.g. class="spam", contains="$$$") then the frequency-based probability estimate will be zero. Given Naive-Bayes' conditional independence assumption, when all the probabilities are multiplied you will get zero.

  • Naive Bayes works well as long as the categories are kept simple. For instance, it works well for problems involving keywords as features (e.g. spam detection), but it does not work when the relationship between words is important (e.g. sentiment analysis).

Demo in Scikit-Learn

It's demo time! We will use Python 3 together with Scikit-Learn to build a very simple SPAM detector for SMS messages (for those of you that are youngsters, this is what we used for messaging back in the middle ages). You can find and download the dataset from this link.

We will need three libraries that will make our coding much easier: scikit-learn, pandas and nltk. You can use pip or conda to install these.

Loading the Data

The SMS Spam Collection v.1 is a set of SMS tagged messages that have been collected for SMS Spam research. It contains one set of SMS messages in English of 5,574 messages, tagged according being ham (legitimate) or spam. The distribution is a total of 4,827 SMS legitimate messages (86.6%) and a total of 747 (13.4%) spam messages.

If we open the dataset, we will see that it has the format [label] [tab] [message], which looks something like this:

ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...

ham	Ok lar... Joking wif u oni...

spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's

ham	U dun say so early hor... U c already then say...

To load the data, we can use Pandas' Dataframe read_table method. This allows us to define a separator (in this case, a tab) and rename the columns accordingly:

import pandas as pd

df = pd.read_table('SMSSpamCollection',
                   names=['label', 'message'])


Once we have our data ready, it is time to do some preprocessing. We will focus on removing useless variance for our task at hand. First, we have to convert the labels from strings to binary values for our classifier:

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

df['label'] ={'ham': 0, 'spam': 1})

Second, convert all characters in the message to lower case:

df['message'] = x: x.lower())

Third, remove any punctuation:

df['message'] = df.message.str.replace('[^\w\s]', '')

Fourth, tokenize the messages into into single words using nltk. First, we have to import and download the tokenizer from the console:

import nltk

An installation window will appear. Go to the "Models" tab and select "punkt" from the "Identifier" column. Then click "Download" and it will install the necessary files. Then it should work! Now we can apply the tokenization:

df['message'] = df['message'].apply(nltk.word_tokenize)

Fifth, we will perform some word stemming. The idea of stemming is to normalize our text for all variations of words carry the same meaning, regardless of the tense. One of the most popular stemming algorithms is the Porter Stemmer:

from nltk.stem import PorterStemmer

stemmer = PorterStemmer()
df['message'] = df['message'].apply(lambda x: [stemmer.stem(y) for y in x])

Finally, we will transform the data into occurrences, which will be the features that we will feed into our model:

from sklearn.feature_extraction.text import CountVectorizer

# This converts the list of words into space-separated strings
df['message'] = df['message'].apply(lambda x: ' '.join(x))

count_vect = CountVectorizer()
counts = count_vect.fit_transform(df['message'])

We could leave it as the simple word-count per message, but it is better to use Term Frequency Inverse Document Frequency, more known as tf-idf:

from sklearn.feature_extraction.text import TfidfTransformer

transformer = TfidfTransformer().fit(counts)

counts = transformer.transform(counts)

Training the Model

Now that we have performed feature extraction from our data, it is time to build our model. We will start by splitting our data into training and test sets:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(counts, df['label'], test_size=0.1, random_state=69)

Then, all that we have to do is initialize the Naive Bayes Classifier and fit the data. For text classification problems, the Multinomial Naive Bayes Classifier is well-suited:

from sklearn.naive_bayes import MultinomialNB

model = MultinomialNB().fit(X_train, y_train)

Evaluating the Model

Once we have put together our classifier, we can evaluate its performance in the testing set:

import numpy as np

predicted = model.predict(X_test)

print(np.mean(predicted == y_test))

Congratulations! Our simple Naive Bayes Classifier has 98.2% accuracy with this specific test set! But it is not enough by just providing the accuracy, since our dataset is imbalanced when it comes to the labels (86.6% legitimate in contrast to 13.4% spam). It could happen that our classifier is over-fitting the legitimate class while ignoring the spam class. To solve this uncertainty, let's have a look at the confusion matrix:

from sklearn.metrics import confusion_matrix

print(confusion_matrix(y_test, predicted))

The confusion_matrix method will print something like this:

[[478   4]
[   6  70]]

As we can see, the amount of errors is pretty balanced between legitimate and spam, with 4 legitimate messages classified as spam and 6 spam messages classified as legitimate. Overall, these are very good results for our simple classifier.


In this article, we have seen a crash-course on both theory and practice of the Naive Bayes Classifier. We have put together a simple Multimodal Naive Bayes Classifier that achieves 98.2% accuracy on spam detection for SMS messages.

Last Updated: July 10th, 2018
Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

Daniyal ShahrokhianAuthor

Creating innovative new products to the world.

Specialized in AI and Data applications across industries (Automotive, Smart Cities, User Experience, Infrastructure, Retail, and many more!)


Bank Note Fraud Detection with SVMs in Python with Scikit-Learn

# python# machine learning# scikit-learn# data science

Can you tell the difference between a real and a fraud bank note? Probably! Can you do it for 1000 bank notes? Probably! But it...

David Landup
Cássia Sampaio

Hands-On House Price Prediction - Machine Learning in Python

# python# machine learning# scikit-learn# tensorflow

If you've gone through the experience of moving to a new house or apartment - you probably remember the stressful experience of choosing a property,...

David Landup
Ammar Alyousfi
Jovana Ninkovic

© 2013-2024 Stack Abuse. All rights reserved.