# -*- coding: UTF-8 -*-
"""
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@File : knn.py
@Contact : ffzzyy@126.com
@License : (C)Copyright 2017-2019
@Author : ffzzyy
@Version : 0.1
@Modify Time : 2019/3/13 22:51
@Desciption
"""
import numpy as np
from sklearn import datasets
from collections import Counter
from sklearn.neighbors import KNeighborsClassifier
class KNNClassifier:
def __init__(self, k):
assert k >= 1, 'k must be valid'
self.k = k
self._x_train = None
self._y_train = None
def fit(self, x_train, y_train):
'''
训练函数
:param x_train:
:param y_train:
:return:
'''
self._x_train = x_train
self._y_train = y_train
return self
def _predict(self, x):
'''
针对单个训练集进行预测
:param x:
:return:
'''
"""
列表推导式
使用欧式距离测试与训练集各个样本之间的距离
下一个程序可考虑一下使用P
"""
d = [np.sqrt(np.sum((x_i - x) ** 2)) for x_i in self._x_train]
# 排序得到序号的列表
nearest = np.argsort(d)
# 得到训练集的Y值
top_k=self._y_train[nearest[0:self.k]]
#top_k = [self._y_train[i] for i in nearest[:self.k]]
print("top_k_nearest=",top_k)
votes = Counter(top_k)
"""
# 返回Counter中最出现次数最多的列表
top_k=[2 1 0 2 2 1]
votes = Counter(top_k)
votes=[(2,3),(1,2),(0,0)]
"""
return votes.most_common(1)[0][0]
def predict(self, X_predict):
'''
预测函数
:param X_predict:
:return:
'''
y_predict = [self._predict(x1) for x1 in X_predict]
return np.array(y_predict)
def __repr__(self):
return 'knn(k=%d):' % self.k
def score(self, x_test, y_test):
y_predict = self.predict(x_test)
return sum(y_predict == y_test) / len(x_test)
def main():
# 使用sklearn训练集
iris = datasets.load_iris()
iris_x = iris.data
iris_y = iris.target
# 使用自己实现的knn算法
input_set = [(10, 8, 3, 2),(6,2,3,1)]
knn=KNNClassifier(6)
knn.fit(iris_x,iris_y)
knn_predict_set=knn.predict(input_set)
print(knn_predict_set)
# 使用sk-learn实现的knn算法来进行对比
knn_classifier=KNeighborsClassifier(6)
knn_classifier.fit(iris_x,iris_y)
y_predict=knn_classifier.predict(input_set)
print(y_predict)
"""
定义使用不同的k,查看得到的预测值
"""
for k in range(2,int(np.sqrt(len(iris_x)))):
knn = KNNClassifier(k)
knn.fit(iris_x, iris_y)
knn_predict_set = knn.predict(input_set)
print("k={0} predict set:[{1}]".format(k,knn_predict_set))
if __name__ == '__main__':
main()