K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 KNN方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
算法流程:
/** * 小顶堆求topN */ public class MinHeapPriorityQueue<T extends Comparable<T>> { private PriorityQueue<T> queue; private int maxSize; /** * @param maxSize */ public MinHeapPriorityQueue(int maxSize) { this(maxSize, new Comparator<T>() { @Override public int compare(T o1, T o2) { return o1.compareTo(o2); } }); } public MinHeapPriorityQueue(int maxSize, Comparator<T> comparator) { this.maxSize = maxSize; this.queue = new PriorityQueue<>(maxSize, comparator); } public synchronized void insert(T t) { if (queue.size() < maxSize) { queue.add(t); } else { T tmp = queue.peek(); if (t.compareTo(tmp) > 0) { queue.poll(); queue.add(t); } } } public synchronized List<T> sortList() { List<T> list = new LinkedList<>(queue); Collections.sort(list, new Comparator<T>() { @Override public int compare(T o1, T o2) { return o2.compareTo(o1); } }); return list; } public synchronized List<T> getList(){ List<T> list = new LinkedList<>(queue); return list; } public static double format(double d, int n) { double p = Math.pow(10, n); return Math.round(d * p) / p; } public static void main(String[] args) { MinHeapPriorityQueue<Double> queue = new MinHeapPriorityQueue<>(3); Random r = new Random(); StringBuffer buf = new StringBuffer(); for (int i = 0; i < 20; i++) { double rd = format(r.nextDouble(), 3); queue.insert(rd); buf.append(rd); if (i != 19) buf.append(", "); } System.out.println("buff: " + buf.toString()); System.out.println("list: " + queue.sortList()); } }
knn算法实现:
public class KNN { public String knn(List<List<Double>> datas, List<Double> testData, int k) { MinHeapPriorityQueue queue = new MinHeapPriorityQueue(k); for (int i = 0; i < datas.size(); i++) { List<Double> t = datas.get(i); double distance = calDistance(t, testData); queue.insert(new TrainTuple(i, distance, t.get(t.size() - 1).toString())); } return getMostClass(queue); } /** * 计算测试数据和训练数据元组的距离 * * @param trainData * @param testData * @return */ private double calDistance(List<Double> trainData, List<Double> testData) { double sum = 0d; double distance = 0d; for (int i = 0; i < trainData.size() - 1 ; i++) { sum += (trainData.get(i) - testData.get(i)) * (trainData.get(i) - testData.get(i)); } distance = Math.sqrt(sum); return distance; } /** * 获取所得到的k个最近邻元组的多数类别 * * @param queue * @return 多数类别名称 */ private String getMostClass(MinHeapPriorityQueue queue) { Map<String, Integer> classCountMap = new HashMap<>(); List<TrainTuple> arrayList = queue.getList(); for (int i = 0; i < arrayList.size(); i++) { TrainTuple tuple = arrayList.get(i); String classify = tuple.getClassify(); if(classCountMap.containsKey(classify)){ classCountMap.put(tuple.getClassify(),classCountMap.get(classify) + 1); }else{ classCountMap.put(classify,1); } } int maxIndex = -1; int maxCount = 0; Object[] classes = classCountMap.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCountMap.get(classes[i]) > maxCount) { maxIndex = i; maxCount = classCountMap.get(classes[i]); } } return classes[maxIndex].toString(); } }
具体的代码实现可以参考:https://github.com/yl897958450/knn
转载请注明出处。
原文:http://www.cnblogs.com/ylcoder/p/6285006.html