CommonLounge Archive

Hands-on Project: Digit classification with K-Nearest Neighbors and Data Augmentation

April 17, 2018

In this hands-on project, we’ll apply K-Nearest Neighbors algorithm to handwritten digit classification. Our main objectives are: a) to learn how to experiment with various hyper-parameters, b) introduce metrics classification accuracy and confusion matrix, c) develop intuition about how KNN works and d) use this intuition and data-augmentation to improve classification accuracy further.


This assignment guides you through using KNN for handwritten digit classification. Take the quiz at the bottom once you complete the assignment, as the quiz asks you for the results from the assignment. In addition, the last few questions in the quiz also guide you through some data augmentation techniques which you can use to improve the accuracy of your KNN model further.

Project Template on Google Colaboratory

Notebook Link

Work on this project directly in-browser via Google Colaboratory. The link above is a starter template that you can save to your own Google Drive and work on. Google Colab is a free tool that lets you run small Machine Learning projects through your web browser. You should read this 1 min tutorial if you’re unfamiliar with Google Colaboratory. Note that, for this project, you’ll have to upload the dataset linked below to Google Colab after saving the notebook to your own system.


The MNIST dataset is a popularly used dataset in machine learning for the handwritten digit recognition task. Here are some sample images from the dataset.

Samples from MNIST hand-written digit dataset (16 samples are shown for each label)

We’ll work with a smaller subset of the dataset. You can access it at the following links: mnist_10000.pkl.gz and mnist_1000.pkl.gz. The first one consists of 10,000 training samples (plus 2,000 validation and 2,000 test samples), and the second one consists of 1,000 training samples (plus 200 validation and 200 test samples).

Loading the dataset

You can load the dataset with the following code. We are going to use Python 3.6 for this project, which should come with pickle and gzip packages. We hope you’ve installed numpy from earlier exercises. If not, run the following in the terminal: pip3 install numpy

import pickle, gzip
import numpy as np
f ='mnist_10000.pkl.gz', 'rb')
trainData, trainLabels, valData, valLabels, testData, testLabels = pickle.load(f, encoding='latin1')
print("training data points: {}".format(len(trainLabels)))
print("validation data points: {}".format(len(valLabels)))
print("testing data points: {}".format(len(testLabels)))

trainData is a NumPy array with shape (10000, 784). Each row is a data point (array of size 784), which are the values for the pixels of the 28 x 28 image (arranged row-by-row). A pixel value of 0.0 denotes white (background), and a pixel value of 1.0 denotes black (foreground). Values in between denote the pixel intensities.

Looking at the images

You can use the following snippet to look at some specific images. You can install the OpenCV package by running the following:

pip3 install opencv-python

Here’s the code to see the training images:

import cv2
image = trainData[0]
image = image.reshape((28, 28))
cv2.imshow("Image", image)

Note that OpenCV launches in a window separate from the terminal, and may take a few seconds to load up before you can see the first image in your dataset.

Choosing the best hyperparameters

Next, we will use sklearn package’s KNeighborsClassifier implementation (which is quite optimized) on the mnist_10000 dataset. Results from the following tasks will be asked in the quiz. (Use Euclidean distance for all tasks). Before we get started, make sure sklearn package is installed:

pip3 install sklearn

Now, let’s implement K-nearest neighbors for a number of values of k and measure the accuracy for those values:

Task 1: Try the following values of K, and note the classification accuracy on the validation data for each. K = 1, 3, 5, 9, 15, 25

from sklearn.neighbors import KNeighborsClassifier
for k in [1, 3, 5, 9, 15, 25, ]:
    model = KNeighborsClassifier(n_neighbors=k), trainLabels)
    score = model.score(valData, valLabels)
    print(k, score)

Task 2: For the best performing value of K, calculate and note the classification accuracy on the test data.

best_k = ...
model = KNeighborsClassifier(n_neighbors=best_k), trainLabels)
score = model.score(testData, testLabels)
predictions = model.predict(testData)

Task 3: Inspect the performance per class, i.e. precision, recall and f-score for each digit. (hint: see sklearn.metrics.classification_report)

from sklearn.metrics import classification_report
print(classification_report(testLabels, predictions))

Task 4: Inspect the confusion matrix, i.e. when the correct label was digit I, how times did the model predict J. (hint: see sklearn.metrics.confusion_matrix)

from sklearn.metrics import confusion_matrix
print(confusion_matrix(testLabels, predictions))


The full code for the solution is available here: Solution to Hands-on Project: Digit classification with K-Nearest Neighbors and Data Augmentation. We highly encourage you to look at it only if you’re stuck and cannot proceed further.

(Bonus) Implement KNN yourself!

If you’d like to practice implementing KNN yourself (not the main focus of this assignment), you should use the mnist_1000 file so that you don’t have to wait a long while for your code to run. You can use sklearn package’s KNeighborsClassifier to check your implementation (compare the predictions outputted by scikit-learn, and the predictions outputted by your code).

The pseudocode for KNN is as follows

  • Compute the distance between current sample and every sample in the training data (use Euclidean distance).
  • Determine the closest K training samples (use K = 5).
  • Check which label is the most common among the K training samples. This label is the prediction.

(Bonus) Improving accuracy further with Data Augmentation

One simple way to improve accuracy further it to try difference distance metrics based on your intuition (see the KNeighborsClassifier documentation).

The quiz has some questions which will guide you to how to improve the accuracy using other data augmentation techniques. (finish the above tasks before starting the quiz).

Solution on Google Colaboratory

Notebook Link

The complete notebook with all the cells executed is available via Google Colaboratory using the link above. Google Colab is a free tool that lets you run small Machine Learning experiments through your browser. You should read this 1 min tutorial if you’re unfamiliar with Google Colaboratory. Note that, for this project, you’ll have to upload the dataset to Google Colab after saving the notebook to your own system.

© 2016-2022. All rights reserved.