R Machine Learning with caret — Predictive Modeling
Learning Objectives
By the end of this tutorial, you will be able to:
- Prepare data for machine learning (splitting, preprocessing)
- Train and tune models using the caret package
- Evaluate model performance with cross-validation
- Compare multiple algorithms
- Build prediction pipelines
The caret Package
library(caret)
# Check available algorithms
modelLookup()
Data Preparation
# Train/test split
set.seed(42)
train_index <- createDataPartition(iris$Species, p = 0.7, list = FALSE)
train_data <- iris[train_index, ]
test_data <- iris[-train_index, ]
# Cross-validation
train_control <- trainControl(
method = "cv",
number = 10,
classProbs = TRUE,
summaryFunction = twoClassSummary
)
# Preprocessing
preprocess_params <- preProcess(train_data[, -5], method = c("center", "scale"))
train_scaled <- predict(preprocess_params, train_data[, -5])
Classification
Logistic Regression
model_glm <- train(
Species ~ .,
data = iris,
method = "glm",
trControl = train_control,
metric = "Accuracy"
)
print(model_glm)
Random Forest
model_rf <- train(
Species ~ .,
data = iris,
method = "rf",
trControl = train_control,
tuneGrid = expand.grid(mtry = c(1, 2, 3, 4))
)
print(model_rf)
plot(model_rf)
Support Vector Machine
model_svm <- train(
Species ~ .,
data = iris,
method = "svmRadial",
trControl = train_control,
preProcess = c("center", "scale")
)
print(model_svm)
K-Nearest Neighbors
model_knn <- train(
Species ~ .,
data = iris,
method = "knn",
trControl = train_control,
preProcess = c("center", "scale"),
tuneGrid = expand.grid(k = seq(1, 21, by = 2))
)
print(model_knn)
plot(model_knn)
Regression
# Random Forest for regression
model_rf_reg <- train(
mpg ~ .,
data = mtcars,
method = "rf",
trControl = trainControl(method = "cv", number = 10),
tuneGrid = expand.grid(mtry = c(2, 4, 6, 8))
)
print(model_rf_reg)
# Predictions
predictions <- predict(model_rf_reg, newdata = mtcars)
postResample(pred = predictions, obs = mtcars$mpg)
Model Comparison
# Train multiple models
models <- resamples(list(
GLM = model_glm,
RF = model_rf,
SVM = model_svm,
KNN = model_knn
))
# Summary
summary(models)
# Box plots
bwplot(models, metric = "Accuracy")
# Dot plots
dotplot(models, metric = "Accuracy")
Feature Selection
# Recursive Feature Elimination
rfe_control <- rfeControl(functions = rfFuncs, method = "cv", number = 10)
rfe_result <- rfe(
iris[, -5],
iris$Species,
sizes = c(1, 2, 3, 4),
rfeControl = rfe_control
)
print(rfe_result)
plot(rfe_result, type = c("g", "o"))
# Variable importance
varImp(model_rf)
plot(varImp(model_rf))
Practical Examples
Example 1: Customer Churn Prediction
library(caret)
# Simulated data
set.seed(42)
n <- 1000
data <- data.frame(
tenure = sample(1:72, n, replace = TRUE),
monthly_charges = runif(n, 20, 100),
contract = factor(sample(c("Month-to-month", "One year", "Two year"), n, replace = TRUE))
)
data$churn <- factor(ifelse(
data$contract == "Month-to-month" & data$monthly_charges > 60,
sample(c("Yes", "No"), n, replace = TRUE, prob = c(0.4, 0.6)),
sample(c("Yes", "No"), n, replace = TRUE, prob = c(0.1, 0.9))
))
# Split
train_index <- createDataPartition(data$churn, p = 0.7, list = FALSE)
train <- data[train_index, ]
test <- data[-train_index, ]
# Train model
model <- train(
churn ~ .,
data = train,
method = "rf",
trControl = trainControl(method = "cv", number = 5),
tuneGrid = expand.grid(mtry = c(1, 2, 3))
)
# Evaluate
predictions <- predict(model, newdata = test)
confusionMatrix(predictions, test$churn)
Practice Exercises
Exercise 1: Model Comparison
Compare logistic regression, random forest, and SVM on the Sonar dataset.
Solution
library(caret)
library(mlbench)
data(Sonar)
set.seed(42)
train_index <- createDataPartition(Sonar$Class, p = 0.7, list = FALSE)
train <- Sonar[train_index, ]
test <- Sonar[-train_index, ]
ctrl <- trainControl(method = "cv", number = 5)
model_glm <- train(Class ~ ., data = train, method = "glm", trControl = ctrl)
model_rf <- train(Class ~ ., data = train, method = "rf", trControl = ctrl)
model_svm <- train(Class ~ ., data = train, method = "svmRadial", trControl = ctrl)
# Compare
resamps <- resamples(list(GLM = model_glm, RF = model_rf, SVM = model_svm))
summary(resamps)
bwplot(resamps)
Key Takeaways
caretprovides a unified interface for hundreds of algorithmstrainControl()defines resampling methods (CV, bootstrap)train()fits models with automatic hyperparameter tuningconfusionMatrix()evaluates classification performancepostResample()evaluates regression performanceresamples()compares multiple modelsvarImp()shows feature importance- Always preprocess — center, scale, handle missing values
Next: Learn about R Shiny Web Apps — interactive web applications.