Build a Graph RAG system with LangChain and GraphRetriever
Graph RAG is an enhancement to retrieval-augmented generation (RAG). Graph RAG uses vector search to find semantically similar documents, and then uses graph traversal to find connected documents through relationships like hyperlinks, citations, or references. This helps find documents that might not be semantically similar but are contextually connected. Similar to RAG, the found documents serve as context for a large language model (LLM).
In this tutorial, you will build a simple graph RAG system. First, you will build a graph from a small set of cross-linked HTML pages. Then, you will use the graph during the retrieval step of RAG to provide extended context to the LLM.
Prerequisites
-
A Serverless (vector) database. If you don’t already have this, you can create one.
-
An OpenAI API key associated with a paid OpenAI account.
-
Python 3.11 or later.
Install dependencies
Install the dependencies used in this tutorial. For example:
pip install \
langchain-astradb==1.0.0 \
langchain-openai==1.1.7 \
langchain-graph-retriever==0.8.0 \
beautifulsoup4==4.14.3
Although beautifulsoup4 is not used directly by this tutorial, it is required by the usage of langchain-graph-retriever in this tutorial.
Set environment variables
Set the following environment variables:
export API_ENDPOINT=API_ENDPOINT
export APPLICATION_TOKEN=APPLICATION_TOKEN
export OPENAI_API_KEY=OPENAI_API_KEY
Replace the following:
-
API_ENDPOINT: Your database’s API endpoint. -
APPLICATION_TOKEN: An application token for your database. -
OPENAI_API_KEY: Your OpenAI API key.
Build the graph
-
Download the graph_rag_dataset.json sample dataset. This dataset is a JSON array describing a small set of cross-linked HTML pages.
-
Copy the following code into a python file, and replace the
PATH_TO_DATA_FILEplaceholder with the path to the JSON data file.This code processes the raw JSON dataset into a list of documents. Each document incudes a
metadata.hyperlinkfield, which lists the links from that document’s HTML content, and ametadata.urlfield, which contains the URL of the document. These fields are used to build the graph during retrieval in the next section.Then, the code creates a vector store that uses Astra DB as the backend and OpenAI as the embedding service. Finally, the code inserts the processed documents into the vector store.
import json import os from langchain_core.documents import Document from langchain_graph_retriever.transformers.html import HyperlinkTransformer from langchain_astradb import AstraDBVectorStore from langchain_openai import OpenAIEmbeddings endpoint = os.environ.get("API_ENDPOINT") (1) application_token = os.environ.get("APPLICATION_TOKEN") openai_api_key = os.environ.get("OPENAI_API_KEY") if not endpoint or not application_token or not openai_api_key: raise RuntimeError( "Environment variables API_ENDPOINT, APPLICATION_TOKEN, OPENAI_API_KEY must be defined." ) data_file_path = "PATH_TO_DATA_FILE" (2) # Read the JSON file and parse it into a JSON array with open(data_file_path, "r", encoding="utf8") as file: json_data = json.load(file) # Convert the JSON array into LangChain Documents documents = [ Document(page_content=data["html_doc"], metadata={"url": data["url"]}) for data in json_data ] # Extract hyperlinks from the HTML and store them in the Document metadata html_transformer = HyperlinkTransformer() documents_with_links = html_transformer.transform_documents(documents) # Creates a vector store that uses # Astra DB as the backend and # OpenAI as the embedding service print("Creating vector store...") vector_store = AstraDBVectorStore( collection_name="graph_rag_tutorial", token=application_token, api_endpoint=endpoint, embedding=OpenAIEmbeddings(api_key=openai_api_key), ) # In case a collection with this name already existed, # delete any documents in the collection vector_store.clear() # Insert the documents into the collection print("Inserting documents...") vector_store.add_documents(documents_with_links)1 Store your database’s endpoint, application token, and OpenAI key in environment variables named API_ENDPOINT,APPLICATION_TOKEN, andOPENAI_API_KEY, as instructed in Set environment variables.2 Replace PATH_TO_DATA_FILEwith the path to the JSON data file. -
Execute the code. You should see printed messages indicating that vector store is being created and that documents are being inserted.
Use the graph for retrieval and generation
-
Copy the following code into a python file.
This code uses a graph retriever that performs a vector search to find the documents that are most similar to a given string, then traverses the graph to find connected documents. The documents found by the graph retriever are passed along with the original question to the LLM.
from pprint import pprint import os from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from graph_retriever.strategies import Eager from langchain_graph_retriever import GraphRetriever from langchain_astradb import AstraDBVectorStore from langchain_openai import OpenAIEmbeddings endpoint = os.environ.get("API_ENDPOINT") (1) application_token = os.environ.get("APPLICATION_TOKEN") openai_api_key = os.environ.get("OPENAI_API_KEY") if not endpoint or not application_token or not openai_api_key: raise RuntimeError( "Environment variables API_ENDPOINT, APPLICATION_TOKEN, OPENAI_API_KEY must be defined." ) # Initialize AstraDBVectorStore # based on the collection that you created and populated earlier vector_store = AstraDBVectorStore( collection_name="graph_rag_tutorial", token=application_token, api_endpoint=endpoint, embedding=OpenAIEmbeddings(api_key=openai_api_key), autodetect_collection=True, ) # Initialize the LLM llm = ChatOpenAI(model="gpt-4o", api_key=openai_api_key) # Define the prompt template template = """Answer the question based only on the following context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) # Initialize GraphRetriever. # This retriever first uses vector search to find relevant documents, # then uses graph traversal to explore their connections. retriever = GraphRetriever( store=vector_store, edges=[("hyperlink", "url")], strategy=Eager( # Number of documents to fetch via vector search for starting the traversal. start_k=3, # Maximum total documents to retrieve during traversal select_k=10, # Maximum traversal depth. # A value of 0 only performs vector search, but does not do any graph traversal. max_depth=1, ), ) # Helper function to format the retrieved documents for LLM context def format_docs(docs): return "\n\n".join([d.page_content for d in docs]) # Build the RAG chain: # 1. Take the input question, retrieve relevant documents from a vector store, # and format the documents into a string. # Create a dictionary with the formatted documents and the original question. # 2. Create the prompt by injecting the dictionary from step 1 into the prompt template. # 3. Pass the formatted prompt to the LLM. # 4. Clean and return the LLM output. chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) # Try these questions to explore the knowledge graph: QUESTION = "What is close to the Space Needle?" # Alternative questions: # - "What is in the Lower Queen Anne neighborhood?" # - "What is in the same neighborhood as the Space Needle?" # - "What connects the 1962 World's Fair to modern Seattle?" # - "Where is Chihuly Garden and Glass?" print(f"\nQuestion: {QUESTION}\n") try: response = chain.invoke(QUESTION) print("Answer:") pprint(response) except Exception as e: print(f"Error during RAG query: {e}")1 Store your database’s endpoint, application token, and OpenAI key in environment variables named API_ENDPOINT,APPLICATION_TOKEN, andOPENAI_API_KEY, as instructed in Set environment variables. -
Execute the code. You should see the question print to the console, followed by the answer from the LLM.
Next steps
-
Ask different questions to see how the graph retriever performs.
-
Tune the
start_k,select_k, andmax_depthparameters to see how this affects the results. Note that increasing these values will increase the number of documents retrieved and passed to the LLM, which will increase the cost of the operation.-
start_kis the number of documents to retrieve via vector search for starting the graph traversal.Increasing
start_kcan help with questions that might match multiple documents. -
select_kis the number of documents to retrieve during graph traversal.Increasing
select_kcan help with questions that require broad context. -
max_depthis the maximum traversal depth.Increasing
max_depthcan help with questions that require more distant connections, but might also retrieve too many loosely related documents.
-
-
Try using a larger dataset.
Cleanup
After completing the tutorial, you can erase the tutorial data from your Astra organization:
-
You can delete the entire database.