Explainable AI
Attention Visualization
import torch
import matplotlib.pyplot as plt
import numpy as np
from typing import List
class AttentionVisualizer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def get_attention_weights(self, text: str) -> np.ndarray:
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
attention = outputs.attentions[-1][0].mean(dim=0)
return attention.numpy()
def visualize_attention(self, text: str, save_path: str = None):
attention = self.get_attention_weights(text)
tokens = self.tokenizer.tokenize(text)
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(attention, cmap="viridis")
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticklabels(tokens)
plt.colorbar(im)
ax.set_title("Self-Attention Weights")
if save_path:
plt.savefig(save_path, bbox_inches="tight")
plt.show()
def find_important_tokens(self, text: str, k: int = 5) -> List[str]:
attention = self.get_attention_weights(text)
tokens = self.tokenizer.tokenize(text)
importance = attention.mean(axis=0)
top_k_idx = importance.argsort()[-k:][::-1]
return [(tokens[i], importance[i]) for i in top_k_idx]
visualizer = AttentionVisualizer(model, tokenizer)
important = visualizer.find_important_tokens("The quick brown fox jumps", k=3)
Gradient-based Explanations
class GradientExplainer:
def __init__(self, model):
self.model = model
def vanilla_gradients(self, input_tensor, target_class):
input_tensor.requires_grad = True
output = self.model(input_tensor)
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot)
gradients = input_tensor.grad.abs()
return gradients
def integrated_gradients(self, input_tensor, target_class, n_steps=50):
baseline = torch.zeros_like(input_tensor)
scaled_inputs = [
baseline + (float(i) / n_steps) * (input_tensor - baseline)
for i in range(n_steps + 1)
]
gradients = []
for scaled_input in scaled_inputs:
scaled_input.requires_grad = True
output = self.model(scaled_input)
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot)
gradients.append(scaled_input.grad)
avg_gradients = torch.stack(gradients).mean(dim=0)
integrated_grads = (input_tensor - baseline) * avg_gradients
return integrated_grads
explainer = GradientExplainer(model)
gradients = explainer.vanilla_gradients(input_tensor, target_class=1)
ig_grads = explainer.integrated_gradients(input_tensor, target_class=1)
LIME Implementation
import numpy as np
from sklearn.linear_model import LinearRegression
class LIMEExplainer:
def __init__(self, model, n_samples=1000):
self.model = model
self.n_samples = n_samples
def explain_instance(self, instance, predict_fn, num_features=10):
n_features = instance.shape[0]
perturbations = np.random.binomial(1, 0.5, size=(self.n_samples, n_features))
predictions = []
for pert in perturbations:
perturbed = instance * pert
pred = predict_fn(perturbed.reshape(1, -1))
predictions.append(pred)
predictions = np.array(predictions)
distances = np.sqrt(np.sum((perturbations - 1) ** 2, axis=1))
weights = np.exp(-distances ** 2 / 2)
local_model = LinearRegression()
local_model.fit(perturbations, predictions, sample_weight=weights)
feature_importance = np.abs(local_model.coef_)
top_features = feature_importance.argsort()[-num_features:][::-1]
return {
"feature_indices": top_features,
"feature_importance": feature_importance[top_features],
"intercept": local_model.intercept_
}
lime = LIMEExplainer(model)
explanation = lime explain_instance(
instance=sample_data[0],
predict_fn=model.predict
)
SHAP Values
import shap
class SHAPExplainer:
def __init__(self, model, background_data):
self.explainer = shap.DeepExplainer(model, background_data)
def explain(self, input_data):
shap_values = self.explainer.shap_values(input_data)
return shap_values
def visualize(self, shap_values, features):
shap.summary_plot(shap_values, features)
def force_plot(self, shap_value, feature_names):
shap.force_plot(
shap_value.base_value,
shap_value.values,
feature_names=feature_names
)
shap_explainer = SHAPExplainer(model, background_data)
shap_values = shap_explainer.explain(test_data)
Best Practices
- Combine multiple explanation methods
- Validate explanations with domain experts
- Use explanations for debugging and improvement
- Consider fairness in explanations
- Document limitations of explanations
- Use explanations for regulatory compliance