Build a text-to-SQL generator

query_builder 60 min

Text-to-SQL is an application of Large Language Models (LLMs) that converts natural language queries into SQL. You can use text-to-SQL to query your databases even if you don’t know SQL.

This tutorial demonstrates how to generate SQL with LLMs and implement dynamic few-shot prompting with Astra DB Serverless. Specifically, you will do the following:

  1. Download a dataset containing questions and their corresponding answering SQL statements.

  2. Populate a relational (SQL) database with sample data to run the queries.

  3. Learn how to query an LLM so that it generates SQL when given a natural-language question.

  4. Use dynamic few-shot prompting and a vector database to provide targeted context for a given query.

  5. Quantify the effectiveness of the LLM by comparing the generated responses to new queries against the curated example data.

For more information about the prompting techniques and other concepts demonstrated in this tutorial, see this tutorial’s Colab notebook, which is also available as a static web page.

About dynamic few-shot prompting

When prompting LLMs to generate an answer, you can improve the quality of the response by providing example question-answer pairs together with the actual question. This technique is known as few-shot prompting.

Dynamic few-shot prompting is a variation of few-shot prompting where the chosen question-answer examples vary based on the question being asked. This technique compares a new query with a set of pre-defined question-answer examples, selects examples that are most similar to the query, and then uses those examples as additional context for the generated response.

This increases the effectiveness of few-shot prompting by providing the LLM with examples that are most relevant to the query.

Vector search is critical to finding similar examples as dynamically and quickly as possible. To use vector search with dynamic few-shot prompting, you generate vector embeddings for the example data, and then store that data in a vector database. This database is referred to as the prompt store. Your application can then run a vector search on your prompt store to find the most similar examples for any new query.

Using an Astra DB Serverless database as a dynamic few-shot prompt store is a robust and low-effort way to improve your text-to-SQL applications in production.

Prerequisites

To complete this tutorial, you need the following:

  • An Astra account.

  • An active Serverless (Vector) database.

  • An application token with the Database Administrator role and your database’s API endpoint. For more information, see Manage application tokens.

  • A paid OpenAI account and API key.

  • Python version 3.9 or later, with the following packages installed:

    pip install \
      "openai>=1.0,<2.0" \
      "astrapy==2.0.0-rc1" \
      "datasets==3.*" \
      "tenacity==9.*"

    This tutorial is written for the latest version of the Data API Python client, astrapy 2.0.0-rc1, which is a pre-release version. For more information, see Data API client upgrade guide.

Initialize the clients

Prepare your environment, initialize the OpenAI and Astra DB clients, and prepare the utility to run arbitrary SQL on a specified database.

  1. Import the required dependencies:

    import json
    import os
    import re
    import sqlite3
    from functools import lru_cache, partial
    from getpass import getpass
    from typing import Any, Callable, Dict, List, Tuple
    
    import openai
    import pandas as pd
    from astrapy import DataAPIClient
    from astrapy.api_options import APIOptions, SerdesOptions
    from astrapy.info import CollectionDefinition
    from datasets import load_dataset
    from tenacity import retry, wait_exponential
    from tqdm.auto import tqdm
  2. Define the model, prompt store, and sample database settings:

    # LLM and embedding model settings.
    # The chosen LLM: intentionally not last-generation for illustrative purposes.
    LLM_MODEL_NAME = "gpt-3.5-turbo-0125"
    
    EMBEDDING_MODEL = "text-embedding-3-small"
    EMBEDDING_DIMENSION = 1536
    
    # Collection for storing the examples for few-shot prompting
    ASTRA_DB_COLLECTION_NAME = "text2sql_examples"
    
    # This is needed for the SQLite database
    SQLITE_FILE_NAME = "sample_database.db"

    This tutorial uses the following settings:

    • LLM_MODEL_NAME: The LLM uses the OpenAI GPT3.5 model because it is less expensive and less powerful than more recent models. With a less powerful model, this tutorial can better demonstrate the increased accuracy of dynamic few-shot prompting over other techniques. For examples of zero-shot and fixed few-shot prompting, see this tutorial’s Colab notebook, which is also available as a static web page.

    • EMBEDDING_MODEL and EMBEDDING_DIMENSION: The embedding model, which is also from OpenAI, is used to calculate vector embeddings for the example questions.

    • ASTRA_DB_COLLECTION_NAME: A collection in an Astra DB database that serves as the prompt store.

    • SQLITE_FILE_NAME: For the relational (SQL) database, this tutorial uses a local SQLite file for simplicity.

      If desired, you can use different models and SQL databases with little impact on the implementation of the SQL generation logic.

  3. Initialize the OpenAI client with your API key:

    if "OPENAI_API_KEY" not in os.environ:
        os.environ["OPENAI_API_KEY"] = getpass("OpenAI API Key: ")
    
    client = openai.OpenAI()
    
    # We equip the LLM and embedding calls with exponential backoff and retry
    # in case the service is throttled
    
    @retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
    def llm_completion(prompt: str) -> str:
        response = client.chat.completions.create(
            messages=[{
                "role": "user",
                "content": prompt,
            }],
            model=LLM_MODEL_NAME,
        ).choices[0].message.content
        return response
    
    
    @retry(wait=wait_exponential(multiplier=1.2, min=4, max=30))
    def compute_embedding_list(texts: List[str]) -> List[List[float]]:
        return [
            data.embedding
            for data in client.embeddings.create(
                input=texts,
                model=EMBEDDING_MODEL,
                timeout=10,
            ).data
        ]
    
    
    def compute_embedding(text: str) -> List[float]:
        return compute_embedding_list([text])[0]
    Response
    OpenAI API Key:  ········
  4. Initialize the Data API client with your Astra DB API endpoint and application token:

    ASTRA_DB_API_ENDPOINT = os.environ.get("ASTRA_DB_API_ENDPOINT")
    if not ASTRA_DB_API_ENDPOINT:
        ASTRA_DB_API_ENDPOINT = input("Astra DB API Endpoint: ")
    ASTRA_DB_APPLICATION_TOKEN = os.environ.get("ASTRA_DB_APPLICATION_TOKEN")
    if not ASTRA_DB_APPLICATION_TOKEN:
        ASTRA_DB_APPLICATION_TOKEN = getpass("Astra DB Token: ")
    
    data_api_client = DataAPIClient()
    astra_db = data_api_client.get_database(
        ASTRA_DB_API_ENDPOINT,
        token=ASTRA_DB_APPLICATION_TOKEN,
    )
    Response
    Astra DB API Endpoint:  https://...apps.astra.datastax.com
    Astra DB Token:  ········
  5. Initialize the interface to the SQL database:

    def execute_sql(
        *statements: List[str], raise_on_error: bool = True
    ) -> List[Tuple[Any, ...]]:
        """
        Utility to execute DB SQL statements and return the result
        of the final query.
        """
        with sqlite3.connect(SQLITE_FILE_NAME) as connection:
            cursor = connection.cursor()
            try:
                for statement in statements:
                    cursor.execute(statement)
            except sqlite3.OperationalError as e:
                # syntax errors or similar in running the query
                if raise_on_error:
                    raise
                return []
            try:
                result = cursor.fetchall()
            except sqlite3.OperationalError as e:
                # No result set (final statement was not a SELECT)
                result = []
    
            return result

Load the Spider dataset

This tutorial uses the Spider dataset, which is an established standard for evaluating generated SQL performance. This dataset consists of question-query pairs that indicate the ideal query to be generated from a given natural-language question. This tutorial refers to these pair as example pairs.

In a production application, you can collect additional sample data by storing your application’s generated SQL in a live environment, evaluating the quality of the generated queries, and then adding approved examples to your prompt store.

For example, you can provide a simple feedback interface for users to vote on the quality of the generated queries with fixed responses like yes/no, helpful/unhelpful, or thumbs-up/thumbs-down. Alternatively, you can have internal human evaluators grade the generated queries.

Approved queries can be added to your prompt store as example pairs for dynamic few-shot prompting. As the prompt store grows, it provides more context to the LLM, which can increase the efficacy of your SQL-generating application.

  1. Use the datasets package to load the Spider dataset and schema:

    spider = load_dataset("spider", split="validation")
    spider_df = spider.to_pandas()
    spider_schema = load_dataset("richardr1126/spider-schema", split="train")
    spider_schema_df = spider_schema.to_pandas()
  2. Show the first few rows of the dataset:

    spider_df.head(3)
    Response

    For readability, some columns were removed from the following example response.

    db_id query question query_toks

    0

    concert_singer

    SELECT count(*) FROM singer

    How many singers do we have?

    [SELECT, count, (, *, ), FROM, singer]

    1

    concert_singer

    SELECT count(*) FROM singer

    What is the total number of singers?

    [SELECT, count, (, *, ), FROM, singer]

    2

    concert_singer

    SELECT name , country , age FROM singer ORDER BY name

    Show name, country, age for all singers ordered by name

    [SELECT, name, ,, country, ,, age, FROM, singe…​

  3. Set aside a few questions for testing:

    # picking some random questions for testing.
    test_indices = [25, 35, 45, 55, 65, 75]
    
    test_df = spider_df.loc[test_indices]
    test_queries = set(test_df["query"])
    
    # Perform comparisons up to spacing and upper/lowercase differences:
    norm_test_queries = {qr.replace(" ", "").lower() for qr in test_queries}
    
    def _is_test_query(qr: str) -> bool:
        return qr.replace(" ", "").lower() in norm_test_queries
    
    idx_to_remove = spider_df[spider_df["query"].apply(_is_test_query)].index
    
    examples_df = spider_df.drop(idx_to_remove)
    
    print(
        f"Total {len(spider_df)} rows loaded.\n"
        f"Rows kept aside for testing: {len(test_queries)}\n"
        f"Rows to use as examples: {len(examples_df)}.\n"
        f"  ({len(idx_to_remove) - len(test_queries)} further rows pruned from examples)"
    )

    These questions come from two schemas in the Spider dataset. One schema contains tables about singers and where they’ve performed, and the other schema contains tables about pets and their owners.

    Some rows are set aside for testing, and the remaining rows are set aside as candidate examples for dynamic few-shot prompts later in the tutorial.

    To avoid data leakage and run a fair test, the code prunes additional rows with the same queries as the designated test rows. It is a best practice to ensure that test queries are neither duplicated across the test rows, nor used in later sample rows for dynamic few-shot prompting.

    Response
    Total 1034 rows loaded.
    Rows kept aside for testing: 6
    Rows to use as examples: 1023.
      (5 further rows pruned from examples)

Set up a SQL database

Store structured data matching the Spider dataset in a SQL-compatible database.

  1. Set up the concert_singer database by creating tables matching the Spider schema, and then inserting sample data for each test question:

    # Set up concert_singer DB
    CREATE_TABLES_SQL = """
    -- Creating the stadium table
    CREATE TABLE stadium (
        Stadium_ID INT PRIMARY KEY,
        Location TEXT,
        Name TEXT,
        Capacity INT,
        Highest INT,
        Lowest INT,
        Average INT
    );
    
    -- Creating the singer table
    CREATE TABLE singer (
        Singer_ID INT PRIMARY KEY,
        Name TEXT,
        Country TEXT,
        Song_Name TEXT,
        Song_release_year TEXT,
        Age INT,
        Is_male BOOLEAN
    );
    
    -- Creating the concert table
    CREATE TABLE concert (
        concert_ID INT PRIMARY KEY,
        concert_Name TEXT,
        Theme TEXT,
        Stadium_ID INT,
        Year TEXT,
        FOREIGN KEY (Stadium_ID) REFERENCES stadium(Stadium_ID)
    );
    
    -- Creating the singer_in_concert table
    CREATE TABLE singer_in_concert (
        concert_ID INT,
        Singer_ID INT,
        PRIMARY KEY (concert_ID, Singer_ID),
        FOREIGN KEY (concert_ID) REFERENCES concert(concert_ID),
        FOREIGN KEY (Singer_ID) REFERENCES singer(Singer_ID)
    );
    """
    
    POPULATE_DATA_SQL = """
    -- Populating the stadium table
    INSERT INTO stadium (
        Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average
    ) VALUES
    (1, 'New York, USA', 'Liberty Stadium', 50000, 1000, 500, 750),
    (2, 'London, UK', 'Royal Arena', 60000, 1500, 600, 900),
    (3, 'Tokyo, Japan', 'Sunshine Dome', 55000, 1200, 550, 800),
    (4, 'Sydney, Australia', 'Ocean Field', 40000, 900, 400, 650),
    (5, 'Berlin, Germany', 'Eagle Grounds', 45000, 1100, 450, 700);
    
    -- Populating the singer table
    INSERT INTO singer (
        Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male
    ) VALUES
    (1, 'John Doe', 'USA', 'Freedom Song', '2018', 28, TRUE),
    (2, 'Emma Stone', 'UK', 'Rolling Hills', '2019', 25, FALSE),
    (3, 'Haruki Tanaka', 'Japan', 'Tokyo Lights', '2020', 30, TRUE),
    (4, 'Alice Johnson', 'Australia', 'Ocean Waves', '2021', 27, FALSE),
    (5, 'Max Müller', 'Germany', 'Berlin Nights', '2017', 32, TRUE);
    
    -- Populating the concert table
    INSERT INTO concert (concert_ID, concert_Name, Theme, Stadium_ID, Year) VALUES
    (1, 'Freedom Fest', 'Pop', 1, '2021'),
    (2, 'Rock Mania', 'Rock', 2, '2022'),
    (3, 'Electronic Waves', 'Electronic', 3, '2020'),
    (4, 'Jazz Evenings', 'Jazz', 3, '2019'),
    (5, 'Classical Mornings', 'Classical', 5, '2023');
    
    -- Populating the singer_in_concert table
    INSERT INTO singer_in_concert (concert_ID, Singer_ID) VALUES
    (1, 1),
    (1, 2),
    (2, 3),
    (3, 4),
    (4, 5),
    (5, 1),
    (2, 2),
    (3, 3),
    (4, 4),
    (5, 5);
    """
    
    statements = [
        statement.strip() for statement in (
            CREATE_TABLES_SQL + POPULATE_DATA_SQL
        ).split(";")
        if len(statement.strip()) > 0
    ]
    execute_sql(*statements)
  2. Set up the pets_1 database:

    # Set up pets_1 DB
    CREATE_TABLES_SQL = """
    -- Creating the Student table
    CREATE TABLE Student (
        StuID INT PRIMARY KEY,
        LName VARCHAR(255),
        Fname VARCHAR(255),
        Age INT,
        Sex VARCHAR(10),
        Major INT,
        Advisor INT,
        city_code VARCHAR(50)
    );
    
    -- Creating the Pets table
    CREATE TABLE Pets (
        PetID INT PRIMARY KEY,
        PetType VARCHAR(255),
        pet_age INT,
        weight INT
    );
    
    -- Creating the Has_Pet table
    CREATE TABLE Has_Pet (
        StuID INT,
        PetID INT,
        FOREIGN KEY (StuID) REFERENCES Student(StuID),
        FOREIGN KEY (PetID) REFERENCES Pets(PetID)
    );
    """
    
    POPULATE_DATA_SQL = """
    -- Populating the Student table
    INSERT INTO Student (
        StuID, LName, Fname, Age, Sex, Major, Advisor, city_code
    ) VALUES
    (101, 'Smith', 'John', 20, 'M', 501, 301, 'NYC'),
    (102, 'Johnson', 'Emma', 22, 'F', 502, 302, 'LAX'),
    (103, 'Williams', 'Michael', 21, 'M', 503, 303, 'CHI'),
    (104, 'Brown', 'Sarah', 23, 'F', 504, 304, 'HOU'),
    (105, 'Jones', 'David', 19, 'M', 505, 305, 'PHI');
    
    -- Populating the Pets table
    INSERT INTO Pets (PetID, PetType, pet_age, weight) VALUES
    (201, 'dog', 3, 20.5),
    (202, 'cat', 5, 10.2),
    (203, 'dog', 2, 8.1),
    (204, 'parrot', 4, 0.5),
    (205, 'hamster', 1, 0.7);
    
    -- Populating the Has_Pet table
    INSERT INTO Has_Pet (StuID, PetID) VALUES
    (101, 201),
    (101, 202),
    (105, 203),
    (103, 204),
    (104, 205),
    (105, 201);
    """
    
    statements = [
        statement.strip() for statement in (
            CREATE_TABLES_SQL + POPULATE_DATA_SQL
        ).split(";")
        if len(statement.strip()) > 0
    ]
    execute_sql(*statements)
  3. Define an evaluation function to assess the generated queries' correctness.

    The following function runs the Example SQL counterpart alongside each of the generated queries, and then compares the results to determine whether they return the same data. To account for natural variations in ordering, if the only difference between results is their ordering, then the evaluator treats those results as the same.

    The function returns a short report in the form of a table. Ideally, you want to maximize the accuracy, which is the percentage of generated queries that return the same data as the corresponding Example SQL.

    In a production application, you must measure the accuracy over a much larger number of test questions than shown in this tutorial.

    def eval_generated_queries(generated_queries: Dict[int, str]) -> pd.DataFrame:
        """
        Evaluate the given queries against the test set,
        and return a report of the performance on each row.
        """
        report = []
    
        for tq_index, tq_generated_sql in sorted(generated_queries.items()):
            query_row = test_queries[tq_index]
            example_sql = query_row["query"]
            example_results = execute_sql(example_sql)
    
            try:
                gen_results = execute_sql(tq_generated_sql)
                error = None
    
                # Figure out correctness, not super straightforward
                if len(example_results) == len(gen_results):
                    # (note: here and in the next sorting,
                    #  we ignore "sorting collisions" for simplicity...)
                    in_sorted_gen = [sorted(tpl, key=str) for tpl in gen_results]
                    in_sorted_example = [sorted(tpl, key=str) for tpl in example_results]
                    if "ORDER BY" in example_sql.upper():
                        # ordering within tuples must match
                        correct = in_sorted_gen == in_sorted_example
                    else:
                        # sort sequence of tuples before comparison\
                        correct = sorted(in_sorted_gen) == sorted(in_sorted_example)
                else:
                    correct = False
    
            except Exception as e:
                gen_results = None
                error = e
                correct = False
    
            report.append({
                "DB ID": query_row["db_id"],
                "Question": query_row["question"],
                "Correct": correct,
                "Error": error,
            })
    
        return pd.DataFrame(report)

Generate SQL queries

Zero-shot prompting and fixed few-shot prompting are good starting points for a text-to-SQL implementation. However, you can boost the accuracy of your application’s generated SQL by using dynamic few-shot prompting.

For a detailed comparison of dynamic-few shot prompting with zero-shot and fixed-few-shot, see this tutorial’s Colab notebook, which is also available as a static web page.

Dynamic few-shot prompting compares a new query with a set of pre-defined question-answer examples, selects examples that are most similar to the query, and then inserts those examples as additional context for the SQL generation prompt.

This tutorial uses vector search to find examples as dynamically and quickly as possible. To use vector search with dynamic few-shot prompting, you must generate vector embeddings for your example pairs, and then store the example pairs and their embeddings in a vector database, also referred to as the prompt store. Your application can then run a vector search on your prompt store to find the most similar examples for any new query.

  1. Create a collection in a Serverless (Vector) database to store the Spider dataset and its embeddings:

    astra_db_collection = astra_db.create_collection(
        ASTRA_DB_COLLECTION_NAME,
        definition=(
            CollectionDefinition.builder()
            .set_vector_dimension(1536)
            .build()
        ),
        # this allows `ndarray` objects in documents to insert:
        spawn_api_options=APIOptions(
            serdes_options=SerdesOptions(
                unroll_iterables_to_lists=True,
            ),
        ),
    )
  2. Generate embeddings for the questions in the example dataset:

    # Batching these requests speeds the calculation up by a factor of about 13
    EMBEDDING_BATCH_SIZE = 40
    
    tqdm.pandas()
    
    examples_embeddings = (
        examples_df
        .groupby(examples_df.index // EMBEDDING_BATCH_SIZE)["question"]
        .progress_transform(lambda questions: compute_embedding_list(questions.tolist()))
    )
    examples_df["question_embedding"] = examples_embeddings
    Response
    100%|██████████| 26/26 [00:29<00:00,  1.24it/s]
  3. Store the example dataset and the associated embeddings in your Serverless (Vector) database:

    # The documents' _id in the collection will be the indices in the DataFrame:
    examples_df["_id"] = examples_df.index
    # The embeddings must be placed in the special "$vector" key:
    examples_documents = examples_df.rename(
        columns={"question_embedding": "$vector"}
    ).to_dict(orient="records")
    
    result = astra_db_collection.insert_many(
        examples_documents,
        # these are ~1000 big documents, so let's be overzealous with timeouts:
        timeout_ms=45000,          # for the whole insertion (which is chunked) ...
        request_timeout_ms=20000,  # and this is for each HTTP request (one chunk).
    )
    print(result)
    Response
    CollectionInsertManyResult(inserted_ids=[0, 1, 2, 3, 4 ... (1023 total)], raw_results=...)
  4. Define a few useful prompt templates:

    SQL_PROMPT_TEMPLATE = """Convert text to SQL:
    
    [Schema : (values)]: {schema_str}
    
    [Column names (type)]: {column_str}
    
    [Primary Keys]: {pk_str}
    
    [Foreign Keys]: {fk_str}
    
    [Q]: {question}
    
    [SQL]: """
    
    QUESTION_PREFIX_STR = (
        "Given the following schema information, generate valid SQL "
        "to answer the provided query. Enclose the query in markdown "
        "code-block syntax.\n"
    )
    
    EXAMPLE_PREFIX_STR = "Here is an example: "
  5. Prepare formatting utilities used by the few-shot prompting function:

    @lru_cache
    def _get_spider_schema_by_db_id(db_id: str) -> pd.Series:
        return spider_schema_df[spider_schema_df["db_id"] == db_id].iloc[0]
    
    
    def _format_schema(
        db_id: str, spider_schema_row: str
    ) -> Tuple[str, str, str, str]:
        """
        Converts the existing schema format of spider_schema_df to
        (schema, columns, primary keys, foreign keys) as used by SQL-PaLM.
        """
        schema_str = f"| {db_id} "
        column_str = ""
    
        spider_schema_vt_str = spider_schema_row["Schema (values (type))"]
    
        for table_name in re.findall(r"([^,: ]*) :", spider_schema_vt_str):
            schema_str += f"| {table_name} : "
    
            start_ndx = spider_schema_vt_str.find(table_name + " :")
            end_ndx = spider_schema_vt_str.find(
                ":", start_ndx + len(table_name) + 4
            )
            if end_ndx == -1:
                end_ndx = len(spider_schema_vt_str)
            current_substr = spider_schema_vt_str[start_ndx:end_ndx]
            for col_name, col_type in re.findall(
                r" ([^:,]*) \(([^,]*)\)", current_substr
            ):
                schema_str += f"{col_name} , "
                column_str += f"{table_name} : {col_name} ({col_type}) | "
    
            schema_str = schema_str[:-2]
        column_str = column_str[:-2]
    
        return (
            schema_str + ";",
            column_str + ";",
            spider_schema_row["Primary Keys"],
            spider_schema_row["Foreign Keys"],
        )
    
    
    def _format_sql_prompt(db_id: str, question: str) -> str:
        """
        Returns a formatted section of the prompt describing the DB Schema.
        This core logic is factored for later re-use in the few-shot approach.
        """
        spider_schema_row = _get_spider_schema_by_db_id(db_id)
        schema_str, column_str, pk_str, fk_str = _format_schema(
            db_id, spider_schema_row
        )
        sql_prompt_str = SQL_PROMPT_TEMPLATE.format(
            schema_str=schema_str,
            column_str=column_str,
            pk_str=pk_str,
            fk_str=fk_str,
            question=question,
        )
        return sql_prompt_str
    
    
    def _format_question_prompt(db_id: str, question: str) -> str:
        core_prompt = _format_sql_prompt(db_id, question)
        return QUESTION_PREFIX_STR + core_prompt
    
    
    def _format_example_prompt(db_id: str, question: str, example_sql: str) -> str:
        core_prompt = _format_sql_prompt(db_id, question)
        return EXAMPLE_PREFIX_STR + core_prompt + example_sql + "\n\n"
    
    
    def _clean_sql_response(response: str) -> str:
        _response2 = response.replace("```sql", "```")
        _start = _response2.find("```")
        if _start < 0:
            # assume the whole response is the SQL, no markdown stuff
            # (gpt-3.5 tends to do this)
            return "\n".join(l.strip() for l in _response2.split("\n") if l.strip())
        rest = _response2[_start+3:]
        _stop = rest.find("```")
        if _stop < 0:
            raise ValueError("Invalid answer from LLM.")
        _sql = rest[:_stop].strip()
        return _sql
    
    
    def generate_sql(
        question: str,
        db_id: str,
        prompt_fn: Callable[[str, str], str],
        debug_prompt: bool = False,
    ) -> str:
        """Use the LLM to generate a SQL response for a given Spider question."""
        prompt = prompt_fn(db_id=db_id, question=question)
    
        if debug_prompt:
            print(f"LLM prompt for SQL generation:\n======\n{prompt}\n======")
    
        response = llm_completion(prompt)
        return _clean_sql_response(response)
  6. Construct a generic few-shot prompt function:

    def few_shot_prompt(
        db_id: str, question: str, example_indices: List[int]
    ) -> str:
        """
        Create a few shot prompt using the given indices to construct the example.
        """
        prompt = ""
        for e_index in example_indices:
            e_row = examples_df.loc[e_index]
            prompt += _format_example_prompt(
                e_row["db_id"], e_row["question"], e_row["query"]
            )
        return prompt + _format_question_prompt(db_id, question)
  7. Write a dynamic few-shot prompt using vector search on your Astra DB database:

    def dynamic_few_shot_prompt_fn(db_id: str, question: str) -> str:
        embedding = compute_embedding(question)
        documents = astra_db_collection.find(sort={"$vector": embedding}, limit=2)
        closest_q_ids = [
            doc["_id"] for doc in documents
        ]
        return few_shot_prompt(db_id, question, example_indices=closest_q_ids)
  8. Generate SQL from the test questions using the dynamic few-shot prompt:

    generated_queries_dynamic = {}
    
    for test_i, test_row in test_queries.items():
        if test_i not in generated_queries_dynamic:
            generated_queries_dynamic[test_i] =  generate_sql(
                question=test_row["question"],
                db_id=test_row["db_id"],
                prompt_fn=dynamic_few_shot_prompt_fn,
            )
            print(
                f"\n====\nAdded for Q[{test_i}]="
                f"'{test_row['question']}' -> SQL:\n----"
            )
            print(generated_queries_dynamic[test_i])

    If you want to inspect the prompts used for the SQL-generating LLM call, you can add debug_prompt=True to the generate_sql function call. The prompts use different examples depending on the question asked.

    Example debug info for one question
    LLM prompt for SQL generation:
    ======
    Here is an example: Convert text to SQL:
    
    [Schema : (values)]: | pets_1 | Student : StuID , LName , Fname , Age , Sex , Major , Advisor , city_code | Has_Pet : StuID , PetID | Pets : PetID , PetType , pet_age , weight ;
    
    [Column names (type)]: Student : StuID (number) | Student : LName (text) | Student : Fname (text) | Student : Age (number) | Student : Sex (text) | Student : Major (number) | Student : Advisor (number) | Student : city_code (text) | Has_Pet : StuID (number) | Has_Pet : PetID (number) | Pets : PetID (number) | Pets : PetType (text) | Pets : pet_age (number) | Pets : weight (number) ;
    
    [Primary Keys]: Student : StuID | Pets : PetID
    
    [Foreign Keys]: Has_Pet : StuID equals Student : StuID | Has_Pet : PetID equals Pets : PetID
    
    [Q]: Find the first name of students who have cat or dog pet.
    
    [SQL]: SELECT DISTINCT T1.Fname FROM student AS T1 JOIN has_pet AS T2 ON T1.stuid  =  T2.stuid JOIN pets AS T3 ON T3.petid  =  T2.petid WHERE T3.pettype  =  'cat' OR T3.pettype  =  'dog'
    
    Here is an example: Convert text to SQL:
    
    [Schema : (values)]: | pets_1 | Student : StuID , LName , Fname , Age , Sex , Major , Advisor , city_code | Has_Pet : StuID , PetID | Pets : PetID , PetType , pet_age , weight ;
    
    [Column names (type)]: Student : StuID (number) | Student : LName (text) | Student : Fname (text) | Student : Age (number) | Student : Sex (text) | Student : Major (number) | Student : Advisor (number) | Student : city_code (text) | Has_Pet : StuID (number) | Has_Pet : PetID (number) | Pets : PetID (number) | Pets : PetType (text) | Pets : pet_age (number) | Pets : weight (number) ;
    
    [Primary Keys]: Student : StuID | Pets : PetID
    
    [Foreign Keys]: Has_Pet : StuID equals Student : StuID | Has_Pet : PetID equals Pets : PetID
    
    [Q]: Find the first name of students who have both cat and dog pets .
    
    [SQL]: select t1.fname from student as t1 join has_pet as t2 on t1.stuid  =  t2.stuid join pets as t3 on t3.petid  =  t2.petid where t3.pettype  =  'cat' intersect select t1.fname from student as t1 join has_pet as t2 on t1.stuid  =  t2.stuid join pets as t3 on t3.petid  =  t2.petid where t3.pettype  =  'dog'
    
    Given the following schema information, generate valid SQL to answer the provided query. Enclose the query in markdown code-block syntax.
    Convert text to SQL:
    
    [Schema : (values)]: | pets_1 | Student : StuID , LName , Fname , Age , Sex , Major , Advisor , city_code | Has_Pet : StuID , PetID | Pets : PetID , PetType , pet_age , weight ;
    
    [Column names (type)]: Student : StuID (number) | Student : LName (text) | Student : Fname (text) | Student : Age (number) | Student : Sex (text) | Student : Major (number) | Student : Advisor (number) | Student : city_code (text) | Has_Pet : StuID (number) | Has_Pet : PetID (number) | Pets : PetID (number) | Pets : PetType (text) | Pets : pet_age (number) | Pets : weight (number) ;
    
    [Primary Keys]: Student : StuID | Pets : PetID
    
    [Foreign Keys]: Has_Pet : StuID equals Student : StuID | Has_Pet : PetID equals Pets : PetID
    
    [Q]: Find the first name and age of students who have a dog but do not have a cat as a pet.
    
    [SQL]:
    ======
    
    ===
    Generated SQL:
    SELECT T1.Fname, T1.Age
    FROM Student AS T1
    JOIN Has_Pet AS T2 ON T1.StuID = T2.StuID
    JOIN Pets AS T3 ON T2.PetID = T3.PetID
    WHERE T3.PetType = 'dog'
    AND T1.StuID NOT IN (
    SELECT T1.StuID
    FROM Student AS T1
    JOIN Has_Pet AS T2 ON T1.StuID = T2.StuID
    JOIN Pets AS T3 ON T2.PetID = T3.PetID
    WHERE T3.PetType = 'cat'
    )
  9. Evaluate the queries to assess the LLM’s performance:

    report = eval_generated_queries(generated_queries_dynamic)
    display(report)

    The response is a report in the form of a table. In the following example, the report shows that the model correctly answered all test questions. For example:

    Example query evaluation report
    DB ID Question Correct Error

    0

    concert_singer

    What is the name and capacity of the stadium w…​

    True

    None

    1

    concert_singer

    List singer names and number of concerts for e…​

    True

    None

    2

    pets_1

    Find the number of pets whose weight is heavie…​

    True

    None

    3

    pets_1

    Find the number of distinct type of pets.

    True

    None

    4

    pets_1

    Find the first name and age of students who ha…​

    True

    None

    5

    pets_1

    Find the first name and age of students who ha…​

    True

    None

    If you run this tutorial’s Colab notebook, you can verify that the accuracy was lower before enriching the model with dynamic few-shot prompting.

    However, be aware that LLMs are notably non-deterministic, and the results can differ between runs. It is possible to occasionally encounter incorrectly-generated SQL.

Conclusion

Text-to-sql is a critical skill for an LLM that enables a variety of GenAI use cases and automation opportunities.

You can often deploy AI-based solutions at various levels of sophistication. In this tutorial, adopting a dynamic approach to few-shot prompting improved SQL generation to nearly 100% accuracy.

Astra DB Serverless is a powerful tool for implementing dynamic few-shot prompting. Serverless (Vector) databases can serve as prompt stores for for your question-query examples and their embeddings. Then, you applications can use Astra DB’s vector search capabilities for on-demand retrieval of relevant question-query examples.

Was this helpful?

Give Feedback

How can we improve the documentation?

© 2025 DataStax | Privacy policy | Terms of use | Manage Privacy Choices

Apache, Apache Cassandra, Cassandra, Apache Tomcat, Tomcat, Apache Lucene, Apache Solr, Apache Hadoop, Hadoop, Apache Pulsar, Pulsar, Apache Spark, Spark, Apache TinkerPop, TinkerPop, Apache Kafka and Kafka are either registered trademarks or trademarks of the Apache Software Foundation or its subsidiaries in Canada, the United States and/or other countries. Kubernetes is the registered trademark of the Linux Foundation.

General Inquiries: +1 (650) 389-6000, info@datastax.com