RAGStack with CassIO
Large Language Models (LLMs) have a data freshness problem. The most powerful LLMs in the world, like GPT-4, have no idea about recent world events.
The world of LLMs is frozen in time. Their world exists as a static snapshot of the world as it was within their training data.
A solution to this problem is Retrieval Augmentated Generation (RAG). The idea behind this is that we retrieve relevant information from an external knowledge base and give that information to our LLM. In this notebook, we will learn how to do that. In this demo, external or proprietary data will be stored in Astra DB Serverless and used to provide more current LLM responses.
Get started with this notebook
See Prerequisites for instructions on setting up your environment.
-
Install the following libraries.
pip install \ "ragstack-ai" \ "openai" \ "pypdf" \ "python-dotenv" \ "datasets" \ "pandas" \ "google-cloud-aiplatform"
-
Import dependencies.
import os from dotenv import load_dotenv from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import SimpleStatement from langchain_openai import OpenAIEmbeddings from langchain.vectorstores import Cassandra from langchain.indexes.vectorstore import VectorStoreIndexWrapper from langchain_community.document_loaders import TextLoader from langchain_community.document_loaders import PyPDFLoader from langchain.chat_models import ChatOpenAI
You will need a secure connect bundle and a user with Database Administrator permissions. More information about how to get the bundle can be found at https://docs.datastax.com/en/astra-serverless/docs/connect/secure-connect-bundle.html.
-
Initialize the environment variables.
ASTRA_DB_SECURE_BUNDLE_PATH = os.getenv("ASTRA_DB_SECURE_BUNDLE_PATH") ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN") ASTRA_DB_APPLICATION_TOKEN_BASED_USERNAME = "token" ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_NAMESPACE") ASTRA_DB_TABLE_NAME = os.getenv("ASTRA_DB_COLLECTION") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
-
Retrieve the text of a short story that will be indexed in the vector store and set it as the sample data. This is a short story by Edgar Allen Poe called "The Cask of Amontillado".
curl https://raw.githubusercontent.com/CassioML/cassio-website/main/docs/frameworks/langchain/texts/amontillado.txt --output amontillado.txt SAMPLEDATA = ["amontillado.txt"]
-
Connect to Astra DB Serverless. The following assumes that a vector-search-capable Astra DB Serverless instance is available. If you don’t have one, you can create one by following the instructions at Create a Serverless (Vector) database.
def getCQLSession(mode='astra_db'): if mode == 'astra_db': cluster = Cluster( cloud={ "secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH, }, auth_provider=PlainTextAuthProvider( ASTRA_DB_APPLICATION_TOKEN_BASED_USERNAME, ASTRA_DB_APPLICATION_TOKEN, ), ) astraSession = cluster.connect() return astraSession else: raise ValueError('Unsupported CQL Session mode') def getCQLKeyspace(mode='astra_db'): if mode == 'astra_db': return ASTRA_DB_KEYSPACE else: raise ValueError('Unsupported CQL Session mode') def getTableCount(): # create a query that counts the number of records of the AstraDB table query = SimpleStatement(f"""SELECT COUNT(*) FROM {keyspace}.{table_name};""") # execute the query results = session.execute(query) return results.one().count cqlMode = 'astra_db' session = getCQLSession(mode=cqlMode) keyspace = getCQLKeyspace(mode=cqlMode)
-
Instantiate the LLM and embeddings model.
llm = ChatOpenAI(temperature=0) myEmbedding = OpenAIEmbeddings()
Langchain retrieval augmentation
The following is a minimal usage of the Cassandra vector store. The store is created and filled at once, and is then queried to retrieve relevant parts of the indexed text, which are then stuffed into a prompt finally used to answer a question.
SAMPLEDATA = []
clears the list so the same files aren’t indexed multiple times.
documents = []
for filename in SAMPLEDATA:
path = os.path.join(os.getcwd(), filename)
# Supported file types are pdf and txt
if filename.endswith(".pdf"):
loader = PyPDFLoader(path)
new_docs = loader.load_and_split()
print(f"Processed pdf file: {filename}")
elif filename.endswith(".txt"):
loader = TextLoader(path)
new_docs = loader.load_and_split()
print(f"Processed txt file: {filename}")
else:
print(f"Unsupported file type: {filename}")
if len(new_docs) > 0:
documents.extend(new_docs)
cass_vstore = Cassandra.from_documents(
documents=documents,
embedding=OpenAIEmbeddings(),
session=session,
keyspace=ASTRA_DB_KEYSPACE,
table_name=ASTRA_DB_TABLE_NAME,
)
SAMPLEDATA = []
print(f"\nProcessing done.")
Query proprietary store
Use VectorStoreIndexWrapper
from langchain.indexes.vectorstore
for querying.
index = VectorStoreIndexWrapper(vectorstore=cass_vstore)
query = "Who is Luchesi?"
index.query(query,llm=llm)
query = "What motivates Montresor to seek revenge against Fortunato?"
index.query(query,llm=llm)
# We can query the index for the relevant documents, which act as context for the LLM.
retriever = index.vectorstore.as_retriever(search_kwargs={
'k': 2, # retrieve 2 documents
})
retriever.get_relevant_documents(
"What motivates Montresor to seek revenge against Fortunado?"
)