Graph Store Example

Create a graph store and use it to answer questions with graph RAG chains.

Prerequisites

  • An active DataStax AstraDB

  • Python 3.11 (to use Union and self hints)

  • OpenAI API key

Environment

  1. Install dependencies:

    pip install "ragstack-ai-langchain[knowledge-store]" beautifulsoup4 markdownify python-dotenv
  2. Create a .env file with the following environment variables:

    env
    OPENAI_API_KEY="<your key here>"
    LANGCHAIN_TRACING_V2=true
    LANGCHAIN_API_KEY="<your key here>"
    ASTRA_DB_DATABASE_ID="<your DB ID here>"
    ASTRA_DB_APPLICATION_TOKEN="<your key here>"
    ASTRA_DB_KEYSPACE="<your keyspace here>"

    If you’re running the notebook in Colab, run the cell using getpass to set the necessary environment variables.

Create an application to scrape and load content

  1. Create an application that scrapes sitemaps, loads content, and creates a graph store with the content.

  2. Import dependencies:

    import asyncio
    
    import requests
    from bs4 import BeautifulSoup
    from dotenv import load_dotenv
    from markdownify import MarkdownConverter
    
    import cassio
    from langchain_community.document_loaders import AsyncHtmlLoader
    from langchain_core.documents import Document
    from langchain_openai import OpenAIEmbeddings
    from ragstack_knowledge_store.graph_store import CONTENT_ID
    from ragstack_langchain.graph_store import CassandraGraphStore
    from ragstack_langchain.graph_store.extractors import HtmlLinkEdgeExtractor
    from typing import AsyncIterator, Iterable

Scrape the URLs from sitemaps and process content

  1. Declare constant values for the sitemaps and extra URLs to load. This example only loads one sitemap from the documentation to limit token usage.

  2. Use the BeautifulSoup library to parse the XML content of each sitemap and get a list of URLs.

    SITEMAPS = [
        "https://docs.datastax.com/en/sitemap-astra-db-vector.xml",
    ]
    EXTRA_URLS = ["https://github.com/jbellis/jvector"]
    SITE_PREFIX = "astra"
    
    def load_pages(sitemap_url):
        r = requests.get(
            sitemap_url,
            headers={
                "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0",
            },
        )
        xml = r.text
        soup = BeautifulSoup(xml, features="xml")
        url_tags = soup.find_all("url")
        for url in url_tags:
            yield (url.find("loc").text)
    
    URLS = [url for sitemap_url in SITEMAPS for url in load_pages(sitemap_url)] + EXTRA_URLS
    
    markdown_converter = MarkdownConverter(heading_style="ATX")
    html_link_extractor = HtmlLinkEdgeExtractor()
    
    def select_content(soup: BeautifulSoup, url: str) -> BeautifulSoup:
        if url.startswith("https://docs.datastax.com/en/"):
            return soup.select_one("article.doc")
        elif url.startswith("https://github.com"):
            return soup.select_one("article.entry-content")
        else:
            return soup
  3. The load_and_process_pages function fetches web pages from the URL list, retrieves content from them, and converts the content to Markdown. It also extracts links (<a href="…​">) from the content to create edges between the documents.

    async def load_and_process_pages(urls: Iterable[str]) -> AsyncIterator[Document]:
        loader = AsyncHtmlLoader(
            urls,
            requests_per_second=4,
            header_template={"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0"},
        )
        async for html in loader.alazy_load():
            url = html.metadata["source"]
            html.metadata[CONTENT_ID] = url
            soup = BeautifulSoup(html.page_content, "html.parser")
            content = select_content(soup, url)
            html_link_extractor.extract_one(html, content)
            html.page_content = markdown_converter.convert_soup(content)
            yield html

Initialize environment and graph store

  1. Initialize the Cassio library for talking to Cassandra / Astra DB and create the GraphStore.

    load_dotenv()
    cassio.init(auto=True)
    embeddings = OpenAIEmbeddings()
    graph_store = CassandraGraphStore(
        embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges"
    )
  2. Fetch pages and asynchronously write them to the graph store in batches of 50.

    docs = []
    
    async def process_documents():
        not_found, found = 0, 0
        docs = []
        async for doc in load_and_process_pages(URLS):
            if doc.page_content.startswith("\n# Page Not Found"):
                not_found += 1
                continue
    
            docs.append(doc)
            found += 1
    
            if len(docs) >= 50:
                graph_store.add_documents(docs)
                docs.clear()
    
        if docs:
            graph_store.add_documents(docs)
    
        print(f"{not_found} (of {not_found + found}) URLs were not found")
    
    if __name__ == "__main__":
        asyncio.run(process_documents())

    You will see output like this until all pages are fetched and edges are created:

    ....
    Fetching pages: 100%|##########| 1368/1368 [04:23<00:00,  5.19it/s]
    ....
    
    ....
    Added 120 edges
    96 (of 1368) URLs were not found
    ....

Create an application to execute RAG chains

  1. Create a new application in the same directory as the previous application.

  2. Import dependencies:

    import cassio
    
    from dotenv import load_dotenv
    
    from langchain_openai import OpenAIEmbeddings, ChatOpenAI
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_core.runnables import RunnablePassthrough
    from langchain_core.output_parsers import StrOutputParser
    from ragstack_langchain.graph_store import CassandraGraphStore
  3. Load environment variables and declare constants. This example uses the following QUESTION because the ideal answer should be concise and in-depth, based on how the vector indexing is actually implemented.

    SITE_PREFIX = "astra"
    QUESTION = "What vector indexing algorithms does Astra use?"
  4. Initialize a session with the embeddings and graph store.

    load_dotenv()
    cassio.init(auto=True)
    embeddings = OpenAIEmbeddings()
    graph_store = CassandraGraphStore(
        embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges"
    )
  5. Define the LLM and prompt template.

    llm = ChatOpenAI(model="gpt-3.5-turbo")
    template = """You are a helpful technical support bot. You should provide complete answers explaining the options the user has available to address their problem. Answer the question based only on the following context:
    {context}
    
    Question: {question}
    """
    prompt = ChatPromptTemplate.from_template(template)
  6. Create a function to format the documents. This function can also limit the number of documents and the length of the content to limit token usage.

    def format_docs(docs, max_length=200, max_docs=50):
        docs = docs[:max_docs]
    
        formatted = "\n\n".join(
            f"From {doc.metadata['content_id']}: {doc.page_content[:max_length]}..."
            if len(doc.page_content) > max_length else
            f"From {doc.metadata['content_id']}: {doc.page_content}"
            for doc in docs
        )
        return formatted

Create and execute the RAG chains

Create a chain for each retrieval method.

  1. The notebook uses the IPython library to display the results in Markdown format, but this example just uses print to display the results, with some added text so you can see which retrieval method is being used.

    def run_and_render(chain, question, description):
        print(f"\nRunning chain: {description}")
        result = chain.invoke(question)
        print("Output:")
        print(result)
  2. Create a vector retriever chain that only uses vector similarity.

    # Depth 0 doesn't traverses edges and is equivalent to vector similarity only.
    vector_retriever = graph_store.as_retriever(search_kwargs={"depth": 0})
    
    vector_rag_chain = (
        {"context": vector_retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    
    run_and_render(vector_rag_chain, QUESTION, "Vector-Only Retrieval")
  3. Create a graph traversal retriever chain that uses vector similarity and traverses one level of edges.

    # Depth 1 does vector similarity and then traverses 1 level of edges.
    graph_retriever = graph_store.as_retriever(search_kwargs={"depth": 1})
    
    graph_rag_chain = (
        {"context": graph_retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    
    run_and_render(graph_rag_chain, QUESTION, "Graph Traversal")
  4. Create an MMR graph traversal retriever chain that uses vector similarity and traverses two levels of edges.

    mmr_graph_retriever = graph_store.as_retriever(
        search_type="mmr_traversal",
        search_kwargs={
            "k": 4,
            "fetch_k": 10,
            "depth": 2,
            # "score_threshold": 0.2,
        },
    )
    
    mmr_graph_rag_chain = (
        {"context": mmr_graph_retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    run_and_render(mmr_graph_rag_chain, QUESTION, "MMR Graph Traversal")
  5. Finally, run the chains and display the results.

    print("\nDocument retrieval results:")
    for i, doc in enumerate(vector_retriever.invoke(QUESTION)):
        print(f"Vector [{i}]:    {doc.metadata['content_id']}")
    
    for i, doc in enumerate(graph_retriever.invoke(QUESTION)):
        print(f"Graph [{i}]:     {doc.metadata['content_id']}")
    
    for i, doc in enumerate(mmr_graph_retriever.invoke(QUESTION)):
        print(f"MMR Graph [{i}]: {doc.metadata['content_id']}")

    You will see output like this:

    Results
    Running chain: Vector-Only Retrieval
    Output:
    Astra DB Serverless uses the Vector Search feature, which allows for vector indexing algorithms to be utilized for similarity searches within the database. The specific vector indexing algorithms used by Astra DB Serverless are not explicitly mentioned in the provided context. However, the Vector Search feature enables data to be compared by similarity within the database, even if it is not explicitly defined by a connection. This feature is particularly useful for machine learning models and AI applications that require similarity searches based on vectors.
    
    Running chain: Depth 1 Retrieval
    Output:
    Astra DB Serverless uses the following vector indexing algorithms:
    
    1. Locality Sensitive Hashing (LSH)
    2. Product Quantization (PQ)
    3. Hierarchical Navigable Small World Graphs (HNSW)
    
    Running chain: MMR Based Retrieval
    Output:
    Astra DB Serverless offers both Serverless (Vector) and Serverless (Non-Vector) databases. The vector databases in Astra use vector indexing algorithms for efficient search operations. The specific vector indexing algorithms used by Astra are not explicitly mentioned in the provided context. However, vector databases typically utilize approximate nearest neighbor search algorithms for efficient searching in high-dimensional data spaces. These algorithms are designed to overcome the limitations of exact nearest neighbor search in higher dimensions. For more specific information on the vector indexing algorithms used by Astra, you may refer to the official Astra documentation or contact DataStax support for further assistance.
    
    Document retrieval results:
    Vector [0]:    https://docs.datastax.com/en/astra-db-serverless/get-started/concepts.html
    Vector [1]:    https://docs.datastax.com/en/cql/astra/getting-started/vector-search-quickstart.html
    Vector [2]:    https://docs.datastax.com/en/astra-db-serverless/databases/database-overview.html
    Vector [3]:    https://docs.datastax.com/en/astra-db-serverless/get-started/astra-db-introduction.html
    Graph [0]:     https://docs.datastax.com/en/astra-db-serverless/get-started/concepts.html
    Graph [1]:     https://docs.datastax.com/en/cql/astra/getting-started/vector-search-quickstart.html
    Graph [2]:     https://docs.datastax.com/en/cql/astra/developing/indexing/indexing-concepts.html
    Graph [3]:     https://docs.datastax.com/en/astra-db-serverless/databases/database-overview.html
    Graph [4]:     https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html
    Graph [5]:     https://docs.datastax.com/en/astra-db-serverless/integrations/semantic-kernel.html
    Graph [6]:     https://docs.datastax.com/en/astra-db-serverless/tutorials/chatbot.html
    Graph [7]:     https://docs.datastax.com/en/astra-db-serverless/tutorials/recommendations.html
    Graph [8]:     https://docs.datastax.com/en/cql/astra/developing/indexing/sai/sai-overview.html
    Graph [9]:     https://docs.datastax.com/en/glossary/index.html
    Graph [10]:     https://github.com/jbellis/jvector
    Graph [11]:     https://docs.datastax.com/en/astra-db-serverless/administration/maintenance-schedule.html
    Graph [12]:     https://docs.datastax.com/en/astra-db-serverless/administration/support.html
    Graph [13]:     https://docs.datastax.com/en/astra-db-serverless/databases/backup-restore.html
    Graph [14]:     https://docs.datastax.com/en/astra-db-serverless/databases/database-limits.html
    MMR Graph [0]: https://docs.datastax.com/en/astra-db-serverless/get-started/concepts.html
    MMR Graph [1]: https://docs.datastax.com/en/astra-db-serverless/cli-reference/astra-cli.html
    MMR Graph [2]: https://github.com/jbellis/jvector
    MMR Graph [3]: https://docs.datastax.com/en/cql/astra/developing/indexing/indexing-concepts.html

Conclusion

With vector-only retrieval, you retrieved chunks from the Astra documentation explaining that it used JVector. Since it didn’t follow the link to JVector on GitHub it didn’t actually answer the question.

The graph retrieval started with the same set of chunks, but it followed the edge to the documents we loaded from GitHub. This allowed the LLM to read in more depth how JVector is implemented, which allowed it to answer the question more clearly and with more detail.

The MMR graph retrieval went even further, following two levels of edges. This allowed the LLM to read even more about JVector and provide an even more detailed answer.

Complete code examples

Load
import asyncio

import requests
from bs4 import BeautifulSoup
from dotenv import load_dotenv
from markdownify import MarkdownConverter

import cassio
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from ragstack_knowledge_store.graph_store import CONTENT_ID
from ragstack_langchain.graph_store import CassandraGraphStore
from ragstack_langchain.graph_store.extractors import HtmlLinkEdgeExtractor
from typing import AsyncIterator, Iterable

SITEMAPS = [
    "https://docs.datastax.com/en/sitemap-astra-db-vector.xml",
]
EXTRA_URLS = ["https://github.com/jbellis/jvector"]
SITE_PREFIX = "astra"

def load_pages(sitemap_url):
    r = requests.get(
        sitemap_url,
        headers={
            "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0",
        },
    )
    xml = r.text
    soup = BeautifulSoup(xml, features="xml")
    url_tags = soup.find_all("url")
    for url in url_tags:
        yield (url.find("loc").text)

URLS = [url for sitemap_url in SITEMAPS for url in load_pages(sitemap_url)] + EXTRA_URLS

markdown_converter = MarkdownConverter(heading_style="ATX")
html_link_extractor = HtmlLinkEdgeExtractor()

def select_content(soup: BeautifulSoup, url: str) -> BeautifulSoup:
    if url.startswith("https://docs.datastax.com/en/"):
        return soup.select_one("article.doc")
    elif url.startswith("https://github.com"):
        return soup.select_one("article.entry-content")
    else:
        return soup

async def load_and_process_pages(urls: Iterable[str]) -> AsyncIterator[Document]:
    loader = AsyncHtmlLoader(
        urls,
        requests_per_second=4,
        header_template={"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0"},
    )
    async for html in loader.alazy_load():
        url = html.metadata["source"]
        html.metadata[CONTENT_ID] = url
        soup = BeautifulSoup(html.page_content, "html.parser")
        content = select_content(soup, url)
        html_link_extractor.extract_one(html, content)
        html.page_content = markdown_converter.convert_soup(content)
        yield html

# Setup environment and database
load_dotenv()
cassio.init(auto=True)
embeddings = OpenAIEmbeddings()
graph_store = CassandraGraphStore(
    embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges"
)

docs = []

async def process_documents():
    not_found, found = 0, 0
    docs = []
    async for doc in load_and_process_pages(URLS):
        if doc.page_content.startswith("\n# Page Not Found"):
            not_found += 1
            continue

        docs.append(doc)
        found += 1

        if len(docs) >= 50:
            graph_store.add_documents(docs)
            docs.clear()

    if docs:
        graph_store.add_documents(docs)

    print(f"{not_found} (of {not_found + found}) URLs were not found")

if __name__ == "__main__":
    asyncio.run(process_documents())
Retrieve
import cassio

from dotenv import load_dotenv

from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from ragstack_langchain.graph_store import CassandraGraphStore

load_dotenv()

SITE_PREFIX = "astra"
QUESTION = "What vector indexing algorithms does Astra use?"

# Initialize embeddings and graph store
cassio.init(auto=True)
embeddings = OpenAIEmbeddings()
graph_store = CassandraGraphStore(
    embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges"
)

llm = ChatOpenAI(model="gpt-3.5-turbo")
template = """You are a helpful technical support bot. You should provide complete answers explaining the options the user has available to address their problem. Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

def format_docs(docs, max_length=200, max_docs=5):
    # Limit the number of documents
    docs = docs[:max_docs]

    formatted = "\n\n".join(
        f"From {doc.metadata['content_id']}: {doc.page_content[:max_length]}..."
        if len(doc.page_content) > max_length else
        f"From {doc.metadata['content_id']}: {doc.page_content}"
        for doc in docs
    )
    return formatted

def run_and_render(chain, question, description):
    print(f"\nRunning chain: {description}")
    result = chain.invoke(question)
    print("Output:")
    print(result)

# Vector-Only Retrieval
vector_retriever = graph_store.as_retriever(search_kwargs={"depth": 0})
vector_rag_chain = (
    {"context": vector_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)
run_and_render(vector_rag_chain, QUESTION, "Vector-Only Retrieval")

# Depth 1 and MMR retrieval
graph_retriever = graph_store.as_retriever(search_kwargs={"depth": 1})
mmr_graph_retriever = graph_store.as_retriever(
    search_type="mmr_traversal",
    search_kwargs={
        "k": 4,
        "fetch_k": 10,
        "depth": 2
    },
)

graph_rag_chain = (
    {"context": graph_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)
run_and_render(graph_rag_chain, QUESTION, "Depth 1 Retrieval")

mmr_graph_rag_chain = (
    {"context": mmr_graph_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)
run_and_render(mmr_graph_rag_chain, QUESTION, "MMR Based Retrieval")

print("\nDocument retrieval results:")
for i, doc in enumerate(vector_retriever.invoke(QUESTION)):
    print(f"Vector [{i}]:    {doc.metadata['content_id']}")

for i, doc in enumerate(graph_retriever.invoke(QUESTION)):
    print(f"Graph [{i}]:     {doc.metadata['content_id']}")

for i, doc in enumerate(mmr_graph_retriever.invoke(QUESTION)):
    print(f"MMR Graph [{i}]: {doc.metadata['content_id']}")

Was this helpful?

Give Feedback

How can we improve the documentation?

© 2024 DataStax | Privacy policy | Terms of use

Apache, Apache Cassandra, Cassandra, Apache Tomcat, Tomcat, Apache Lucene, Apache Solr, Apache Hadoop, Hadoop, Apache Pulsar, Pulsar, Apache Spark, Spark, Apache TinkerPop, TinkerPop, Apache Kafka and Kafka are either registered trademarks or trademarks of the Apache Software Foundation or its subsidiaries in Canada, the United States and/or other countries. Kubernetes is the registered trademark of the Linux Foundation.

General Inquiries: +1 (650) 389-6000, info@datastax.com