Graph Store Example
Create a graph store and use it to answer questions with graph RAG chains.
Prerequisites
-
An active DataStax AstraDB
-
Python 3.11 (to use
Union
andself
hints) -
OpenAI API key
Environment
-
Install dependencies:
pip install "ragstack-ai-langchain[knowledge-store]" beautifulsoup4 markdownify python-dotenv
-
Create a
.env
file with the following environment variables:envOPENAI_API_KEY="<your key here>" LANGCHAIN_TRACING_V2=true LANGCHAIN_API_KEY="<your key here>" ASTRA_DB_DATABASE_ID="<your DB ID here>" ASTRA_DB_APPLICATION_TOKEN="<your key here>" ASTRA_DB_KEYSPACE="<your keyspace here>"
If you’re running the notebook in Colab, run the cell using
getpass
to set the necessary environment variables.
Create an application to scrape and load content
-
Create an application that scrapes sitemaps, loads content, and creates a graph store with the content.
-
Import dependencies:
import asyncio import requests from bs4 import BeautifulSoup from dotenv import load_dotenv from markdownify import MarkdownConverter import cassio from langchain_community.document_loaders import AsyncHtmlLoader from langchain_core.documents import Document from langchain_openai import OpenAIEmbeddings from ragstack_knowledge_store.graph_store import CONTENT_ID from ragstack_langchain.graph_store import CassandraGraphStore from ragstack_langchain.graph_store.extractors import HtmlLinkEdgeExtractor from typing import AsyncIterator, Iterable
Scrape the URLs from sitemaps and process content
-
Declare constant values for the sitemaps and extra URLs to load. This example only loads one sitemap from the documentation to limit token usage.
-
Use the BeautifulSoup library to parse the XML content of each sitemap and get a list of URLs.
SITEMAPS = [ "https://docs.datastax.com/en/sitemap-astra-db-vector.xml", ] EXTRA_URLS = ["https://github.com/jbellis/jvector"] SITE_PREFIX = "astra" def load_pages(sitemap_url): r = requests.get( sitemap_url, headers={ "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0", }, ) xml = r.text soup = BeautifulSoup(xml, features="xml") url_tags = soup.find_all("url") for url in url_tags: yield (url.find("loc").text) URLS = [url for sitemap_url in SITEMAPS for url in load_pages(sitemap_url)] + EXTRA_URLS markdown_converter = MarkdownConverter(heading_style="ATX") html_link_extractor = HtmlLinkEdgeExtractor() def select_content(soup: BeautifulSoup, url: str) -> BeautifulSoup: if url.startswith("https://docs.datastax.com/en/"): return soup.select_one("article.doc") elif url.startswith("https://github.com"): return soup.select_one("article.entry-content") else: return soup
-
The
load_and_process_pages
function fetches web pages from the URL list, retrieves content from them, and converts the content to Markdown. It also extracts links (<a href="…">
) from the content to create edges between the documents.async def load_and_process_pages(urls: Iterable[str]) -> AsyncIterator[Document]: loader = AsyncHtmlLoader( urls, requests_per_second=4, header_template={"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0"}, ) async for html in loader.alazy_load(): url = html.metadata["source"] html.metadata[CONTENT_ID] = url soup = BeautifulSoup(html.page_content, "html.parser") content = select_content(soup, url) html_link_extractor.extract_one(html, content) html.page_content = markdown_converter.convert_soup(content) yield html
Initialize environment and graph store
-
Initialize the Cassio library for talking to Cassandra / Astra DB and create the
GraphStore
.load_dotenv() cassio.init(auto=True) embeddings = OpenAIEmbeddings() graph_store = CassandraGraphStore( embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges" )
-
Fetch pages and asynchronously write them to the graph store in batches of 50.
docs = [] async def process_documents(): not_found, found = 0, 0 docs = [] async for doc in load_and_process_pages(URLS): if doc.page_content.startswith("\n# Page Not Found"): not_found += 1 continue docs.append(doc) found += 1 if len(docs) >= 50: graph_store.add_documents(docs) docs.clear() if docs: graph_store.add_documents(docs) print(f"{not_found} (of {not_found + found}) URLs were not found") if __name__ == "__main__": asyncio.run(process_documents())
You will see output like this until all pages are fetched and edges are created:
.... Fetching pages: 100%|##########| 1368/1368 [04:23<00:00, 5.19it/s] .... .... Added 120 edges 96 (of 1368) URLs were not found ....
Create an application to execute RAG chains
-
Create a new application in the same directory as the previous application.
-
Import dependencies:
import cassio from dotenv import load_dotenv from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from ragstack_langchain.graph_store import CassandraGraphStore
-
Load environment variables and declare constants. This example uses the following
QUESTION
because the ideal answer should be concise and in-depth, based on how the vector indexing is actually implemented.SITE_PREFIX = "astra" QUESTION = "What vector indexing algorithms does Astra use?"
-
Initialize a session with the embeddings and graph store.
load_dotenv() cassio.init(auto=True) embeddings = OpenAIEmbeddings() graph_store = CassandraGraphStore( embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges" )
-
Define the LLM and prompt template.
llm = ChatOpenAI(model="gpt-3.5-turbo") template = """You are a helpful technical support bot. You should provide complete answers explaining the options the user has available to address their problem. Answer the question based only on the following context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template)
-
Create a function to format the documents. This function can also limit the number of documents and the length of the content to limit token usage.
def format_docs(docs, max_length=200, max_docs=50): docs = docs[:max_docs] formatted = "\n\n".join( f"From {doc.metadata['content_id']}: {doc.page_content[:max_length]}..." if len(doc.page_content) > max_length else f"From {doc.metadata['content_id']}: {doc.page_content}" for doc in docs ) return formatted
Create and execute the RAG chains
Create a chain for each retrieval method.
-
The notebook uses the
IPython
library to display the results in Markdown format, but this example just usesprint
to display the results, with some added text so you can see which retrieval method is being used.def run_and_render(chain, question, description): print(f"\nRunning chain: {description}") result = chain.invoke(question) print("Output:") print(result)
-
Create a vector retriever chain that only uses vector similarity.
# Depth 0 doesn't traverses edges and is equivalent to vector similarity only. vector_retriever = graph_store.as_retriever(search_kwargs={"depth": 0}) vector_rag_chain = ( {"context": vector_retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) run_and_render(vector_rag_chain, QUESTION, "Vector-Only Retrieval")
-
Create a graph traversal retriever chain that uses vector similarity and traverses one level of edges.
# Depth 1 does vector similarity and then traverses 1 level of edges. graph_retriever = graph_store.as_retriever(search_kwargs={"depth": 1}) graph_rag_chain = ( {"context": graph_retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) run_and_render(graph_rag_chain, QUESTION, "Graph Traversal")
-
Create an MMR graph traversal retriever chain that uses vector similarity and traverses two levels of edges.
mmr_graph_retriever = graph_store.as_retriever( search_type="mmr_traversal", search_kwargs={ "k": 4, "fetch_k": 10, "depth": 2, # "score_threshold": 0.2, }, ) mmr_graph_rag_chain = ( {"context": mmr_graph_retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) run_and_render(mmr_graph_rag_chain, QUESTION, "MMR Graph Traversal")
-
Finally, run the chains and display the results.
print("\nDocument retrieval results:") for i, doc in enumerate(vector_retriever.invoke(QUESTION)): print(f"Vector [{i}]: {doc.metadata['content_id']}") for i, doc in enumerate(graph_retriever.invoke(QUESTION)): print(f"Graph [{i}]: {doc.metadata['content_id']}") for i, doc in enumerate(mmr_graph_retriever.invoke(QUESTION)): print(f"MMR Graph [{i}]: {doc.metadata['content_id']}")
You will see output like this:
Results
Running chain: Vector-Only Retrieval Output: Astra DB Serverless uses the Vector Search feature, which allows for vector indexing algorithms to be utilized for similarity searches within the database. The specific vector indexing algorithms used by Astra DB Serverless are not explicitly mentioned in the provided context. However, the Vector Search feature enables data to be compared by similarity within the database, even if it is not explicitly defined by a connection. This feature is particularly useful for machine learning models and AI applications that require similarity searches based on vectors. Running chain: Depth 1 Retrieval Output: Astra DB Serverless uses the following vector indexing algorithms: 1. Locality Sensitive Hashing (LSH) 2. Product Quantization (PQ) 3. Hierarchical Navigable Small World Graphs (HNSW) Running chain: MMR Based Retrieval Output: Astra DB Serverless offers both Serverless (Vector) and Serverless (Non-Vector) databases. The vector databases in Astra use vector indexing algorithms for efficient search operations. The specific vector indexing algorithms used by Astra are not explicitly mentioned in the provided context. However, vector databases typically utilize approximate nearest neighbor search algorithms for efficient searching in high-dimensional data spaces. These algorithms are designed to overcome the limitations of exact nearest neighbor search in higher dimensions. For more specific information on the vector indexing algorithms used by Astra, you may refer to the official Astra documentation or contact DataStax support for further assistance. Document retrieval results: Vector [0]: https://docs.datastax.com/en/astra-db-serverless/get-started/concepts.html Vector [1]: https://docs.datastax.com/en/cql/astra/getting-started/vector-search-quickstart.html Vector [2]: https://docs.datastax.com/en/astra-db-serverless/databases/database-overview.html Vector [3]: https://docs.datastax.com/en/astra-db-serverless/get-started/astra-db-introduction.html Graph [0]: https://docs.datastax.com/en/astra-db-serverless/get-started/concepts.html Graph [1]: https://docs.datastax.com/en/cql/astra/getting-started/vector-search-quickstart.html Graph [2]: https://docs.datastax.com/en/cql/astra/developing/indexing/indexing-concepts.html Graph [3]: https://docs.datastax.com/en/astra-db-serverless/databases/database-overview.html Graph [4]: https://docs.datastax.com/en/astra-db-serverless/databases/embedding-generation.html Graph [5]: https://docs.datastax.com/en/astra-db-serverless/integrations/semantic-kernel.html Graph [6]: https://docs.datastax.com/en/astra-db-serverless/tutorials/chatbot.html Graph [7]: https://docs.datastax.com/en/astra-db-serverless/tutorials/recommendations.html Graph [8]: https://docs.datastax.com/en/cql/astra/developing/indexing/sai/sai-overview.html Graph [9]: https://docs.datastax.com/en/glossary/index.html Graph [10]: https://github.com/jbellis/jvector Graph [11]: https://docs.datastax.com/en/astra-db-serverless/administration/maintenance-schedule.html Graph [12]: https://docs.datastax.com/en/astra-db-serverless/administration/support.html Graph [13]: https://docs.datastax.com/en/astra-db-serverless/databases/backup-restore.html Graph [14]: https://docs.datastax.com/en/astra-db-serverless/databases/database-limits.html MMR Graph [0]: https://docs.datastax.com/en/astra-db-serverless/get-started/concepts.html MMR Graph [1]: https://docs.datastax.com/en/astra-db-serverless/cli-reference/astra-cli.html MMR Graph [2]: https://github.com/jbellis/jvector MMR Graph [3]: https://docs.datastax.com/en/cql/astra/developing/indexing/indexing-concepts.html
Conclusion
With vector-only retrieval, you retrieved chunks from the Astra documentation explaining that it used JVector. Since it didn’t follow the link to JVector on GitHub it didn’t actually answer the question.
The graph retrieval started with the same set of chunks, but it followed the edge to the documents we loaded from GitHub. This allowed the LLM to read in more depth how JVector is implemented, which allowed it to answer the question more clearly and with more detail.
The MMR graph retrieval went even further, following two levels of edges. This allowed the LLM to read even more about JVector and provide an even more detailed answer.
Complete code examples
Load
import asyncio
import requests
from bs4 import BeautifulSoup
from dotenv import load_dotenv
from markdownify import MarkdownConverter
import cassio
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from ragstack_knowledge_store.graph_store import CONTENT_ID
from ragstack_langchain.graph_store import CassandraGraphStore
from ragstack_langchain.graph_store.extractors import HtmlLinkEdgeExtractor
from typing import AsyncIterator, Iterable
SITEMAPS = [
"https://docs.datastax.com/en/sitemap-astra-db-vector.xml",
]
EXTRA_URLS = ["https://github.com/jbellis/jvector"]
SITE_PREFIX = "astra"
def load_pages(sitemap_url):
r = requests.get(
sitemap_url,
headers={
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0",
},
)
xml = r.text
soup = BeautifulSoup(xml, features="xml")
url_tags = soup.find_all("url")
for url in url_tags:
yield (url.find("loc").text)
URLS = [url for sitemap_url in SITEMAPS for url in load_pages(sitemap_url)] + EXTRA_URLS
markdown_converter = MarkdownConverter(heading_style="ATX")
html_link_extractor = HtmlLinkEdgeExtractor()
def select_content(soup: BeautifulSoup, url: str) -> BeautifulSoup:
if url.startswith("https://docs.datastax.com/en/"):
return soup.select_one("article.doc")
elif url.startswith("https://github.com"):
return soup.select_one("article.entry-content")
else:
return soup
async def load_and_process_pages(urls: Iterable[str]) -> AsyncIterator[Document]:
loader = AsyncHtmlLoader(
urls,
requests_per_second=4,
header_template={"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 Firefox/58.0"},
)
async for html in loader.alazy_load():
url = html.metadata["source"]
html.metadata[CONTENT_ID] = url
soup = BeautifulSoup(html.page_content, "html.parser")
content = select_content(soup, url)
html_link_extractor.extract_one(html, content)
html.page_content = markdown_converter.convert_soup(content)
yield html
# Setup environment and database
load_dotenv()
cassio.init(auto=True)
embeddings = OpenAIEmbeddings()
graph_store = CassandraGraphStore(
embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges"
)
docs = []
async def process_documents():
not_found, found = 0, 0
docs = []
async for doc in load_and_process_pages(URLS):
if doc.page_content.startswith("\n# Page Not Found"):
not_found += 1
continue
docs.append(doc)
found += 1
if len(docs) >= 50:
graph_store.add_documents(docs)
docs.clear()
if docs:
graph_store.add_documents(docs)
print(f"{not_found} (of {not_found + found}) URLs were not found")
if __name__ == "__main__":
asyncio.run(process_documents())
Retrieve
import cassio
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from ragstack_langchain.graph_store import CassandraGraphStore
load_dotenv()
SITE_PREFIX = "astra"
QUESTION = "What vector indexing algorithms does Astra use?"
# Initialize embeddings and graph store
cassio.init(auto=True)
embeddings = OpenAIEmbeddings()
graph_store = CassandraGraphStore(
embeddings, node_table=f"{SITE_PREFIX}_nodes", edge_table=f"{SITE_PREFIX}_edges"
)
llm = ChatOpenAI(model="gpt-3.5-turbo")
template = """You are a helpful technical support bot. You should provide complete answers explaining the options the user has available to address their problem. Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
def format_docs(docs, max_length=200, max_docs=5):
# Limit the number of documents
docs = docs[:max_docs]
formatted = "\n\n".join(
f"From {doc.metadata['content_id']}: {doc.page_content[:max_length]}..."
if len(doc.page_content) > max_length else
f"From {doc.metadata['content_id']}: {doc.page_content}"
for doc in docs
)
return formatted
def run_and_render(chain, question, description):
print(f"\nRunning chain: {description}")
result = chain.invoke(question)
print("Output:")
print(result)
# Vector-Only Retrieval
vector_retriever = graph_store.as_retriever(search_kwargs={"depth": 0})
vector_rag_chain = (
{"context": vector_retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
run_and_render(vector_rag_chain, QUESTION, "Vector-Only Retrieval")
# Depth 1 and MMR retrieval
graph_retriever = graph_store.as_retriever(search_kwargs={"depth": 1})
mmr_graph_retriever = graph_store.as_retriever(
search_type="mmr_traversal",
search_kwargs={
"k": 4,
"fetch_k": 10,
"depth": 2
},
)
graph_rag_chain = (
{"context": graph_retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
run_and_render(graph_rag_chain, QUESTION, "Depth 1 Retrieval")
mmr_graph_rag_chain = (
{"context": mmr_graph_retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
run_and_render(mmr_graph_rag_chain, QUESTION, "MMR Based Retrieval")
print("\nDocument retrieval results:")
for i, doc in enumerate(vector_retriever.invoke(QUESTION)):
print(f"Vector [{i}]: {doc.metadata['content_id']}")
for i, doc in enumerate(graph_retriever.invoke(QUESTION)):
print(f"Graph [{i}]: {doc.metadata['content_id']}")
for i, doc in enumerate(mmr_graph_retriever.invoke(QUESTION)):
print(f"MMR Graph [{i}]: {doc.metadata['content_id']}")