Avoiding underflows in Gaussian Naive Bayes

1 minute read

There are mainly two ways to avoid numerical instability when implementing Gaussian Naive Bayes (GNB). Either one applies the log-sum-exp trick or one takes the logarithm and increases the variance by some number . In this blog post, I will follow the second approach.

Derivation 1

Let us first derive GNB via Bayes’ theorem.

can be removed because the term does not depend on .

Assume the features are conditional independent given . Then

Next, either use the likelihood ratio test or MAP as decision rule. We will use MAP and obtain

Assume .

The result is GNB. In order to avoid underflows, we apply the logarithm.

After removing some constants, we get

You could use the formula as it is now. However, I will go one step further and assume the uniform prior . We can simplify the equation even more.

Derivation 2

Before implementing GNB, I will provide an alternative derivation via Quadratic Discriminant Analysis (QDA).

QDA assumes like GNB normal distributed features, but it does not assume conditional independence.

Let us derive GNB from QDA, which is given by the following equation:

Remove constants, prior and use arg min.

Assume conditional independence. Recall that when two random variables are independent, their covariance is zero.

Hence, this assumption will turn into a diagonal matrix. The determinant becomes . By replacing also the second term, we obtain GNB.


The following code is a direct implementation of equation (1).

from sklearn.datasets import load_breast_cancer
import numpy as np

X, y = load_breast_cancer(return_X_y=True)

Mu = np.array([X[y == c, :].mean(axis=0) for c in np.unique(y)])
epsilon = 1e-9 * np.var(X, axis=0).max()
Sigma = np.array([X[y == c, :].var(axis=0) + epsilon for c in np.unique(y)])
summed = np.sum(np.log(Sigma), axis=1)
y_hat = np.argmin([summed[j] + np.sum((X - Mu[j]) ** 2 / Sigma[j], axis=1) for j in range(Mu.shape[0])], axis=0)

There is one difference with respect to (1). We have to increase the variance by some small number to avoid numerical errors. These happen due to our fraction . If you read the scikit-learn code, you will find out that they also use this trick.

This code gives the same accuracy as scikit-learn’s implementation of GNB. However, I didn’t test this extensively.