image of the blog
ShadowThink Logo

kNN Classify Handwritten Digits

kNN Intuition

As a common nonparametric learning algorithm, the intuition behind kNN is pretty simple. For every unclassified test point, find k nearest neighbors in the training dataset. Then predict the class of the test point according to the classes of these k nearest neighbors. To be summary, it’s a kind of geometric intuition for prediction.

kNN Distance Metric

The most widely used distance metric is $L_p$ distance. The distance between $x_i$ and $x_j$ is:

in which $x_i$ and $x_j$ both have $n$ dimensions. Specifically, when $p = 2$, it’s Euclidean distance; when $p = 1$, it’s Manhattan distance.

Because $L_p$ distance doesn’t take into account the different meaning of data axes, it is not suitable for some dataset. For example, the first axis of the data point has meaning $A$ which ranges from 100 to 1000, while the second axis has meaning $B$ which ranges from 0 to 1. So the whole distance between two points is almost decided by the first axis, that makes the distance metric less reasonable. To resolve this issue, we may use Mahalanobis distance as the distance metric. But why Mahalanobis distance work?

The Mahalanobis distance between $x_i$ and $x_j$ is:

in which $C$ is the covariance matrix. and $C_{ij} = \frac{\sum_{i=1}^{n}(x_i - \bar{X}_i)(x_j - \bar{X}_j)}{n-1}$

It’s quite abstract. However, from its mathematical formula, it’s somewhat like the ellipse. A typical ellipse can be written as:

Here $A$ is any positive definite matrix. $y = (y_1, y_2, \ldots, y_n)$ is the center of the ellipse. Then the trace of $x = (x_1, x_2, \ldots, x_n)$ shapes the ellipse. In particular, if the $x_i, x_j \;(i \neq j)$ don’t have any correlations, $A$ is a diagonal matrix, like $A = diag(a_1^2, a_2^2, \ldots, a_n^2)$. Then we get the familiar form of ellipse:

From this geometric explanation, we know that the points on the ellipse have the same Mahalanobis distance. To some extent, it flexibly considers the relationship between different attributes as we can tune the value of $a_1, a_2, \ldots, a_n$ as requirements. More generally, the covariance matrix $C$ is a real symmetric matrix. It can be decomposed as

where $Q$ is an orthogonal matrix, $Q^{-1} = Q^{T}$, and $\Lambda$ is a diagonal matrix whose entries are the eigenvalues of $C$. Substitute $C$ into $d(x_i, x_j)$, we have

As we can see, currently, the $d(x_i, x_j)$ is in the same format of ellipse with just a tranformation on all points. Thus, our analysis from ellipse to Mahalanobis distance is reasonable.

Classification Rules

To use kNN in a classification task, we need to select proper $k$ value and make rules (i.e. classification function) on how to use the k nearest neighbors. The distance metric and classification function influence the sensitivity to $k$ value for given dataset. Usually, we select majority class of the k nearest neighbors as the prediction. But it’s also reasonable to weight the neighbor close to the test point more heavily than the others which are at a greater distance from the test point. The second strategy is called distance-weighted kNN(wkNN). Mathematically, the first strategy can be presented as:

in which $C_i$ is the class of $i$th nearest neighbor.

To describe the second strategy, we need a weighting function. Here is an example:

in which $d_j$ is the distance of the test point to the $j$th nearest neighbor. In this way, we can weight the k nearest neighbors with $w_j \in [0, 1]$. Some researches use Monte Carlo method to estimate the probability of error for different $k$ value. And the results showed that comparing to the basic kNN, wkNN has less sensitivity to $k$ value. Thus, it’s easier for wkNN to select $k$ value. You can just select a fairly large value of $k$ without fear of the probability of error obtained being significantly greater that that which could be achieved with the optimum value of $k$.

kd-tree Based kNN

A linearly scan based kNN needs to scan all test dataset for every test point, it’s quite costly operations. An optimization method is to use kd-tree based kNN. At a high level, a kd-tree is a generalization of a binary search tree that stores poins in k-dimensional space. To generate kd-tree, I designed a Python class KDNode to store a binary search tree as a root node, then use KDTree class to encapsulate kd-tree operations. The procedure is similar to the construction of binary search tree, but in kd-tree generation, it compares $i$th component of the data point. And at $n$th level of the tree with nodes dimension $k$, we have $i = n\mod{k}$. For the details, please check the following code snippets.

class KDNode:
    """A Node in a kd-tree
    A tree is represented by its root, and every node represents its subtree
    def __init__(self, data=None, left=None, right=None): = data  # store the index of this KDNode in dataset
        self.left = left
        self.right = right

    def set_data(self, data): = data

    def set_left(self, left):
        self.left = left

    def set_right(self, right):
        self.right = right

    def preorder(self):
        """iterator for node: root, left, right"""
        if not self:

        yield self

        if self.left:
            for x in self.left.preorder():
                yield x

        if self.right:
            for x in self.right.preorder():
                yield x

    def __str__(self):
        printer = 'root, left, right paris are:\n'
        iterator = self.preorder()
        x = next(iterator, None)
        while x:
            if is not None:
                printer += 'index: ' + str(
            if x.left and is not None:
                printer += ' left: ' + str(

            if x.right and is not None:
                printer += ' right: ' + str(

            if is not None:
                printer += '\n'
            x = next(iterator, None)

        return printer

To leverage the power of kd-tree, I implemented a $O(log N)$ search algorithm to find top k nearest neighbors. Suppose that we have a guess of the nearest neighbor, if we want to check it, we usually draw a hypersphere in $n$ dimenstional space and check whether other points are in this hypersphere. If so, this guess isn’t the nearest neighbor, otherwise this guess it right the nearest one. However, it’s hard to find other points in given hypersphere. Using the philosophy of divide-and-conque, we consider two partitions of the guess point separately. The two partitions are splitted by the hyperplane of the guess point in kd-tree, as illustrated in following figure, if the hyperplane cross the hyperspace, then the real nearest neighbor can be in two partitions, otherwise, it can only be in one partition. According to this geometric intuition, we can prune the search space when it’s necessary. In particular, given a kd-tree holding point $(a_0, a_1, \ldots, a_k)$ and the hypersphere of radius $r$ centered at $(b_0, b_1, \ldots, b_k)$, the hyperplane cross the hypersphere only if $|b_i - a_i| < r$.

hyperplane cross hypersphere in kdtree

Moreover, to maintain top k nearest neighbors, we use bounded priority queue(BPQ) to save top k points with smallest distance between the test point.

Algorithm: search k-nn
Input: kd-tree, the test point P = (x0, x1, ..., xk)
Output: the BPQ which maintains the k nearest neighbors and their distances

starting at the root, execute the following procedure:
    if curr == NULL

    / * Add current checking point to the BPQ */
    enqueue curr into bpq with priority distance(curr, P)

    /* Recursively search the half of the tree that contains the test point */
    if x_i < curr_i:
        recursively search the left subtree on the next axis
        recursively search the right subtree on the next axis

    /* If the hyperplane crosses the hypersphere */
    /* look on the other side of the plane by examining the other subtree */
    if bpq isn't full or |curr_i - y_i| is less than the priority of the max-priority element of bpq then
        recursively search the other subtree on the next axis

return bpq

The whole KDTree class implementation:

class KDTree:
    def __init__(self, dataset=None):
        self.dataset = dataset
        self.dim_size = 1
        self.root = KDNode()

    def generate_by(self, dataset):
        """Generate kd-tree by given dataset,
        store the kd-tree in self.root"""
        # mark all data entry not in the kd-tree
        self.dataset = dataset
        dataset_size = dataset.shape[0]

        self.dim_size = dataset.shape[1]
        selected_indices = np.arange(dataset_size)
        self._rec_generate_by(self.root, selected_indices, 0)

    def _rec_generate_by(self, root, indices, level):
        """Recursively generate kd-tree"""
        if indices.shape[0] == 0:

        # slice and get the kth dimension feature
        kth = level % self.dim_size
        selected_entries = self.dataset[indices][:, kth]

        # relative sorted indices to absolute indices
        rel_sorted_indices = selected_entries.argsort()
        mid = math.ceil(rel_sorted_indices.shape[0] / 2.0) - 1
        mid_idx = indices[rel_sorted_indices[mid]]

        # left indices
        left_indices = indices[rel_sorted_indices[:mid]]
        root.left = KDNode()
        self._rec_generate_by(root.left, left_indices, level + 1)

        # right indices
        right_indices = indices[rel_sorted_indices[mid+1:]]
        root.right = KDNode()
        self._rec_generate_by(root.right, right_indices, level + 1)

    def search_knn(self, point, k=1, dmetric=EUC_DM):
        """Return the k nearest neighbors of given points
        and their distances"""
        results = BPQ(k)
        self._search_node(self.root, point, k, results, 0, dmetric)
        sorted_indices = [i[0] for i in results.queue]
        distances = [i[1] for i in results.queue]
        return sorted_indices, distances

    def _search_node(self, curr, point, k, results, level, dmetric):
        if is None:

        dist = self._get_distance(curr, point, dmetric)
        results.add(, dist)

        # get the splitting plane
        kth = level % self.dim_size
        split_plane = self.dataset[][kth]

        # get the distance between the point and the splitting plane
        plane_dist2 = (point[kth] - split_plane) ** 2
        plane_dist = plane_dist2 ** 0.5

        # search the side of the splitting plane that the point is in
        if point[kth] < split_plane:
            if curr.left is not None:
                self._search_node(curr.left, point, k, results, level + 1,
            if curr.right is not None:
                self._search_node(curr.right, point, k, results, level + 1,

        # search the other side of the splitting plane if
        # the splitting is cross the hypersphere
        if plane_dist < results.max_priority() or (not results.is_full()):
            if point[kth] < split_plane:
                if curr.right is not None:
                    self._search_node(curr.right, point, k, results, level + 1,
                if curr.left is not None:
                    self._search_node(curr.left, point, k, results, level + 1,

    def _get_distance(self, node, point, dmetric):
        node_idx =
        node_data = self.dataset[node_idx]
        return distance(node_data, point, dmetric)

Handwritten Digits Classifier

This handwritten digits dataset is from Machine Learning in Action. You can download it from here. The pixels of digits image are represented as a 0/1 matrix. And the digits images are labeled by their filenames, like the following one:

digits image example

Read the dataset into a large matrix:

import os
import numpy as np

def img2vector(filename):
    """convert image file to feature vector"""
    feature_vect = np.zeros((1, 32 * 32))
    fr = open(filename)
    for i in range(32):
        line_str = fr.readline()
        for j in range(32):
            feature_vect[0, 32*i+j] = int(line_str[j])
    return feature_vect

def read_data():
    """read training data and test data from folder `digits/`,
    return a tuple (training-data, test-data)"""
    folder = 'digits'
    training_folder = os.path.join(folder, 'trainingDigits')
    test_folder = os.path.join(folder, 'testDigits')

    training_imgs = os.listdir(training_folder)
    test_imgs = os.listdir(test_folder)

    training_data = np.zeros((len(training_imgs), 32 * 32 + 1))
    test_data = np.zeros((len(test_imgs), 32 * 32 + 1))

    # read training data from training folder
    for k, v in enumerate(training_imgs):
        img = os.path.join(training_folder, v)
        feature_vect = img2vector(img)
        label = np.array([[int(v.split('_')[0])]])
        # data item format: label, feature_vect
        training_data[k, :] = np.concatenate((label, feature_vect), axis=1)

    # read test data from test folder
    for k, v in enumerate(test_imgs):
        img = os.path.join(test_folder, v)
        feature_vect = img2vector(img)
        label = np.array([[int(v.split('_')[0])]])
        test_data[k, :] = np.concatenate((label, feature_vect), axis=1)

    return training_data, test_data

kNN classifier:

def knn_classify(in_X, dataset, labels, k, dmetric=EUC_DM):
    """kNN classification through linearly scanning,
    kNN method 1"""
    if dmetric == EUC_DM:
        sorted_indices, _ = euclidean_metric(in_X, dataset)
    elif dmetric == MAH_DM:
        if 'cov_invmat' not in knn_cache.keys():
            print 'computing convariance matrix...'
            cov = cov_mat(dataset)
            print 'get cov!'
            knn_cache['cov_invmat'] = np.linalg.pinv(cov)

        sorted_indices, _ = mahalanobis_metric(in_X,

    # find the majority class
    class_count = {}
    for i in range(k):
        vote_label = labels[sorted_indices[i]]
        class_count[vote_label] = class_count.get(vote_label, 0) + 1

    sorted_class_count = sorted(class_count.iteritems(),
                                key=lambda x: x[1], reverse=True)
    return sorted_class_count[0][0]

kd-tree based kNN classifier:

def kd_knn_classify(in_X, dataset, labels, k, dmetric=EUC_DM):
    """kNN classification using kd-tree,
    kNN method 3"""
    kdtree = KDTree(dataset)
    indices, distances = kdtree.search_knn(in_X, k, dmetric)

    # find the majority class
    class_count = {}
    for i in range(k):
        vote_label = labels[indices[i]]
        class_count[vote_label] = class_count.get(vote_label, 0) + 1

    sorted_class_count = sorted(class_count.iteritems(),
                                key=lambda x: x[1], reverse=True)
    return sorted_class_count[0][0]

After implemted kNN classifier, I did test on the test dataset with $k = 3$, it got 98.78% accuracy. Pretty remarkable!