Knowledge Graphs with LLMs
Knowledge Graph Fundamentals
Knowledge graphs represent information as entities (nodes) and relationships (edges), enabling structured reasoning and graph-based queries that complement LLM capabilities.
Building a Knowledge Graph
from typing import Dict, List, Tuple
from dataclasses import dataclass
import networkx as nx
@dataclass
class Entity:
id: str
name: str
type: str
properties: Dict
@dataclass
class Relation:
source: str
target: str
relation_type: str
properties: Dict
class KnowledgeGraph:
def __init__(self):
self.graph = nx.DiGraph()
self.entity_index: Dict[str, Entity] = {}
def add_entity(self, entity: Entity):
self.entity_index[entity.id] = entity
self.graph.add_node(
entity.id,
name=entity.name,
type=entity.type,
**entity.properties
)
def add_relation(self, relation: Relation):
self.graph.add_edge(
relation.source,
relation.target,
relation_type=relation.relation_type,
**relation.properties
)
def get_entity_context(self, entity_id: str, depth: int = 2) -> Dict:
if entity_id not in self.graph:
return {}
subgraph = nx.ego_graph(self.graph, entity_id, depth=depth)
return {
"entity": self.entity_index.get(entity_id),
"neighbors": list(subgraph.nodes()),
"relations": [
{
"source": u,
"target": v,
"type": data.get("relation_type")
}
for u, v, data in subgraph.edges(data=True)
]
}
def find_path(self, source: str, target: str) -> List[str]:
try:
return nx.shortest_path(self.graph, source, target)
except nx.NetworkXNoPath:
return []
kg = KnowledgeGraph()
kg.add_entity(Entity("e1", "Python", "ProgrammingLanguage", {"paradigm": "multi"}))
kg.add_entity(Entity("e2", "Guido van Rossum", "Person", {"nationality": "Dutch"}))
kg.add_relation(Relation("e2", "e1", "CREATED", {"year": 1991}))
Entity Extraction with LLMs
import openai
import json
from typing import List, Dict
class EntityExtractor:
def __init__(self, api_key: str):
self.client = openai.OpenAI(api_key=api_key)
def extract_entities(self, text: str) -> Dict:
response = self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": """Extract entities and relationships from text.
Return JSON with "entities" and "relations" arrays.
Entities: {"id": str, "name": str, "type": str}
Relations: {"source": str, "target": str, "type": str}"""},
{"role": "user", "content": text}
],
temperature=0,
response_format={"type": "json_object"}
)
return json.loads(response.choices[0].message.content)
def extract_from_documents(self, documents: List[str]) -> KnowledgeGraph:
kg = KnowledgeGraph()
for doc in documents:
result = self.extract_entities(doc)
for entity_data in result.get("entities", []):
entity = Entity(
id=entity_data["id"],
name=entity_data["name"],
type=entity_data["type"],
properties={}
)
kg.add_entity(entity)
for rel_data in result.get("relations", []):
relation = Relation(
source=rel_data["source"],
target=rel_data["target"],
relation_type=rel_data["type"],
properties={}
)
kg.add_relation(relation)
return kg
extractor = EntityExtractor(api_key="your-api-key")
kg = extractor.extract_from_documents([
"Python was created by Guido van Rossum in 1991.",
"Python is used for machine learning with libraries like TensorFlow."
])
Graph RAG Implementation
class GraphRAG:
def __init__(self, knowledge_graph: KnowledgeGraph, llm_client):
self.kg = knowledge_graph
self.client = llm_client
def retrieve_context(self, query: str, max_hops: int = 2) -> str:
entities = self._extract_query_entities(query)
context_parts = []
for entity_id in entities:
entity_context = self.kg.get_entity_context(entity_id, depth=max_hops)
context_parts.append(f"Entity: {entity_context['entity'].name}")
for rel in entity_context['relations'][:5]:
context_parts.append(
f" - {rel['source']} --[{rel['type']}]--> {rel['target']}"
)
return "\n".join(context_parts)
def _extract_query_entities(self, query: str) -> List[str]:
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "Extract entity names from the query."},
{"role": "user", "content": query}
],
temperature=0
)
entity_names = response.choices[0].message.content.split(", ")
return [
eid for eid, entity in self.kg.entity_index.items()
if entity.name in entity_names
]
def answer_with_graph(self, query: str) -> str:
graph_context = self.retrieve_context(query)
response = self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": """Answer questions using the provided knowledge graph context.
Cite specific relationships from the graph."""},
{"role": "user", "content": f"""Knowledge Graph Context:
{graph_context}
Question: {query}"""}
],
temperature=0.3
)
return response.choices[0].message.content
rag = GraphRAG(kg, openai.OpenAI(api_key="your-api-key"))
answer = rag.answer_with_graph("Who created Python and when?")
Graph Embeddings
import numpy as np
from typing import Dict
class GraphEmbeddings:
def __init__(self, knowledge_graph: KnowledgeGraph, embedding_dim: int = 128):
self.kg = knowledge_graph
self.dim = embedding_dim
self.embeddings: Dict[str, np.ndarray] = {}
def train_node2vec(self, walks_per_node: int = 10, walk_length: int = 20):
from node2vec import Node2Vec
node2vec = Node2Vec(
self.kg.graph,
dimensions=self.dim,
walk_length=walk_length,
num_walks=walks_per_node
)
model = node2vec.fit(window=10, min_count=1)
for node in self.kg.graph.nodes():
self.embeddings[node] = model.wv[node]
def get_entity_embedding(self, entity_id: str) -> np.ndarray:
if entity_id not in self.embeddings:
raise ValueError(f"Entity {entity_id} not found")
return self.embeddings[entity_id]
def find_similar_entities(self, entity_id: str, k: int = 5) -> List[Tuple[str, float]]:
if entity_id not in self.embeddings:
return []
target_embedding = self.embeddings[entity_id]
similarities = []
for eid, embedding in self.embeddings.items():
if eid != entity_id:
sim = np.dot(target_embedding, embedding) / (
np.linalg.norm(target_embedding) * np.linalg.norm(embedding)
)
similarities.append((eid, sim))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:k]
embeddings = GraphEmbeddings(kg)
embeddings.train_node2vec()
similar = embeddings.find_similar_entities("e1", k=3)
Best Practices
- Design ontology carefully before building the graph
- Use unique identifiers for entities across sources
- Implement incremental updates for dynamic knowledge
- Combine structured queries with LLM reasoning
- Validate extracted facts against the graph
- Monitor graph quality and completeness