Linear Regression — Complete Guide
Linear regression is the simplest and most fundamental ML algorithm. It models the relationship between variables as a straight line.
Simple Linear Regression
The Goal:
Find the best line that fits the data:
y = mx + b
Where:
y = predicted value (target)
x = input feature
m = slope (weight)
b = y-intercept (bias)
Example: House prices
y = price
x = square footage
m = price per sq ft
b = base price
Finding the Best Line
Method 1: Ordinary Least Squares (OLS)
Minimize: L = Σ(yᵢ - ŷᵢ)² = Σ(yᵢ - (mxᵢ + b))²
Taking derivatives and setting to 0:
m = Σ(xᵢ - x̄)(yᵢ - ȳ) / Σ(xᵢ - x̄)²
b = ȳ - mx̄
Where x̄, ȳ are means of x, y
Method 2: Gradient Descent
Initialize: m=0, b=0
Repeat:
∂L/∂m = -2 Σ xᵢ(yᵢ - (mxᵢ + b))
∂L/∂b = -2 Σ (yᵢ - (mxᵢ + b))
m = m - α × ∂L/∂m
b = b - α × ∂L/∂b
Until convergence
Multiple Linear Regression
y = w₁x₁ + w₂x₂ + ... + wₙxₙ + b
In matrix form:
y = Xw + b
Where:
X = feature matrix (n_samples × n_features)
w = weight vector (n_features × 1)
y = target vector (n_samples × 1)
Normal equation (closed-form):
w = (XᵀX)⁻¹Xᵀy
Python Implementation
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
# Generate data
np.random.seed(42)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# Fit model
model = LinearRegression()
model.fit(X, y)
print(f"Slope: {model.coef_[0][0]:.2f}")
print(f"Intercept: {model.intercept_[0]:.2f}")
# Predict
X_new = np.array([[0], [2]])
y_pred = model.predict(X_new)
# Plot
plt.scatter(X, y, alpha=0.5)
plt.plot(X_new, y_pred, 'r-', linewidth=2)
plt.xlabel('X')
plt.ylabel('y')
plt.title('Linear Regression')
plt.show()
Evaluation Metrics
Mean Squared Error (MSE):
MSE = (1/n) Σ(yᵢ - ŷᵢ)²
Root Mean Squared Error (RMSE):
RMSE = √MSE
Mean Absolute Error (MAE):
MAE = (1/n) Σ|yᵢ - ŷᵢ|
R-squared (R²):
R² = 1 - SS_res/SS_tot
where SS_res = Σ(yᵢ - ŷᵢ)²
SS_tot = Σ(yᵢ - ȳ)²
R² = 1.0 → perfect prediction
R² = 0.0 → predicts mean
R² < 0.0 → worse than predicting mean
Assumptions
1. Linearity
Relationship between X and y is linear
Check: Scatter plots, residual plots
2. Independence
Observations are independent
Check: Durbin-Watson test
3. Homoscedasticity
Constant variance of residuals
Check: Residual vs fitted plot
4. Normality
Residuals are normally distributed
Check: Q-Q plot, histogram
5. No multicollinearity
Features are not highly correlated
Check: VIF (Variance Inflation Factor)
Polynomial Regression
When relationship is non-linear:
y = w₁x + w₂x² + w₃x³ + b
from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=2)
X_poly = poly.fit_transform(X)
model = LinearRegression()
model.fit(X_poly, y)
Key Takeaways
- Linear regression finds the best straight line through data
- OLS gives closed-form solution; gradient descent is iterative
- R² measures how well the model explains variance
- Check assumptions before using linear regression
- Polynomial regression extends to non-linear relationships
- Regularization (Ridge, Lasso) prevents overfitting
- Linear regression is fast, interpretable, and a great baseline
- Always visualize data before fitting a line