Skip to content

Commit d5ee46e

Browse files
committed
Add Gaussian Naive Bayes classifier in machine_learning/
1 parent e3b01ec commit d5ee46e

1 file changed

Lines changed: 321 additions & 0 deletions

File tree

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
"""
2+
Gaussian Naive Bayes Classifier
3+
4+
A probabilistic classifier based on Bayes' theorem with the assumption that
5+
features follow a Gaussian (normal) distribution within each class.
6+
7+
Despite its simplicity, Gaussian Naive Bayes performs well on many real-world
8+
problems, especially when the number of features is large relative to the
9+
number of training samples.
10+
11+
How it works:
12+
1. Training: Compute the mean and variance of each feature per class,
13+
and the prior probability of each class.
14+
2. Prediction: For each class, compute the log-likelihood of the input
15+
using the Gaussian probability density function, add the
16+
log prior, and pick the class with the highest score.
17+
18+
Bayes' theorem:
19+
P(class | X) ∝ P(X | class) * P(class)
20+
21+
Gaussian PDF:
22+
P(x | mean, var) = exp(-0.5 * ((x - mean)^2 / var)) / sqrt(2 * pi * var)
23+
24+
Time Complexity: O(n * k * d) for training, O(k * d) for prediction
25+
where n = samples, k = classes, d = features
26+
27+
References:
28+
- https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes
29+
- https://en.wikipedia.org/wiki/Bayes%27_theorem
30+
"""
31+
32+
import math
33+
from collections import defaultdict
34+
35+
36+
def separate_by_class(
37+
data: list[list[float]], labels: list[int]
38+
) -> dict[int, list[list[float]]]:
39+
"""
40+
Separate training data by class label.
41+
42+
Args:
43+
data: List of feature vectors.
44+
labels: List of class labels corresponding to each feature vector.
45+
46+
Returns:
47+
A dictionary mapping each class label to its list of feature vectors.
48+
49+
Raises:
50+
ValueError: If data and labels have different lengths.
51+
ValueError: If data is empty.
52+
53+
>>> data = [[1.0, 2.0], [3.0, 4.0], [1.5, 2.5]]
54+
>>> labels = [0, 1, 0]
55+
>>> separated = separate_by_class(data, labels)
56+
>>> separated[0]
57+
[[1.0, 2.0], [1.5, 2.5]]
58+
>>> separated[1]
59+
[[3.0, 4.0]]
60+
>>> separate_by_class([], [])
61+
Traceback (most recent call last):
62+
...
63+
ValueError: Data must not be empty.
64+
>>> separate_by_class([[1.0, 2.0]], [0, 1])
65+
Traceback (most recent call last):
66+
...
67+
ValueError: Data and labels must have the same length.
68+
"""
69+
if not data:
70+
raise ValueError("Data must not be empty.")
71+
if len(data) != len(labels):
72+
raise ValueError("Data and labels must have the same length.")
73+
74+
separated: dict[int, list[list[float]]] = defaultdict(list)
75+
for feature_vector, label in zip(data, labels):
76+
separated[label].append(feature_vector)
77+
return dict(separated)
78+
79+
80+
def compute_mean_variance(values: list[float]) -> tuple[float, float]:
81+
"""
82+
Compute the mean and variance of a list of values.
83+
84+
Uses population variance (divides by n) consistent with the Gaussian PDF
85+
assumption in Naive Bayes.
86+
87+
Args:
88+
values: A non-empty list of numerical values.
89+
90+
Returns:
91+
A tuple of (mean, variance). Variance is clamped to a minimum of 1e-9
92+
to avoid division by zero in the Gaussian PDF.
93+
94+
Raises:
95+
ValueError: If values is empty.
96+
97+
>>> mean, var = compute_mean_variance([2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0])
98+
>>> round(mean, 4)
99+
5.0
100+
>>> round(var, 4)
101+
4.0
102+
>>> compute_mean_variance([5.0])
103+
(5.0, 1e-09)
104+
>>> compute_mean_variance([])
105+
Traceback (most recent call last):
106+
...
107+
ValueError: Values must not be empty.
108+
"""
109+
if not values:
110+
raise ValueError("Values must not be empty.")
111+
112+
n = len(values)
113+
mean = sum(values) / n
114+
variance = sum((x - mean) ** 2 for x in values) / n
115+
return mean, max(variance, 1e-9)
116+
117+
118+
def train(
119+
data: list[list[float]], labels: list[int]
120+
) -> tuple[dict[int, float], dict[int, list[tuple[float, float]]]]:
121+
"""
122+
Train a Gaussian Naive Bayes classifier.
123+
124+
Args:
125+
data: List of feature vectors (training samples).
126+
labels: List of class labels corresponding to each sample.
127+
128+
Returns:
129+
A tuple of:
130+
- priors: dict mapping class label to its log prior probability.
131+
- summaries: dict mapping class label to a list of (mean, variance)
132+
tuples, one per feature.
133+
134+
Raises:
135+
ValueError: If data is empty or lengths mismatch (via helpers).
136+
137+
>>> data = [[1.0, 2.0], [2.0, 3.0], [10.0, 11.0], [11.0, 12.0]]
138+
>>> labels = [0, 0, 1, 1]
139+
>>> priors, summaries = train(data, labels)
140+
>>> round(priors[0], 4)
141+
-0.6931
142+
>>> len(summaries[0]) # two features
143+
2
144+
>>> round(summaries[1][0][0], 1) # mean of feature 0 in class 1
145+
10.5
146+
"""
147+
n_samples = len(data)
148+
separated = separate_by_class(data, labels)
149+
150+
priors: dict[int, float] = {}
151+
summaries: dict[int, list[tuple[float, float]]] = {}
152+
153+
for class_label, class_samples in separated.items():
154+
priors[class_label] = math.log(len(class_samples) / n_samples)
155+
# transpose to get per-feature lists
156+
features_by_column = [
157+
[row[col] for row in class_samples]
158+
for col in range(len(class_samples[0]))
159+
]
160+
summaries[class_label] = [
161+
compute_mean_variance(column) for column in features_by_column
162+
]
163+
164+
return priors, summaries
165+
166+
167+
def gaussian_log_probability(x: float, mean: float, variance: float) -> float:
168+
"""
169+
Compute the log of the Gaussian probability density for a single value.
170+
171+
Uses the formula:
172+
log P(x | mean, var) = -0.5 * log(2 * pi * var)
173+
- 0.5 * ((x - mean)^2 / var)
174+
175+
Args:
176+
x: The observed value.
177+
mean: Mean of the Gaussian distribution.
178+
variance: Variance of the Gaussian distribution (must be > 0).
179+
180+
Returns:
181+
Log probability density as a float.
182+
183+
Raises:
184+
ValueError: If variance is not positive.
185+
186+
>>> round(gaussian_log_probability(1.0, 0.0, 1.0), 4)
187+
-1.4189
188+
>>> round(gaussian_log_probability(0.0, 0.0, 1.0), 4)
189+
-0.9189
190+
>>> gaussian_log_probability(1.0, 0.0, 0.0)
191+
Traceback (most recent call last):
192+
...
193+
ValueError: Variance must be positive.
194+
"""
195+
if variance <= 0:
196+
raise ValueError("Variance must be positive.")
197+
return (
198+
-0.5 * math.log(2 * math.pi * variance)
199+
- 0.5 * ((x - mean) ** 2 / variance)
200+
)
201+
202+
203+
def predict_single(
204+
feature_vector: list[float],
205+
priors: dict[int, float],
206+
summaries: dict[int, list[tuple[float, float]]],
207+
) -> int:
208+
"""
209+
Predict the class label for a single feature vector.
210+
211+
Args:
212+
feature_vector: A list of feature values to classify.
213+
priors: Log prior probabilities per class (from train()).
214+
summaries: Per-class (mean, variance) per feature (from train()).
215+
216+
Returns:
217+
The predicted class label (integer).
218+
219+
>>> data = [[1.0, 2.0], [2.0, 3.0], [10.0, 11.0], [11.0, 12.0]]
220+
>>> labels = [0, 0, 1, 1]
221+
>>> priors, summaries = train(data, labels)
222+
>>> predict_single([1.5, 2.5], priors, summaries)
223+
0
224+
>>> predict_single([10.5, 11.5], priors, summaries)
225+
1
226+
"""
227+
best_label = -1
228+
best_score = float("-inf")
229+
230+
for class_label, feature_summaries in summaries.items():
231+
score = priors[class_label]
232+
for feature_value, (mean, variance) in zip(
233+
feature_vector, feature_summaries
234+
):
235+
score += gaussian_log_probability(feature_value, mean, variance)
236+
if score > best_score:
237+
best_score = score
238+
best_label = class_label
239+
240+
return best_label
241+
242+
243+
def predict(
244+
data: list[list[float]],
245+
priors: dict[int, float],
246+
summaries: dict[int, list[tuple[float, float]]],
247+
) -> list[int]:
248+
"""
249+
Predict class labels for a list of feature vectors.
250+
251+
Args:
252+
data: List of feature vectors to classify.
253+
priors: Log prior probabilities per class (from train()).
254+
summaries: Per-class (mean, variance) per feature (from train()).
255+
256+
Returns:
257+
List of predicted class labels.
258+
259+
Raises:
260+
ValueError: If data is empty.
261+
262+
>>> data = [[1.0, 2.0], [2.0, 3.0], [10.0, 11.0], [11.0, 12.0]]
263+
>>> labels = [0, 0, 1, 1]
264+
>>> priors, summaries = train(data, labels)
265+
>>> predict([[1.5, 2.5], [10.5, 11.5]], priors, summaries)
266+
[0, 1]
267+
>>> predict([[0.5, 1.5], [12.0, 13.0]], priors, summaries)
268+
[0, 1]
269+
>>> predict([], priors, summaries)
270+
Traceback (most recent call last):
271+
...
272+
ValueError: Data must not be empty.
273+
"""
274+
if not data:
275+
raise ValueError("Data must not be empty.")
276+
return [predict_single(vector, priors, summaries) for vector in data]
277+
278+
279+
def accuracy(predictions: list[int], actual: list[int]) -> float:
280+
"""
281+
Compute classification accuracy as a fraction of correct predictions.
282+
283+
Args:
284+
predictions: List of predicted class labels.
285+
actual: List of true class labels.
286+
287+
Returns:
288+
Accuracy as a float between 0.0 and 1.0.
289+
290+
Raises:
291+
ValueError: If inputs are empty or have different lengths.
292+
293+
>>> accuracy([0, 1, 1, 0], [0, 1, 1, 0])
294+
1.0
295+
>>> accuracy([0, 1, 1, 0], [0, 1, 0, 0])
296+
0.75
297+
>>> accuracy([0], [1])
298+
0.0
299+
>>> accuracy([], [])
300+
Traceback (most recent call last):
301+
...
302+
ValueError: Inputs must not be empty.
303+
>>> accuracy([0, 1], [0])
304+
Traceback (most recent call last):
305+
...
306+
ValueError: Predictions and actual labels must have the same length.
307+
"""
308+
if not predictions:
309+
raise ValueError("Inputs must not be empty.")
310+
if len(predictions) != len(actual):
311+
raise ValueError(
312+
"Predictions and actual labels must have the same length."
313+
)
314+
correct = sum(p == a for p, a in zip(predictions, actual))
315+
return correct / len(actual)
316+
317+
318+
if __name__ == "__main__":
319+
import doctest
320+
321+
doctest.testmod(verbose=True)

0 commit comments

Comments
 (0)