K-Nearest Neighbors (KNN)

Machine LearningInstance-Based LearningFree Lesson

Advertisement

What Is K-Nearest Neighbors?

K-Nearest Neighbors (KNN) is one of the most intuitive machine learning algorithms. Its prediction logic is elegantly simple: find the k training examples most similar to the new input, and let them vote on the answer.

No equations are learned. No parameters are fit. KNN is a lazy learner — it memorizes training data and defers all computation to prediction time.

Core Principle: "Birds of a feather flock together." — Similar inputs should produce similar outputs.


Architecture: How KNN Actually Works

┌─────────────────────────────────────────────────────────────────┐
│                    KNN: TRAINING PHASE                          │
│                                                                 │
│   Input: X_train (n×d matrix), y_train (n labels)              │
│                                                                 │
│   Action: STORE EVERYTHING. That's it.                          │
│   Cost: O(1) time, O(n×d) space                                 │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                  KNN: PREDICTION PHASE (for new point x)        │
│                                                                 │
│  Step 1 ─ MEASURE                                               │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │  Compute distance(x, xᵢ) for ALL n training points      │   │
│  │  → n distance computations, each O(d)                   │   │
│  │  Total: O(n × d)                                         │   │
│  └──────────────────────────────────────────────────────────┘   │
│           ↓                                                     │
│  Step 2 ─ SORT                                                  │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │  Sort distances, select the k smallest                  │   │
│  │  O(n log n) sort, or O(n + k log n) with heap           │   │
│  └──────────────────────────────────────────────────────────┘   │
│           ↓                                                     │
│  Step 3 ─ DECIDE                                                │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │  CLASSIFICATION: majority vote among k neighbors         │   │
│  │  REGRESSION:     mean (or weighted mean) of k neighbors  │   │
│  └──────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────┘

Visual Decision Boundary

        Feature B ▲
                  │
             4    │    ■   ■
                  │  ■   ●?   ← New point. What class?
             3    │    ○   ○
                  │  ○
             2    │
                  └────────────────── Feature A ▶
                     2     3     4

Distance from ●? to each neighbor:
  d(●?, ■₁) = 1.0   ← nearest
  d(●?, ○₁) = 1.0   ← nearest
  d(●?, ■₂) = 1.4
  d(●?, ○₂) = 1.4
  d(●?, ■₃) = 2.0

k=1 → nearest neighbor is ■  → PREDICT ■
k=3 → 3 nearest: ■, ○, ■    → 2 votes ■ → PREDICT ■
k=5 → 5 nearest: ■,○,■,○,■  → 3 votes ■ → PREDICT ■

■ = Class A    ○ = Class B

Distance Metrics: Defining "Similarity"

The distance metric is the heart of KNN — it defines what "nearest" means.

Euclidean Distance (L2 Norm)

deuclidean(p,q)=i=1d(piqi)2d_{euclidean}(p, q) = \sqrt{\sum_{i=1}^{d}(p_i - q_i)^2}

Example: p = (1, 2), q = (4, 6)
d = √((4-1)² + (6-2)²) = √(9 + 16) = √25 = 5.0

Geometric interpretation: straight-line distance
Best for: continuous features, low-dimensional spaces, isotropic data

Manhattan Distance (L1 Norm)

dmanhattan(p,q)=i=1dpiqid_{manhattan}(p, q) = \sum_{i=1}^{d}|p_i - q_i|

Example: p = (1, 2), q = (4, 6)
d = |4-1| + |6-2| = 3 + 4 = 7.0

Geometric interpretation: city-block distance (only right angles)
Best for: high-dimensional data, when outliers are present, sparse features

Minkowski Distance (Generalized)

dminkowski(p,q)=(i=1dpiqir)1/rd_{minkowski}(p, q) = \left(\sum_{i=1}^{d}|p_i - q_i|^r\right)^{1/r}

r = 1 → Manhattan distance
r = 2 → Euclidean distance
r → ∞ → Chebyshev (maximum difference across any dimension)

Comparing Metrics Visually

All points equidistant (distance=3) from origin:

    Euclidean (circle)    Manhattan (diamond)   Chebyshev (square)
         ●                      ●                      ●●●
        ●●●                   ●●●●●                  ●   ●
       ●●●●●                ●●●●●●●●●               ●     ●
      ●●●●●●●              ●●●●●●●●●●●              ●     ●
       ●●●●●                ●●●●●●●●●               ●     ●
        ●●●                   ●●●●●                  ●●●●●
         ●                      ●

Distance Metric Selection Guide

Data TypeRecommended MetricWhy
Continuous features, low-dEuclideanNatural geometry
High-dimensional (> 50 features)ManhattanLess sensitive to outliers
Text / NLP embeddingsCosine similarityDirection matters, not magnitude
Binary / categoricalHammingCounts differing positions
Mixed typesGower distanceHandles mixed data natively
Images (pixel space)EuclideanPixel-wise comparison

Choosing K: The Most Important Hyperparameter

Effect of k on Decision Boundary:

k = 1 (overfit):           k = 7 (balanced):         k = 51 (underfit):
  ┌─────────────┐            ┌─────────────┐            ┌─────────────┐
  │●│■│●│■│●│■  │            │             │            │             │
  │■│●│■│●│■│●  │            │  ■  ■  ■    │            │             │
  │●│■│●│■│●│■  │            │     ■       │            │ ──── ────── │
  │            ■│            │─ ─ ─ ─ ─ ─ │            │             │
  └─────────────┘            │  ○  ○  ○   │            │  ○  ○  ○   │
 Jagged boundary             └─────────────┘            └─────────────┘
 Memorizes noise             Smooth boundary            Too smooth
 High Variance               Balanced                   High Bias

   TEST ACCURACY                                          TEST ACCURACY
        ▲                                                      ▲
    1.0 │*                                               1.0  │
    0.9 │ *                                              0.9  │ * * *
    0.8 │  **                                            0.8  │       * *
    0.7 │    ** ← sweet spot                             0.7  │           *
        └──────────── k                                       └──────────── k

Rules for Picking k

import numpy as np

n_train = 1000  # training samples

# Rule of thumb 1: square root
k_sqrt = int(np.sqrt(n_train))         # k ≈ 31

# Rule of thumb 2: always odd for binary classification
k_odd = k_sqrt if k_sqrt % 2 == 1 else k_sqrt + 1  # avoid ties

# Rule of thumb 3: cross-validation (best method)
from sklearn.model_selection import cross_val_score

k_candidates = range(1, 51, 2)  # try odd values 1-49
cv_scores = []
for k in k_candidates:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X_train, y_train, cv=5)
    cv_scores.append(scores.mean())

best_k = list(k_candidates)[np.argmax(cv_scores)]
print(f"Best k by cross-validation: {best_k}")

Complete Implementation: Iris Classification

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
    accuracy_score, classification_report,
    confusion_matrix
)
from sklearn.pipeline import Pipeline

# ─── 1. Load and Explore Data ─────────────────────────────────────
iris = load_iris()
X, y = iris.data, iris.target
feature_names = iris.feature_names  # sepal length, sepal width, petal length, petal width
class_names   = iris.target_names   # setosa, versicolor, virginica

df = pd.DataFrame(X, columns=feature_names)
df['species'] = [class_names[i] for i in y]

print("Dataset Overview:")
print(f"  Samples:  {X.shape[0]}")
print(f"  Features: {X.shape[1]}")
print(f"  Classes:  {list(class_names)}")
print(f"\nClass distribution:\n{df['species'].value_counts()}")
print(f"\nFeature statistics:\n{df.drop('species',axis=1).describe().round(2)}")

# ─── 2. Split Data ────────────────────────────────────────────────
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42,
    stratify=y         # maintain class proportions in both splits
)
print(f"\nTrain size: {X_train.shape[0]} | Test size: {X_test.shape[0]}")

# ─── 3. Build Pipeline (scale + KNN) ─────────────────────────────
# Pipeline CRITICAL: ensures scaler is fit only on training data
# Prevents data leakage during cross-validation
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('knn', KNeighborsClassifier())
])

# ─── 4. Hyperparameter Search ────────────────────────────────────
param_grid = {
    'knn__n_neighbors': [1, 3, 5, 7, 9, 11, 15],
    'knn__weights':     ['uniform', 'distance'],
    'knn__metric':      ['euclidean', 'manhattan'],
}

grid_search = GridSearchCV(
    pipeline,
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1,
    refit=True,
    verbose=0
)
grid_search.fit(X_train, y_train)

print(f"\nBest hyperparameters: {grid_search.best_params_}")
print(f"Best cross-validation accuracy: {grid_search.best_score_:.4f}")

# ─── 5. Evaluate on Test Set ─────────────────────────────────────
best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test)

print("\n" + "="*55)
print("       FINAL MODEL EVALUATION ON TEST SET")
print("="*55)
print(f"Test Accuracy: {accuracy_score(y_test, y_pred):.4f}\n")
print("Per-Class Performance:")
print(classification_report(y_test, y_pred, target_names=class_names))

# Confusion Matrix (formatted)
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(f"                Setosa  Versicolor  Virginica")
for i, row in enumerate(cm):
    print(f"  {class_names[i]:<14} {row[0]:>6}  {row[1]:>10}  {row[2]:>9}")

# ─── 6. Predict New Flower ───────────────────────────────────────
new_flower = np.array([[5.1, 3.5, 1.4, 0.2]])   # likely Setosa
pred_class = best_model.predict(new_flower)
pred_proba = best_model.predict_proba(new_flower)

print(f"\nNew flower: sepal={new_flower[0,:2]}, petal={new_flower[0,2:]}")
print(f"Predicted species: {class_names[pred_class[0]]}")
print("Class probabilities:")
for name, prob in zip(class_names, pred_proba[0]):
    bar = "█" * int(prob * 30)
    print(f"  {name:<12} {prob:.3f}  {bar}")

Sample output:

Test Accuracy: 1.0000

Per-Class Performance:
              precision    recall  f1-score   support
      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00        10
   virginica       1.00      1.00      1.00        10

New flower: sepal=[5.1 3.5], petal=[1.4 0.2]
Predicted species: setosa
Class probabilities:
  setosa       1.000  ██████████████████████████████
  versicolor   0.000
  virginica    0.000

KNN From Scratch (Educational)

import numpy as np
from collections import Counter

class KNNClassifier:
    """
    Pure Python KNN implementation — no sklearn.
    Exposes the internals for learning purposes.
    """

    def __init__(self, k=5, metric='euclidean', weights='uniform'):
        self.k       = k
        self.metric  = metric
        self.weights = weights

    def fit(self, X, y):
        """Lazy learning — just store the data."""
        self.X_train = np.asarray(X, dtype=float)
        self.y_train = np.asarray(y)
        self.classes_ = np.unique(y)
        print(f"KNN fitted | n={len(X)} samples | k={self.k} | metric={self.metric}")
        return self

    def _distance(self, a, b):
        if self.metric == 'euclidean':
            return np.sqrt(np.sum((a - b) ** 2))
        elif self.metric == 'manhattan':
            return np.sum(np.abs(a - b))
        elif self.metric == 'chebyshev':
            return np.max(np.abs(a - b))
        raise ValueError(f"Unknown metric: {self.metric}")

    def _predict_one(self, x):
        # Compute all distances
        distances = np.array([
            self._distance(x, x_train)
            for x_train in self.X_train
        ])

        # Get k nearest neighbors
        k_idx    = np.argsort(distances)[:self.k]
        k_labels = self.y_train[k_idx]
        k_dists  = distances[k_idx]

        if self.weights == 'uniform':
            # Simple majority vote
            return Counter(k_labels).most_common(1)[0][0]

        elif self.weights == 'distance':
            # Weighted vote — closer neighbors have more influence
            # Weight = 1/distance (handle distance=0 edge case)
            weights = np.where(k_dists == 0, 1e10, 1.0 / k_dists)
            weighted_votes = {}
            for label, weight in zip(k_labels, weights):
                weighted_votes[label] = weighted_votes.get(label, 0) + weight
            return max(weighted_votes, key=weighted_votes.get)

    def predict(self, X):
        return np.array([self._predict_one(x) for x in X])

    def score(self, X, y):
        return np.mean(self.predict(X) == y)

    def predict_explain(self, x):
        """Show exactly which neighbors drove the prediction."""
        distances = np.array([self._distance(x, xt) for xt in self.X_train])
        k_idx = np.argsort(distances)[:self.k]

        print(f"\nPredicting for: {x}")
        print(f"Top {self.k} nearest neighbors:")
        print(f"  {'Rank':<6} {'Distance':<12} {'Label':<10}")
        print(f"  {'-'*30}")
        for rank, idx in enumerate(k_idx, 1):
            print(f"  {rank:<6} {distances[idx]:<12.4f} {self.y_train[idx]:<10}")

        prediction = self._predict_one(x)
        vote_counts = Counter(self.y_train[k_idx])
        print(f"\nVotes: {dict(vote_counts)}")
        print(f"Final prediction: {prediction}")
        return prediction

Why Feature Scaling Is Non-Negotiable

Without Scaling:

  Feature 1 (age):    range 18 - 80     → max diff = 62
  Feature 2 (income): range 20000-150000 → max diff = 130000

  Euclidean distance is DOMINATED by income.
  Age is completely ignored — only income drives predictions!

  Person A: age=25, income=30000   →  salary difference = 120000
  Person B: age=24, income=149000  →  salary difference = 1000

  d(query, A) = √(0 + 120000²) ≈ 120,000
  d(query, B) = √(1 + 1000²)   ≈ 1,000

  B is "closer" despite being 125,000 apart in income — because age = 1
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler

# StandardScaler: z-score normalization — mean=0, std=1
# Best for: normally distributed continuous features
std_scaler = StandardScaler()
X_std = std_scaler.fit_transform(X_train)
print(f"StandardScaler — mean: {X_std.mean(axis=0).round(3)}, std: {X_std.std(axis=0).round(3)}")

# MinMaxScaler: scale to [0, 1]
# Best for: bounded features, when you know the min/max
mm_scaler = MinMaxScaler()
X_mm = mm_scaler.fit_transform(X_train)
print(f"MinMaxScaler — min: {X_mm.min(axis=0).round(3)}, max: {X_mm.max(axis=0).round(3)}")

# RobustScaler: uses median and IQR — resistant to outliers
# Best for: data with significant outliers
rb_scaler = RobustScaler()
X_rb = rb_scaler.fit_transform(X_train)

# Rule: fit scaler on TRAINING data only, then transform both train and test
# NEVER fit on test data — that's data leakage!
X_test_std = std_scaler.transform(X_test)  # No fit_transform here!

KNN for Regression

from sklearn.neighbors import KNeighborsRegressor
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error, r2_score

# ─── Generate non-linear regression data ─────────────────────────
np.random.seed(42)
X_reg = np.linspace(0, 10, 200).reshape(-1, 1)
y_reg = np.sin(X_reg.ravel()) + 0.2 * np.random.randn(200)

X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(
    X_reg, y_reg, test_size=0.2, random_state=42
)

# Compare uniform vs distance weighting
results = []
for k in [1, 3, 5, 10, 20]:
    for weights in ['uniform', 'distance']:
        model = KNeighborsRegressor(n_neighbors=k, weights=weights)
        model.fit(X_train_r, y_train_r)
        y_pred_r = model.predict(X_test_r)
        r2   = r2_score(y_test_r, y_pred_r)
        rmse = np.sqrt(mean_squared_error(y_test_r, y_pred_r))
        results.append({'k': k, 'weights': weights, 'R2': r2, 'RMSE': rmse})

df_results = pd.DataFrame(results).sort_values('R2', ascending=False)
print(df_results.to_string(index=False))

Computational Complexity and Scaling

Algorithm     | Training   | Prediction (per sample)    | Memory
──────────────┼────────────┼────────────────────────────┼────────────
Brute Force   | O(1)       | O(n × d)                   | O(n × d)
KD-Tree       | O(n log n) | O(log n × d)  [d < 20]     | O(n × d)
Ball Tree     | O(n log n) | O(log n × d)  [d ≥ 20]     | O(n × d)
FAISS (approx)| O(n)       | O(√n × d)     [any d]      | O(n × d)

n = number of training samples
d = number of features

Practical thresholds:
  n < 10,000  and d < 20  → KD-Tree  (sklearn default: auto → kdtree)
  n < 50,000  and d ≥ 20  → Ball Tree
  n > 100,000             → Approximate methods (FAISS, Annoy, ScaNN)
from sklearn.neighbors import KNeighborsClassifier
import time

algorithms = ['brute', 'kd_tree', 'ball_tree', 'auto']
n_samples  = [100, 1000, 5000]

X_bench = np.random.randn(max(n_samples), 10)
y_bench = np.random.randint(0, 2, max(n_samples))

print(f"{'Algorithm':<12} {'n=100':>8} {'n=1000':>8} {'n=5000':>8}  (prediction time, ms)")
print("-" * 50)
for algo in algorithms:
    times = []
    for n in n_samples:
        X_b, y_b = X_bench[:n], y_bench[:n]
        knn = KNeighborsClassifier(n_neighbors=5, algorithm=algo)
        knn.fit(X_b, y_b)
        t = time.perf_counter()
        knn.predict(X_b[:10])
        times.append((time.perf_counter() - t) * 1000)
    print(f"{algo:<12} {times[0]:>7.2f}  {times[1]:>7.2f}  {times[2]:>7.2f}")

When to Use KNN vs Alternatives

┌─────────────────────────────────────────────────────────────────┐
│                   DECISION GUIDE                                │
│                                                                 │
│  Dataset is SMALL (n < 5,000)?                                  │
│  Features are NUMERIC and SCALED?                               │
│  Classes have IRREGULAR decision boundaries?                    │
│  You need INTERPRETABLE predictions?                            │
│  Data changes frequently (online learning)?                     │
│                                                                 │
│  YES to most → KNN is a strong candidate ✅                     │
│                                                                 │
│  BUT if:                                                        │
│  → n > 100,000 samples        → Try Random Forest or XGBoost   │
│  → Many categorical features  → Try Decision Tree              │
│  → Need probability estimates → Try Logistic Regression        │
│  → Very high dimensions (d>100) → Try SVM or Neural Net        │
│  → Real-time serving needed    → Avoid brute-force KNN         │
└─────────────────────────────────────────────────────────────────┘
PropertyKNNDecision TreeSVMNeural Net
Training speed⚡ Instant🔶 Fast🔴 Slow🔴🔴 Very slow
Prediction speed🔴 Slow (large n)⚡ Fast⚡ Fast⚡ Fast
Memory usage🔴 High🟢 Low🟢 Low🔶 Medium
Handles non-linearity✅ Naturally✅ Yes✅ With kernel✅ Yes
Interpretability✅ High✅ High❌ Low❌ Very low
Works with small data✅ Yes✅ Yes✅ Yes❌ Needs lots
Sensitive to noise🔴 Yes (k=1)🔶 Moderate🟢 Robust🔶 Moderate

Key Takeaways

  1. No training — KNN memorizes data and computes at prediction time (lazy learning)
  2. Always scale features — unscaled data makes distance meaningless
  3. k controls the bias-variance tradeoff — low k overfits, high k underfits; cross-validate to find the sweet spot
  4. Distance metric matters — Euclidean for low-d continuous data, Manhattan for high-d, Cosine for text
  5. Use a Pipeline — wraps scaler and KNN to prevent data leakage in cross-validation
  6. Weighted KNN (weights='distance') often outperforms uniform by giving more influence to closer neighbors
  7. KD-tree or Ball-tree structures speed up prediction from O(n) to O(log n) for medium-sized datasets
  8. KNN is baseline — always try KNN first; if it doesn't work, you'll understand your data better

Advertisement

Need Expert Data Science Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement