Life's monolog

用python实现朴素贝叶斯分类算法

Word count: 713 / Reading time: 3 min
2018/04/24 Share

简介

朴素贝叶斯算法主要用于分类问题,原理十分简单,主要采用后验概率最大化的方法来判定测试样本的类别。对于某个给定的测试样本$ x = \{x_1,…,x_n\} $,其类别为$y$的概率可以通过贝叶斯公式计算:

其中$P(y)$是类别$y$的先验概率,$P(x_1,…,x_n|y)$是条件概率。后验概率最大化就是找出能够使得$P(y|x_1,…,x_n)$最大的类别$y$。

由于朴素贝叶斯算法假设各个特征之间相互独立,所以可以得到下面的公式:

因为$P(x_1,…,x_n)$是个不变量,所以只要考虑分子,最后朴素贝叶斯算法分类的公式可以表示为:

代码实现

《统计学习方法》中举了一个关于朴素贝叶斯算法分类的例子,但这个例子中特征是离散特征,所以需要分别对于每一类特征的每一个值计算条件概率$P(x_i | y)$。对于连续型的数值特征,比较常见的做法是假设特征符合某项概率分布,例如高斯分布、伯努利分布等。本文的实现针对连续型的数值特征,假设其符合高斯分布。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class NaiveBayes:
def fit(self, X, y):
"""朴素贝叶斯算法
假设条件概率p(x|y)符合高斯分布
为每一类的样本中特征的每一个维度拟合一个高斯分布

Args:
X: 训练特征
y: 训练标签
"""
self.X, self.y = X, y
self.classes = np.unique(y)
self.parameters = []

for i, c in enumerate(self.classes):
tmp_X = X[np.where(y == c)]
self.parameters.append([])
for col in tmp_X.T:
parameters = {"mean": col.mean(), "var": col.var()}
self.parameters[i].append(parameters)

def _calculate_priori(self, c):
"""根据训练数据计算每一类的先验概率"""
tmp_X = self.X[np.where(self.y == c)]
return len(tmp_X) / len(self.X)

def _calculate_likelihood(self, mean, var, x):
"""计算条件概率"""
eps = 1e-4
coef = 1.0 / math.sqrt(2.0 * math.pi * var + eps)
exponent = math.exp(-math.pow(x - mean, 2) / (2 * var + eps))
return coef * exponent

def _classify(self, sample):
"""采用最大后验概率来分类"""
posteriors = []
for i, c in enumerate(self.classes):
posterior = self._calculate_priori(c)
for feature_value, params in zip(sample, self.parameters[i]):
likelihood = self._calculate_likelihood(params['mean'], params['var'], feature_value)
posterior *= likelihood
posteriors.append(posterior)
return self.classes[np.argmax(posteriors)]

def predict(self, X):
pred = [self._classify(x) for x in X]
return pred

下面采用sklearn自带的iris数据集测试实现的模型,可以看到分类的accuracy在0.96。

1
2
3
4
5
6
model = NaiveBayes()
iris = datasets.load_iris()

model.fit(iris.data, iris.target)
y_pred = model.predict(iris.data)
print(f'accuracy = {accuracy_score(iris.target, y_pred)}') # 输出 accuracy = 0.96

完整代码在这里

参考

  1. 《统计学习方法》
  2. ML-From-Scratch
CATALOG
  1. 1. 简介
  2. 2. 代码实现
  3. 3. 参考