πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Text-to-SQL

🟒 Free Lesson

Advertisement

Text-to-SQL

NL Question"Show top 5 customers""by revenue this year"Schema LinkMatch EntitiesColumn DetectionLLM GenerateCode LLMSchema-AwareSQL QuerySELECT name, SUM(amount)FROM orders JOIN customersGROUP BY name LIMIT 5ValidateSyntax CheckSafety CheckExecuteRun QueryFetch ResultsResults+-------------------+| Alice | $50,000 || Bob | $45,000 |Pipeline: NL Question -> Schema Link -> Generate SQL -> Validate -> Execute -> Return Results

Text-to-SQL systems convert natural language questions into SQL queries, enabling non-technical users to query databases using plain language.

Schema Extractor

import sqlite3
from dataclasses import dataclass
from typing import List, Dict

@dataclass
class ColumnInfo:
    name: str
    type: str
    description: str = ""

@dataclass
class TableInfo:
    name: str
    columns: List[ColumnInfo]
    description: str = ""

class SchemaExtractor:
    def __init__(self, db_path: str):
        self.conn = sqlite3.connect(db_path)

    def get_schema(self) -> List[TableInfo]:
        cursor = self.conn.execute(
            "SELECT name FROM sqlite_master WHERE type='table'"
        )
        tables = []
        for (table_name,) in cursor.fetchall():
            columns = self._get_columns(table_name)
            tables.append(TableInfo(name=table_name, columns=columns))
        return tables

    def _get_columns(self, table_name: str) -> List[ColumnInfo]:
        cursor = self.conn.execute(f"PRAGMA table_info({table_name})")
        return [
            ColumnInfo(name=row[1], type=row[2])
            for row in cursor.fetchall()
        ]

    def schema_to_text(self, tables: List[TableInfo]) -> str:
        lines = []
        for table in tables:
            lines.append(f"Table: {table.name}")
            for col in table.columns:
                lines.append(f"  - {col.name} ({col.type})")
            lines.append("")
        return "\n".join(lines)

    def schema_to_prompt(self, tables: List[TableInfo]) -> str:
        return f"""Database Schema:
{self.schema_to_text(tables)}
Generate SQL for the given question based on this schema."""

# Usage
extractor = SchemaExtractor("database.db")
schema = extractor.get_schema()
schema_text = extractor.schema_to_text(schema)

Text-to-SQL Generator

from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate

class TextToSQLGenerator:
    def __init__(self, llm):
        self.llm = llm
        self.schema_cache = {}

    def generate_sql(self, question: str, schema_text: str) -> str:
        prompt = PromptTemplate.from_template(
            """You are a SQL expert. Generate a SQL query to answer the question.
            Use the following schema:
            {schema}

            Question: {question}
            SQL:"""
        )
        chain = prompt | self.llm
        result = chain.invoke({"schema": schema_text, "question": question})
        sql = result.content.strip()
        if sql.startswith("```sql"):
            sql = sql[7:-3]
        return sql

    def generate_with_examples(self, question: str, schema_text: str,
                               examples: list) -> str:
        examples_text = "\n".join([
            f"Q: {ex['question']}\nSQL: {ex['sql']}" for ex in examples
        ])
        prompt = PromptTemplate.from_template(
            """You are a SQL expert. Generate SQL for the question.
            Schema: {schema}

            Examples:
            {examples}

            Question: {question}
            SQL:"""
        )
        chain = prompt | self.llm
        result = chain.invoke({
            "schema": schema_text,
            "examples": examples_text,
            "question": question
        })
        return result.content.strip()

    def generate_with_correction(self, question: str, schema_text: str,
                                 error: str, failed_sql: str) -> str:
        prompt = PromptTemplate.from_template(
            """The previous SQL query failed. Fix the error.
            Schema: {schema}
            Question: {question}
            Failed SQL: {failed_sql}
            Error: {error}

            Correct SQL:"""
        )
        chain = prompt | self.llm
        result = chain.invoke({
            "schema": schema_text,
            "question": question,
            "failed_sql": failed_sql,
            "error": error
        })
        return result.content.strip()

# Usage
generator = TextToSQLGenerator(llm)
sql = generator.generate_sql("Show top 5 customers by revenue", schema_text)

Query Validator

import re

class QueryValidator:
    def __init__(self):
        self.dangerous_keywords = ["DROP", "DELETE", "TRUNCATE", "ALTER", "INSERT", "UPDATE"]
        self.read_only_tables = []

    def validate_syntax(self, sql: str) -> dict:
        sql_upper = sql.upper()
        has_select = "SELECT" in sql_upper
        has_from = "FROM" in sql_upper
        balanced = sql.count("(") == sql.count(")")
        return {
            "valid_syntax": has_select and has_from and balanced,
            "is_read_only": not any(kw in sql_upper for kw in self.dangerous_keywords),
            "has_where" : "WHERE" in sql_upper
        }

    def check_safety(self, sql: str) -> dict:
        sql_upper = sql.upper()
        issues = []
        for kw in self.dangerous_keywords:
            if kw in sql_upper:
                issues.append(f"Dangerous keyword: {kw}")
        if ";" in sql and sql.count(";") > 1:
            issues.append("Multiple statements detected")
        return {"safe": len(issues) == 0, "issues": issues}

    def validate_against_schema(self, sql: str, schema: list) -> dict:
        sql_upper = sql.upper()
        table_names = [t.name.upper() for t in schema]
        referenced = re.findall(r'FROM\s+(\w+)', sql_upper)
        issues = []
        for ref in referenced:
            if ref not in table_names:
                issues.append(f"Unknown table: {ref}")
        return {"valid": len(issues) == 0, "issues": issues}

    def full_validate(self, sql: str, schema: list) -> dict:
        syntax = self.validate_syntax(sql)
        safety = self.check_safety(sql)
        schema_valid = self.validate_against_schema(sql, schema)
        return {
            "syntax": syntax,
            "safety": safety,
            "schema": schema_valid,
            "overall_valid": syntax["valid_syntax"] and safety["safe"] and schema_valid["valid"]
        }

# Usage
validator = QueryValidator()
result = validator.full_validate("SELECT * FROM users", schema)

Text-to-SQL Pipeline

class TextToSQLPipeline:
    def __init__(self, db_path: str, llm):
        self.extractor = SchemaExtractor(db_path)
        self.generator = TextToSQLGenerator(llm)
        self.validator = QueryValidator()
        self.conn = sqlite3.connect(db_path)

    def run(self, question: str) -> dict:
        schema = self.extractor.get_schema()
        schema_text = self.extractor.schema_to_text(schema)

        sql = self.generator.generate_sql(question, schema_text)

        validation = self.validator.full_validate(sql, schema)
        if not validation["overall_valid"]:
            return {"error": "Invalid query", "validation": validation, "sql": sql}

        try:
            cursor = self.conn.execute(sql)
            columns = [desc[0] for desc in cursor.description]
            results = cursor.fetchall()
            return {
                "sql": sql,
                "columns": columns,
                "results": results,
                "row_count": len(results)
            }
        except Exception as e:
            corrected_sql = self.generator.generate_with_correction(
                question, schema_text, str(e), sql
            )
            return {"error": str(e), "corrected_sql": corrected_sql}

    def run_with_retry(self, question: str, max_retries: int = 3) -> dict:
        for attempt in range(max_retries):
            result = self.run(question)
            if "error" not in result:
                return result
        return result

# Usage
pipeline = TextToSQLPipeline("database.db", llm)
result = pipeline.run("Show me the top 5 customers by total orders")

Key Takeaways

  • Schema linking maps natural language to database entities
  • Few-shot examples dramatically improve SQL generation quality
  • Validation ensures generated queries are safe and syntactically correct
  • Error correction iteratively fixes failed queries
  • Context-aware generation considers database schema and relationships
⭐

Premium Content

Text-to-SQL

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
πŸ’ΌInterview Prep
πŸ“œCertificates
🀝Community Access

Already a member? Log in

Need Expert Generative AI Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement