K-Means聚类算法

参考文章:K-Means聚类算法原理

假设样本集D={x1,x2,…xm},最大迭代次数是N,聚类的簇数是k,输出是k个簇{C1,C2,…Ck}

  1. 从数据集D中随机选择k个样本作为初始的k个质心向量:{μ1,μ2,…,μk}
  2. 对于n=1,2,…,N
    • 对于i=1,2…m,计算样本xi和各个质心向量μj的距离,若与第j个质心向量最近,归入第j个簇
    • 对于新的k个簇,计算其均值,得到新的质心向量
    • 如果所有的k个质心向量都没有发生变化,则转到步骤3)
  3. 输出簇划分C={C1,C2,…Ck}

代码实现如下:

import numpy as np


class MyKMeans(object):
    def __init__(self, k=2, tolerance=0.0001, max_iter=300):
        self.k_ = k
        self.tolerance_ = tolerance
        self.max_iter_ = max_iter
        self.centers_ = {}
        self.clf_ = {}

    def fit_predict(self, data):
        p_labels = []
        # 初始化向量
        self.centers_ = {}
        for i in range(self.k_):
            self.centers_[i] = data[i]

        for i in range(self.max_iter_):
            # 得到新的簇
            self.clf_ = {}
            for j in range(self.k_):
                self.clf_[j] = []
            for feature in data:
                classification = self.predict(feature)
                self.clf_[classification].append(feature)
            # 计算新的中心点
            prev_centers = dict(self.centers_)
            for c in self.clf_:
                self.centers_[c] = np.average(self.clf_[c], axis=0)
            # 新的中心点是否在误差范围
            optimized = True
            for center in self.centers_:
                org_centers = prev_centers[center]
                cur_centers = self.centers_[center]
                if np.sum((cur_centers - org_centers) / org_centers * 100.0) > self.tolerance_:
                    optimized = False
            if optimized:
                break
        for feature in data:
            classification = self.predict(feature)
            p_labels.append(classification)
        return p_labels

    def predict(self, p_data):
        distances = [np.linalg.norm(p_data - self.centers_[center]) for center in self.centers_]
        index = distances.index(min(distances))
        return index


data = np.array([(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)])
kmeans2 = MyKMeans(k=3)
labels2 = kmeans2.fit_predict(data)
print(labels2, kmeans2.predict((2, 3.1)))
本作品采用《CC 协议》,转载必须注明作者和本文链接
讨论数量: 1

看不懂系列

1个月前 评论

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!