Feature Store Integration
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
Why Feature Stores?
Feature stores solve training-serving skew, feature reuse, and real-time feature computation challenges.
βΉοΈ
Uber's Michelangelo feature store serves 10,000+ features across 1,000+ models in production.
Feast Feature Store
# feast_setup.py
from feast import FeatureStore, Entity, Feature, ValueType
from feast import FileSource, RequestSource
from feast.feature_store import FeatureStore
from feast.infra.offline_stores.file_source import FileSource
from feast.data_format import AvroFormat
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional
class FeatureStoreManager:
def __init__(self, feature_store_path: str = "./feature_repo"):
self.store = FeatureStore(repo_path=feature_store_path)
def define_entities(self):
customer = Entity(
name="customer_id",
value_type=ValueType.INT64,
description="Customer identifier"
)
product = Entity(
name="product_id",
value_type=ValueType.INT64,
description="Product identifier"
)
return [customer, product]
def define_features(self):
customer_features = [
Feature(name="lifetime_value", value_type=ValueType.FLOAT),
Feature(name="total_purchases", value_type=ValueType.INT32),
Feature(name="avg_order_value", value_type=ValueType.FLOAT),
Feature(name="days_since_last_purchase", value_type=ValueType.INT32),
Feature(name="customer_segment", value_type=ValueType.STRING),
]
product_features = [
Feature(name="price", value_type=ValueType.FLOAT),
Feature(name="category", value_type=ValueType.STRING),
Feature(name="avg_rating", value_type=ValueType.FLOAT),
Feature(name="total_sales", value_type=ValueType.INT64),
]
return customer_features, product_features
def get_historical_features(self, entity_df: pd.DataFrame) -> pd.DataFrame:
training_df = self.store.get_historical_features(
entity_df=entity_df,
features=[
"customer_features:lifetime_value",
"customer_features:total_purchases",
"customer_features:avg_order_value",
"customer_features:days_since_last_purchase",
"product_features:price",
"product_features:avg_rating",
]
).to_df()
return training_df
def get_online_features(self, entity_rows: List[Dict]) -> Dict:
response = self.store.get_online_features(
features=[
"customer_features:lifetime_value",
"customer_features:total_purchases",
"customer_features:avg_order_value",
"product_features:price",
"product_features:avg_rating",
],
entity_rows=entity_rows
)
return response.to_dict()
# Feature definitions
feature_store_path = "./feature_repo"
# features.yaml content
features_yaml = """
project: ml_platform
registry: data/registry.db
provider: local
online_store:
type: sqlite
path: data/online_store.db
entity_definitions:
- name: customer_id
value_type: INT64
description: Customer identifier
feature_views:
- name: customer_features
entities:
- customer_id
ttl: 30d
schema:
- name: lifetime_value
value_type: FLOAT
- name: total_purchases
value_type: INT32
- name: avg_order_value
value_type: FLOAT
- name: days_since_last_purchase
value_type: INT32
source:
path: data/customer_features.parquet
- name: product_features
entities:
- product_id
ttl: 7d
schema:
- name: price
value_type: FLOAT
- name: avg_rating
value_type: FLOAT
- name: total_sales
value_type: INT64
source:
path: data/product_features.parquet
"""
Real-Time Feature Serving
# realtime_features.py
import redis
import json
import time
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from datetime import datetime
import hashlib
@dataclass
class FeatureConfig:
redis_host: str
redis_port: int
redis_db: int
default_ttl: int
class RealTimeFeatureStore:
def __init__(self, config: FeatureConfig):
self.redis = redis.Redis(
host=config.redis_host,
port=config.redis_port,
db=config.redis_db,
decode_responses=True
)
self.default_ttl = config.default_ttl
def set_feature(self, entity_id: str, feature_name: str, value: Any, ttl: Optional[int] = None):
key = f"feature:{entity_id}:{feature_name}"
serialized = json.dumps({
"value": value,
"timestamp": datetime.now().isoformat(),
"ttl": ttl or self.default_ttl
})
self.redis.setex(key, ttl or self.default_ttl, serialized)
def get_feature(self, entity_id: str, feature_name: str) -> Optional[Any]:
key = f"feature:{entity_id}:{feature_name}"
data = self.redis.get(key)
if data:
return json.loads(data)["value"]
return None
def set_features_batch(self, entity_id: str, features: Dict[str, Any], ttl: Optional[int] = None):
pipe = self.redis.pipeline()
for name, value in features.items():
key = f"feature:{entity_id}:{name}"
serialized = json.dumps({
"value": value,
"timestamp": datetime.now().isoformat(),
"ttl": ttl or self.default_ttl
})
pipe.setex(key, ttl or self.default_ttl, serialized)
pipe.execute()
def get_features_batch(self, entity_id: str, feature_names: List[str]) -> Dict[str, Any]:
pipe = self.redis.pipeline()
for name in feature_names:
pipe.get(f"feature:{entity_id}:{name}")
results = pipe.execute()
features = {}
for name, data in zip(feature_names, results):
if data:
features[name] = json.loads(data)["value"]
return features
def get_online_features(self, entity_ids: List[str], feature_names: List[str]) -> List[Dict]:
pipe = self.redis.pipeline()
for entity_id in entity_ids:
for feature_name in feature_names:
pipe.get(f"feature:{entity_id}:{feature_name}")
results = pipe.execute()
feature_vectors = []
idx = 0
for entity_id in entity_ids:
vector = {"entity_id": entity_id}
for feature_name in feature_names:
data = results[idx]
if data:
vector[feature_name] = json.loads(data)["value"]
else:
vector[feature_name] = None
idx += 1
feature_vectors.append(vector)
return feature_vectors
def compute_aggregations(self, entity_id: str, event_type: str, window_seconds: int = 3600):
key = f"events:{entity_id}:{event_type}"
cutoff = time.time() - window_seconds
events = self.redis.zrangebyscore(key, cutoff, "+inf")
return len(events)
def record_event(self, entity_id: str, event_type: str, event_data: Dict):
key = f"events:{entity_id}:{event_type}"
self.redis.zadd(key, {json.dumps(event_data): time.time()})
self.redis.expire(key, 86400)
# Usage
config = FeatureConfig(
redis_host="localhost",
redis_port=6379,
redis_db=0,
default_ttl=3600
)
feature_store = RealTimeFeatureStore(config)
feature_store.set_features_batch("customer_123", {
"lifetime_value": 1250.50,
"total_purchases": 15,
"avg_order_value": 83.37,
"segment": "premium"
})
features = feature_store.get_features_batch("customer_123", [
"lifetime_value", "total_purchases", "avg_order_value", "segment"
])
print(features)
Feature Computation Pipeline
# feature_pipeline.py
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.io.parquetio import WriteToParquet
import pandas as pd
import numpy as np
from typing import Dict, List
from datetime import datetime, timedelta
class FeatureComputingTransform(beam.DoFn):
def __init__(self, compute_functions: Dict[str, callable]):
self.compute_functions = compute_functions
def process(self, element):
entity_id, events = element
features = {}
for feature_name, compute_fn in self.compute_functions.items():
features[feature_name] = compute_fn(events)
yield {"entity_id": entity_id, **features}
class FeaturePipeline:
def __init__(self, pipeline_options: PipelineOptions):
self.options = pipeline_options
def create_feature_pipeline(self, input_path: str, output_path: str):
compute_functions = {
"total_purchases": lambda events: len(events),
"avg_order_value": lambda events: np.mean([e.get("amount", 0) for e in events]) if events else 0,
"days_since_last_purchase": lambda events: (
(datetime.now() - max([datetime.fromisoformat(e["timestamp"]) for e in events])).days
if events else 999
),
"lifetime_value": lambda events: sum(e.get("amount", 0) for e in events),
}
with beam.Pipeline(options=self.options) as pipeline:
(
pipeline
| "Read Events" >> beam.io.ReadFromBigQuery(
query="SELECT * FROM events WHERE timestamp > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 30 DAY)"
)
| "Group by Entity" >> beam.GroupBy(lambda x: x["entity_id"])
| "Compute Features" >> beam.ParDo(FeatureComputingTransform(compute_functions))
| "Write Features" >> WriteToParquet(
file_path_prefix=output_path,
schema="entity_id:STRING, total_purchases:INTEGER, avg_order_value:FLOAT, days_since_last_purchase:INTEGER, lifetime_value:FLOAT"
)
)
# Feature validation
class FeatureValidator:
def __init__(self):
self.validation_rules = {}
def add_rule(self, feature_name: str, rule: callable, description: str):
if feature_name not in self.validation_rules:
self.validation_rules[feature_name] = []
self.validation_rules[feature_name].append({
"rule": rule,
"description": description
})
def validate(self, features: Dict) -> Dict[str, List[str]]:
errors = {}
for feature_name, value in features.items():
if feature_name in self.validation_rules:
for rule_info in self.validation_rules[feature_name]:
if not rule_info["rule"](value):
if feature_name not in errors:
errors[feature_name] = []
errors[feature_name].append(rule_info["description"])
return errors
validator = FeatureValidator()
validator.add_rule("lifetime_value", lambda x: x >= 0, "Must be non-negative")
validator.add_rule("total_purchases", lambda x: x > 0, "Must be positive")
validator.add_rule("avg_order_value", lambda x: 0 < x < 100000, "Must be between 0 and 100k")
test_features = {
"lifetime_value": -10,
"total_purchases": 5,
"avg_order_value": 50.0
}
validation_errors = validator.validate(test_features)
print("Validation errors:", validation_errors)
Follow-Up Questions
- How do you handle feature drift detection in a feature store?
- What are the trade-offs between batch and real-time feature computation?
- How would you implement feature versioning and rollback?
- What security considerations apply to feature stores containing PII?