K-means++ Algorithm - ML
Clustering is used to group similar data points. K-Means is a commonly used clustering method, but it often gives poor results because the initial cluster centers are chosen randomly. This may lead to empty clusters, overlapping clusters or centroids falling too close to each other.
For example, the outputs below show how K-Means can form incorrect clusters due to weak initialization. To fix this, K-Means++ was introduced. It improves the way initial centroids are selected so that the clustering becomes more stable, accurate and faster.
K-Means++ is an improved version of the K-Means algorithm. Instead of picking all centroids randomly, it chooses the first center randomly and then selects the remaining centers in a spaced-out manner. This ensures:
- Better cluster separation.
- Faster convergence.
- More consistent results compared to K-Means.
How K-mean++ Algorithm Works
K-means++ revolutionizes the initialization step while keeping the rest of the K-means algorithm intact. The key insight is deceptively simple: spread out the initial centers as much as possible.
The Initialization Process
The algorithm follows these steps:
1. First center: Choose the first cluster center uniformly at random from the data points
2. Subsequent centers: For each remaining center:
- Calculate the distance from each data point to its nearest existing center.
- Choose the next center with probability proportional to the square of this distance.
- Points farther from existing centers have a higher chance of being selected.
3. Standard K-means: Once all k centers are initialized, proceed with the standard K-means algorithm
The squared distance weighting is important it ensures that points far from existing centers are much more likely to be chosen, naturally spreading the centers across the data space.
Note: Although the initialization in K-means++ is computationally more expensive than the standard K-means algorithm, the run-time for convergence to optimum is drastically reduced for K-means++. This is because the centroids that are initially chosen are likely to lie in different clusters already.
Mathematical Foundation
Let's formalize the selection probability. When choosing the (i+1)-th center, the probability of selecting point x is:
P(x) = \frac{D(x)^2}{\sum_{x'} D(x')^2}
Where D(x) is the shortest distance from point x to any already-chosen center and the sum runs over all data points. his probability distribution is called D2-weighting and it's the fundamental principle that makes K-means++ work so well.
Implementation in Python
Let's understand how KMeans++ initializes centroids step by step using the following implementation:
1. Dataset Creation
Four separate Gaussian clusters are generated with different means and covariances to simulate different groupings in the data.
import numpy as np
import matplotlib.pyplot as plt
mean_01 = np.array([0.0, 0.0])
cov_01 = np.array([[1, 0.3], [0.3, 1]])
dist_01 = np.random.multivariate_normal(mean_01, cov_01, 100)
mean_02 = np.array([6.0, 7.0])
cov_02 = np.array([[1.5, 0.3], [0.3, 1]])
dist_02 = np.random.multivariate_normal(mean_02, cov_02, 100)
mean_03 = np.array([7.0, -5.0])
dist_03 = np.random.multivariate_normal(mean_03, cov_01, 100)
mean_04 = np.array([2.0, -7.0])
cov_04 = np.array([[1.2, 0.5], [0.5, 1.3]])
dist_04 = np.random.multivariate_normal(mean_04, cov_01, 100)
data = np.vstack((dist_01, dist_02, dist_03, dist_04))
np.random.shuffle(data)
2. Plotting Helper Function
This function is used to visualize the data points and the selected centroids at each step. All data points are shown in gray.
- Previously selected centroids are marked in black.
- The current centroid being added is marked in red.
- This helps visualize the centroid initialization process step by step.
def plot(data, centroids):
plt.scatter(data[:, 0], data[:, 1], marker='.', color='gray', label='Data Points')
if centroids.shape[0] > 1:
plt.scatter(centroids[:-1, 0], centroids[:-1, 1], color='black', label='Selected Centroids')
plt.scatter(centroids[-1, 0], centroids[-1, 1], color='red', label='Next Centroid')
plt.title(f'Select {centroids.shape[0]}th Centroid')
plt.legend()
plt.xlim(-5, 12)
plt.ylim(-10, 15)
plt.show()
3. Euclidean Distance Function
This is a standard formula to compute the distance between two vectors p1 and p2 in 2D space.
def distance(p1, p2):
return np.sqrt(np.sum((p1 - p2)**2))
4. K-Means++ Initialization
This function selects initial centroids using the K-Means++ strategy. The first centroid is chosen randomly from the dataset. For the next centroids:
- It calculates the distance of every point to its nearest existing centroid.
- Chooses the point farthest from the nearest centroid as the next centroid and ensures centroids are spaced far apart initially, giving better cluster separation.
import sys
def initialize(data, k):
centroids = []
centroids.append(data[np.random.randint(data.shape[0])])
plot(data, np.array(centroids))
for _ in range(k - 1):
distances = []
for point in data:
min_dist = min([distance(point, c) for c in centroids])
distances.append(min_dist)
next_centroid = data[np.argmax(distances)]
centroids.append(next_centroid)
plot(data, np.array(centroids))
return np.array(centroids)
# Run initialization
centroids = initialize(data, k=4)
Output:
It shows the dataset with the first randomly selected centroid (in red). No black points are visible since only one centroid is selected.
The second centroid is selected which is the farthest point from the first centroid. The first centroid becomes black and the new centroid is marked in red
The third centroid is selected. The two previously selected centroids are shown in black while the newly selected centroid is in red.
The final centroid is selected completing the initialization. Three previously selected centroids are in black and the last selected centroid is in red.
You can download complete code from here.
Applications of k-means++ algorithm
- Image segmentation: It can be used to segment images into different regions based on their color or texture features. This is useful in computer vision applications, such as object recognition or tracking.
- Customer segmentation: These are used to group customers into different segments based on their purchasing habits, demographic data or other characteristics. This is useful in marketing and advertising applications, as it can help businesses target their marketing efforts more effectively.
- Recommender systems: K-means++ can be used to recommend products or services to users based on their past purchases or preferences. This is useful in e-commerce and online advertising applications.