0

I tried to implement the k-means algorithm for the MNIST data set. But since the result is far from good, there might be a bug (or several) I don't see at the moment. The code is really straightforward. Here is what I did so far:

import numpy as np

# Load images
I = np.load("mnist_test_images.npy").astype(float) # (10000,784)
L = np.load("mnist_test_labels.npy").astype(int) # (10000,1)

# Scale
I = 2.0*(I/255.0-0.5)

images = len(I)

# Random initialization of centers for k=10 clusters
M = np.random.randn(10,28*28)

guess = np.zeros((len(I),1))
step = 0
while (True):
    # Compute distance of every image i to the center of every cluster k
    # image i belongs to cluster with smallest distance
    for i in range(images):
        d = np.sum((M-I[i])**2,axis=1)
        guess[i] = np.argmin(d)

    # Update the centers for all clusters
    # New center is the mean of all images i which belong to cluster k
    for k in range(10):
        idx, _ = np.where(guess == k)
        if len(idx) > 0:
            M[k] = np.mean(I[idx],axis=0)

    # Test how good the algorithm works
    # Very similar to first step
    if (step % 10 == 0):
        fitness = 0
        for i in range(images):
            dist = np.sum((M-I[i])**2,axis=1)
            if L[i] == np.argmin(dist):
                fitness += 1
        print("%d" % fitness, flush=True)

    step += 1

The code looks really simple. But there is probably a bug somewhere. When I test it, the accuracy drops from about 10-20% down to 5-10% or converges almost instantly not reaching more than 30%. I can not recognize any learning. Could the random initialization of the cluster's centers cause this behavior?

Thank you!

1 Answer 1

4

The problem is that you treat this as a supervised learning approach, but it is unsupervised. In my opinion the whole "unsupervised learning" terminology should be avoided because it can be very misleading. In fact, I wouldn't call most "unsupervised" methods to be "learning" at all.

Clustering is not just "unsupervised classification". It is a very different and much much harder task. The task is so difficult that we do not even yet know how to really evaluate it.

I'm your case there are sevral issues:

  1. You assume that kmeans will find digits 0 to 9. Since it is unsupervised, it most likely will not. Instead, it may discover there are slanted digits, different line widths, different kinds of ones, etc.
  2. You evaluate it assuming that cluster 0 corresponds to digit 0. It doesn't. The cluster labels are meaningless. The MNIST is a very bad choice here, because by coincidence it's classes are also digits. But kmeans will always use labels 0 to k-1, even for apples vs. bananas.
  3. You assume that the evaluation must become better with each iteration. But this is unsupervised!
  4. A class may contain multiple clusters
  5. Classes may be inseparable without labels, and this form one cluster
  6. Methods like kmeans are sensitive to outliers. You probably have some very tiny clusters that just model anfew bad data points.
Sign up to request clarification or add additional context in comments.

3 Comments

Thank you for the answer! But is there any way to classify MNIST with the kmeans algorithm? What would be an approach? When I can't use it for MNIST for what data is kmeans a proper choise?
It's not "wrong" to experiment with kmeans on MNIST. it's just easily misleading because the digits are numbers 0 to 9. It is easier to see why classification != clustering when the true classes are "red" and "blue".
But don't use it for classification.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.