The Interview Question
βΉοΈ
Question: You're analyzing customer churn for a subscription service:
- Dataset: 100K customers with signup date, churn date, and features
- Requirements: Predict time to churn, identify churn factors, estimate survival curves
- Challenge: Censored data (some customers haven't churned yet)
Walk through your survival analysis approach:
- How do you handle censored data properly?
- How do you estimate survival curves using Kaplan-Meier?
- How do you build a Cox proportional hazards model?
- How do you validate and interpret your model?
Detailed Answer
1. Survival Analysis Fundamentals
Survival analysis handles time-to-event data with censoring, which occurs when we don't observe the event for all subjects.
import pandas as pd
import numpy as np
from lifelines import (
KaplanMeierFitter,
CoxPHFitter,
WeibullFitter,
ExponentialFitter,
LogNormalFitter
)
from lifelines.statistics import logrank_test, multivariate_logrank_test
from lifelines.plotting import plot_lifetimes
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')
class SurvivalAnalysisFramework:
"""Framework for survival analysis"""
def __init__(self, data, duration_col, event_col):
self.data = data.copy()
self.duration_col = duration_col
self.event_col = event_col
self.results = {}
def prepare_survival_data(self, start_date_col, end_date_col, event_indicator_col):
"""Prepare data for survival analysis"""
# Calculate duration
self.data['duration'] = (
pd.to_datetime(self.data[end_date_col]) -
pd.to_datetime(self.data[start_date_col])
).dt.days
# Ensure non-negative duration
self.data['duration'] = self.data['duration'].clip(lower=0)
# Event indicator (1 = event occurred, 0 = censored)
self.data['event'] = self.data[event_indicator_col].astype(int)
# Summary
n_events = self.data['event'].sum()
n_censored = len(self.data) - n_events
print(f"Survival Data Summary:")
print(f" Total subjects: {len(self.data)}")
print(f" Events (churned): {n_events} ({n_events/len(self.data)*100:.1f}%)")
print(f" Censored (active): {n_censored} ({n_censored/len(self.data)*100:.1f}%)")
print(f" Median duration: {self.data['duration'].median():.0f} days")
return self.data
def describe_censoring(self):
"""Describe censoring patterns"""
censoring_info = {
'right_censored': (self.data[self.event_col] == 0).sum(),
'events': (self.data[self.event_col] == 1).sum(),
'censoring_rate': (self.data[self.event_col] == 0).mean()
}
# Check for informative censoring
# (If censoring is related to outcomes)
print(f"\nCensoring Information:")
print(f" Right-censored: {censoring_info['right_censored']}")
print(f" Events observed: {censoring_info['events']}")
print(f" Censoring rate: {censoring_info['censoring_rate']:.2%}")
return censoring_info
# Example usage
# framework = SurvivalAnalysisFramework(data, 'duration', 'churned')
# framework.prepare_survival_data('signup_date', 'churn_date', 'has_churned')
2. Kaplan-Meier Estimation
class KaplanMeierAnalyzer:
"""Kaplan-Meier survival estimation"""
def __init__(self, data, duration_col, event_col):
self.data = data
self.duration_col = duration_col
self.event_col = event_col
self.fitter = KaplanMeierFitter()
self.fitted = False
def fit(self, group_col=None):
"""Fit Kaplan-Meier estimator"""
if group_col is None:
# Overall survival curve
self.fitter.fit(
self.data[self.duration_col],
self.data[self.event_col],
label='Overall'
)
self.fitted = True
results = {
'median_survival': self.fitter.median_survival_time_,
'survival_function': self.fitter.survival_function_,
'confidence_intervals': self.fitter.confidence_interval_survival_function_
}
else:
# Group-specific survival curves
groups = self.data[group_col].unique()
self.group_fitters = {}
for group in groups:
mask = self.data[group_col] == group
fitter = KaplanMeierFitter()
fitter.fit(
self.data.loc[mask, self.duration_col],
self.data.loc[mask, self.event_col],
label=str(group)
)
self.group_fitters[group] = fitter
self.fitted = True
results = {group: {
'median_survival': fitter.median_survival_time_,
'survival_function': fitter.survival_function_
} for group, fitter in self.group_fitters.items()}
self.results = results
return results
def compare_groups(self, group_col):
"""Compare survival curves between groups"""
groups = self.data[group_col].unique()
if len(groups) == 2:
# Log-rank test for two groups
group1 = self.data[self.data[group_col] == groups[0]]
group2 = self.data[self.data[group_col] == groups[1]]
result = logrank_test(
group1[self.duration_col], group2[self.duration_col],
event_observed_A=group1[self.event_col],
event_observed_B=group2[self.event_col]
)
comparison = {
'test': 'Log-rank',
'statistic': result.test_statistic,
'p_value': result.p_value,
'significant': result.p_value < 0.05
}
else:
# Multivariate log-rank test
result = multivariate_logrank_test(
self.data[group_col],
self.data[self.duration_col],
self.data[self.event_col]
)
comparison = {
'test': 'Multivariate log-rank',
'statistic': result.test_statistic,
'p_value': result.p_value,
'significant': result.p_value < 0.05
}
return comparison
def calculate_survival_probabilities(self, time_points):
"""Calculate survival probabilities at specific time points"""
survival_probs = {}
for t in time_points:
if self.fitted:
surv_prob = self.fitter.predict(t)
survival_probs[t] = {
'survival_probability': surv_prob,
'risk': 1 - surv_prob
}
return survival_probs
def median_survival_time(self):
"""Calculate median survival time"""
return self.fitter.median_survival_time_
def visualize(self, group_col=None, figsize=(10, 6)):
"""Visualize Kaplan-Meier curves"""
fig, ax = plt.subplots(figsize=figsize)
if group_col is None:
self.fitter.plot_survival_function(ax=ax, ci_show=True)
else:
for group, fitter in self.group_fitters.items():
fitter.plot_survival_function(ax=ax, ci_show=True)
ax.set_xlabel('Time (days)')
ax.set_ylabel('Survival Probability')
ax.set_title('Kaplan-Meier Survival Curves')
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout()
plt.savefig('kaplan_meier_curves.png', dpi=150, bbox_inches='tight')
plt.show()
def survival_table(self):
"""Generate survival table"""
return self.fitter.survival_function_at_times(
np.arange(0, self.data[self.duration_col].max(), 30)
)
# Example usage
# km_analyzer = KaplanMeierAnalyzer(data, 'duration', 'event')
# km_analyzer.fit()
# km_analyzer.fit(group_col='plan_type')
# comparison = km_analyzer.compare_groups('plan_type')
# survival_probs = km_analyzer.calculate_survival_probabilities([30, 90, 180, 365])
# km_analyzer.visualize(group_col='plan_type')
3. Cox Proportional Hazards Model
class CoxPHAnalyzer:
"""Cox Proportional Hazards regression"""
def __init__(self, data, duration_col, event_col):
self.data = data
self.duration_col = duration_col
self.event_col = event_col
self.model = CoxPHFitter()
self.fitted = False
def fit(self, covariates, penalizer=0.01):
"""Fit Cox PH model"""
# Prepare data
fit_data = self.data[[self.duration_col, self.event_col] + covariates].dropna()
# Fit model
self.model.fit(
fit_data,
duration_col=self.duration_col,
event_col=self.event_col,
penalizer=penalizer # L2 regularization
)
self.fitted = True
self.covariates = covariates
# Summary
results = {
'coefficients': self.model.params_,
'hazard_ratios': self.model.hazard_ratios_,
'p_values': self.model.summary['p'],
'confidence_intervals': self.model.confidence_intervals_,
'concordance_index': self.model.concordance_index_,
'log_likelihood': self.model.log_likelihood_,
'AIC': self.model.AIC_partial_
}
return results
def summarize(self):
"""Detailed model summary"""
if not self.fitted:
print("Model not fitted yet")
return
summary = self.model.summary
print("Cox Proportional Hazards Model Summary")
print("=" * 70)
print(f"Concordance index: {self.model.concordance_index_:.4f}")
print(f"Log-likelihood ratio test p-value: {self.model.log_likelihood_ratio_test().p_value:.6f}")
print(f"AIC: {self.model.AIC_partial_:.2f}")
print("\nCoefficients:")
print(summary)
return summary
def hazard_ratios(self):
"""Calculate and interpret hazard ratios"""
if not self.fitted:
print("Model not fitted yet")
return
hr = self.model.hazard_ratios_
interpretation = {}
for covariate, ratio in hr.items():
if ratio > 1:
interpretation[covariate] = {
'hazard_ratio': ratio,
'interpretation': f'One unit increase in {covariate} increases '
f'hazard by {(ratio - 1) * 100:.1f}%',
'direction': 'increases risk'
}
elif ratio < 1:
interpretation[covariate] = {
'hazard_ratio': ratio,
'interpretation': f'One unit increase in {covariate} decreases '
f'hazard by {(1 - ratio) * 100:.1f}%',
'direction': 'decreases risk'
}
else:
interpretation[covariate] = {
'hazard_ratio': ratio,
'interpretation': f'{covariate} has no effect on hazard',
'direction': 'no effect'
}
return interpretation
def check_proportional_hazards(self):
"""Test proportional hazards assumption"""
if not self.fitted:
print("Model not fitted yet")
return
# Schoenfeld residuals test
results = self.model.check_assumptions(
self.data[[self.duration_col, self.event_col] + self.covariates],
p_value_threshold=0.05,
show_plots=False
)
return results
def predict_survival_function(self, new_data, times):
"""Predict survival function for new observations"""
if not self.fitted:
print("Model not fitted yet")
return
survival_probs = self.model.predict_survival_function(new_data, times=times)
return survival_probs
def predict_partial_hazard(self, new_data):
"""Predict partial hazard (risk score)"""
if not self.fitted:
print("Model not fitted yet")
return
risk_scores = self.model.predict_partial_hazard(new_data)
return risk_scores
def visualize_coefficients(self, figsize=(10, 6)):
"""Visualize model coefficients"""
if not self.fitted:
print("Model not fitted yet")
return
fig, axes = plt.subplots(1, 2, figsize=figsize)
# Coefficients with confidence intervals
summary = self.model.summary
coefficients = summary['coef']
conf_int = summary[['coef lower 95%', 'coef upper 95%']]
y_pos = range(len(coefficients))
axes[0].barh(y_pos, coefficients.values, xerr=[
coefficients.values - conf_int['coef lower 95%'].values,
conf_int['coef upper 95%'].values - coefficients.values
], capsize=5)
axes[0].axvline(x=0, color='gray', linestyle='--')
axes[0].set_yticks(y_pos)
axes[0].set_yticklabels(coefficients.index)
axes[0].set_xlabel('Coefficient')
axes[0].set_title('Cox Model Coefficients')
axes[0].grid(True, alpha=0.3)
# Hazard ratios
hr = self.model.hazard_ratios_
axes[1].barh(y_pos, hr.values)
axes[1].axvline(x=1, color='gray', linestyle='--')
axes[1].set_yticks(y_pos)
axes[1].set_yticklabels(hr.index)
axes[1].set_xlabel('Hazard Ratio')
axes[1].set_title('Hazard Ratios (exp(coef))')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('cox_model_coefficients.png', dpi=150, bbox_inches='tight')
plt.show()
# Example usage
# cox_analyzer = CoxPHAnalyzer(data, 'duration', 'event')
# covariates = ['age', 'plan_type', 'usage_minutes', 'support_tickets']
# results = cox_analyzer.fit(covariates)
# cox_analyzer.summarize()
# hr_interpretation = cox_analyzer.hazard_ratios()
# cox_analyzer.check_proportional_hazards()
# cox_analyzer.visualize_coefficients()
4. Parametric Survival Models
class ParametricSurvivalModels:
"""Parametric survival models"""
def __init__(self, data, duration_col, event_col):
self.data = data
self.duration_col = duration_col
self.event_col = event_col
self.models = {}
def fit_exponential(self):
"""Fit exponential model"""
model = ExponentialFitter()
model.fit(self.data[self.duration_col], self.data[self.event_col])
self.models['exponential'] = model
return model
def fit_weibull(self):
"""Fit Weibull model"""
model = WeibullFitter()
model.fit(self.data[self.duration_col], self.data[self.event_col])
self.models['weibull'] = model
return model
def fit_log_normal(self):
"""Fit log-normal model"""
model = LogNormalFitter()
model.fit(self.data[self.duration_col], self.data[self.event_col])
self.models['log_normal'] = model
return model
def compare_models(self):
"""Compare parametric models using AIC/BIC"""
comparison = []
for name, model in self.models.items():
comparison.append({
'model': name,
'AIC': model.AIC_,
'BIC': model.BIC_,
'log_likelihood': model.log_likelihood_
})
comparison_df = pd.DataFrame(comparison).sort_values('AIC')
print("Model Comparison (lower AIC is better):")
print("=" * 50)
print(comparison_df)
return comparison_df
def predict_survival(self, model_name, times):
"""Predict survival function using specified model"""
if model_name not in self.models:
print(f"Model {model_name} not fitted")
return None
model = self.models[model_name]
survival = model.predict_survival_function(times)
return survival
def predict_hazard(self, model_name, times):
"""Predict hazard function using specified model"""
if model_name not in self.models:
print(f"Model {model_name} not fitted")
return None
model = self.models[model_name]
hazard = model.predict_hazard(times)
return hazard
def visualize_model_comparison(self, figsize=(10, 6)):
"""Visualize comparison of parametric models"""
fig, axes = plt.subplots(1, 2, figsize=figsize)
# Survival functions
times = np.arange(0, self.data[self.duration_col].max(), 1)
for name, model in self.models.items():
survival = model.predict_survival_function(times)
axes[0].plot(times, survival, label=name)
axes[0].set_xlabel('Time (days)')
axes[0].set_ylabel('Survival Probability')
axes[0].set_title('Survival Functions Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Hazard functions
for name, model in self.models.items():
hazard = model.predict_hazard(times)
axes[1].plot(times, hazard, label=name)
axes[1].set_xlabel('Time (days)')
axes[1].set_ylabel('Hazard Rate')
axes[1].set_title('Hazard Functions Comparison')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('parametric_models_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
# Example usage
# param_models = ParametricSurvivalModels(data, 'duration', 'event')
# param_models.fit_exponential()
# param_models.fit_weibull()
# param_models.fit_log_normal()
# comparison = param_models.compare_models()
# param_models.visualize_model_comparison()
π‘
Pro Tip: Cox PH is semi-parametric and doesn't require specifying the baseline hazard distribution. Use parametric models when you need to extrapolate survival curves beyond observed data.
5. Real-World Application: Churn Prediction
class ChurnPredictor:
"""Complete churn prediction using survival analysis"""
def __init__(self, data):
self.data = data
self.km_analyzer = None
self.cox_analyzer = None
def analyze_churn(self, customer_features, duration_col='tenure_days', event_col='churned'):
"""Complete churn analysis pipeline"""
print("Step 1: Kaplan-Meier Analysis")
self.km_analyzer = KaplanMeierAnalyzer(self.data, duration_col, event_col)
# Overall survival curve
self.km_analyzer.fit()
# By customer segment
if 'plan_type' in self.data.columns:
self.km_analyzer.fit(group_col='plan_type')
comparison = self.km_analyzer.compare_groups('plan_type')
print(f" Plan type comparison p-value: {comparison['p_value']:.4f}")
print("\nStep 2: Cox Proportional Hazards Model")
self.cox_analyzer = CoxPHAnalyzer(self.data, duration_col, event_col)
cox_results = self.cox_analyzer.fit(customer_features)
print(f" Concordance index: {cox_results['concordance_index']:.4f}")
print("\nStep 3: Hazard Ratio Interpretation")
hr_interpretation = self.cox_analyzer.hazard_ratios()
for feature, info in hr_interpretation.items():
print(f" {feature}: {info['interpretation']}")
print("\nStep 4: Risk Segmentation")
risk_scores = self.cox_analyzer.predict_partial_hazard(self.data[customer_features])
self.data['risk_score'] = risk_scores
# Segment customers by risk
self.data['risk_segment'] = pd.qcut(
self.data['risk_score'],
q=4,
labels=['Low Risk', 'Medium Risk', 'High Risk', 'Very High Risk']
)
print("\nRisk Segment Distribution:")
print(self.data['risk_segment'].value_counts())
return self.data
def generate_insights(self):
"""Generate business insights from analysis"""
insights = {
'key_findings': [],
'recommendations': []
}
# Analyze risk segments
risk_analysis = self.data.groupby('risk_segment').agg({
'risk_score': 'mean',
'churned': 'mean',
'tenure_days': 'mean',
'monthly_revenue': 'sum'
})
insights['key_findings'].append(
f"Very High Risk customers have {risk_analysis.loc['Very High Risk', 'churned']:.1%} churn rate"
)
# Cox model insights
hr = self.cox_analyzer.hazard_ratios()
for feature, info in hr.items():
if info['direction'] == 'increases risk':
insights['key_findings'].append(
f"{feature} increases churn risk: {info['interpretation']}"
)
# Recommendations
insights['recommendations'].extend([
"Focus retention efforts on Very High Risk segment",
"Investigate factors driving churn in Cox model",
"Implement early warning system using risk scores",
"A/B test retention interventions by risk segment"
])
return insights
def predict_individual_churn(self, customer_data, time_horizon=30):
"""Predict churn probability for individual customers"""
# Get survival function
survival_probs = self.cox_analyzer.predict_survival_function(
customer_data,
times=[time_horizon]
)
# Calculate churn probability
churn_probability = 1 - survival_probs.iloc[0].values[0]
# Risk score
risk_score = self.cox_analyzer.predict_partial_hazard(customer_data)
return {
'churn_probability_30d': churn_probability,
'risk_score': risk_score.values[0],
'risk_segment': self._get_risk_segment(risk_score.values[0])
}
def _get_risk_segment(self, risk_score):
"""Determine risk segment from score"""
percentiles = self.data['risk_score'].quantile([0.25, 0.5, 0.75])
if risk_score <= percentiles[0.25]:
return 'Low Risk'
elif risk_score <= percentiles[0.5]:
return 'Medium Risk'
elif risk_score <= percentiles[0.75]:
return 'High Risk'
else:
return 'Very High Risk'
# Example usage
# churn_predictor = ChurnPredictor(data)
# features = ['age', 'tenure_days', 'monthly_revenue', 'support_tickets', 'plan_type']
# analyzed_data = churn_predictor.analyze_churn(features)
# insights = churn_predictor.generate_insights()
# individual_pred = churn_predictor.predict_individual_churn(new_customer_data)
6. Common Follow-Up Questions
Follow-up 1: How do you handle time-varying covariates?
def time_varying_cox_model(data, id_col, start_col, end_col, event_col,
time_varying_covariates):
"""Cox model with time-varying covariates"""
from lifelines import CoxTimeVaryingFitter
# Prepare data in long format
# Each row represents a time interval for a subject
ctv = CoxTimeVaryingFitter()
# Fit model
ctv.fit(
data,
id_col=id_col,
start_col=start_col,
stop_col=end_col,
event_col=event_col
)
# Summary
print(ctv.summary)
return ctv
# Example: Model where usage changes over time
# time_varying_data = create_long_format(data, covariates=['usage_minutes', 'support_tickets'])
# tv_model = time_varying_cox_model(time_varying_data, 'customer_id', 'start', 'end', 'churned')
Follow-up 2: How do you validate survival models?
def validate_survival_model(model, data, duration_col, event_col, covariates):
"""Validate survival model using multiple metrics"""
from lifelines.utils import concordance_index
# Split data
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
# Fit on training data
model.fit(train_data, duration_col=duration_col, event_col=event_col)
# Predict on test data
risk_scores = model.predict_partial_hazard(test_data[covariates])
# Calculate concordance index
c_index = concordance_index(
test_data[duration_col],
-risk_scores, # Negate because higher risk = lower survival
test_data[event_col]
)
# Calibration (simplified)
# Check if predicted risks match observed outcomes
results = {
'concordance_index': c_index,
'interpretation': f'C-index of {c_index:.3f} indicates ' +
('good' if c_index > 0.7 else 'fair' if c_index > 0.6 else 'poor') +
' discrimination'
}
return results
# Example
# validation_results = validate_survival_model(
# CoxPHFitter(), data, 'duration', 'event', ['age', 'usage', 'tickets']
# )
Company-Specific Tips
βΉοΈ
Google Tips:
- Google values survival analysis for user behavior modeling
- Know how to handle large-scale survival data
- Understand how to extend survival analysis to multiple events
- Be comfortable with competing risks models
Amazon Tips:
- Amazon uses survival analysis for subscription churn
- Know how to build real-time churn prediction systems
- Understand how to design retention interventions
- Be familiar with customer lifetime value estimation
Quiz Section
Related Topics
- Time Series Analysis β Temporal patterns in events
- Causal Inference β Estimating treatment effects on survival
- Classification β Alternative approach to churn prediction
- Bayesian Survival β Bayesian survival models