RAGStack with CassIO

colab badge

Large Language Models (LLMs) have a data freshness problem. The most powerful LLMs in the world, like GPT-4, have no idea about recent world events.

The world of LLMs is frozen in time. Their world exists as a static snapshot of the world as it was within their training data.

A solution to this problem is Retrieval Augmentated Generation (RAG). The idea behind this is that we retrieve relevant information from an external knowledge base and give that information to our LLM. In this notebook, we will learn how to do that. In this demo, external or proprietary data will be stored in Astra DB Serverless and used to provide more current LLM responses.

Get started with this notebook

See Prerequisites for instructions on setting up your environment.

  1. Install the following libraries.

    pip install \
        "ragstack-ai" \
        "openai" \
        "pypdf" \
        "python-dotenv" \
        "datasets" \
        "pandas" \
        "google-cloud-aiplatform"
  2. Import dependencies.

    import os
    from dotenv import load_dotenv
    from cassandra.cluster import Cluster
    from cassandra.auth import PlainTextAuthProvider
    from cassandra.query import SimpleStatement
    from langchain_openai import OpenAIEmbeddings
    from langchain.vectorstores import Cassandra
    from langchain.indexes.vectorstore import VectorStoreIndexWrapper
    from langchain_community.document_loaders import TextLoader
    from langchain_community.document_loaders import PyPDFLoader
    from langchain.chat_models import ChatOpenAI

    You will need a secure connect bundle and a user with Database Administrator permissions. More information about how to get the bundle can be found at https://docs.datastax.com/en/astra-serverless/docs/connect/secure-connect-bundle.html.

  3. Initialize the environment variables.

    ASTRA_DB_SECURE_BUNDLE_PATH = os.getenv("ASTRA_DB_SECURE_BUNDLE_PATH")
    ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
    ASTRA_DB_APPLICATION_TOKEN_BASED_USERNAME = "token"
    ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_NAMESPACE")
    ASTRA_DB_TABLE_NAME = os.getenv("ASTRA_DB_COLLECTION")
    OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
  4. Retrieve the text of a short story that will be indexed in the vector store and set it as the sample data. This is a short story by Edgar Allen Poe called "The Cask of Amontillado".

    curl https://raw.githubusercontent.com/CassioML/cassio-website/main/docs/frameworks/langchain/texts/amontillado.txt --output amontillado.txt
    SAMPLEDATA = ["amontillado.txt"]
  5. Connect to Astra DB Serverless. The following assumes that a vector-search-capable Astra DB Serverless instance is available. If you don’t have one, you can create one by following the instructions at Create a Serverless (Vector) database.

    def getCQLSession(mode='astra_db'):
        if mode == 'astra_db':
            cluster = Cluster(
                cloud={
                    "secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH,
                },
                auth_provider=PlainTextAuthProvider(
                    ASTRA_DB_APPLICATION_TOKEN_BASED_USERNAME,
                    ASTRA_DB_APPLICATION_TOKEN,
                ),
            )
            astraSession = cluster.connect()
            return astraSession
        else:
            raise ValueError('Unsupported CQL Session mode')
    
    def getCQLKeyspace(mode='astra_db'):
        if mode == 'astra_db':
            return ASTRA_DB_KEYSPACE
        else:
            raise ValueError('Unsupported CQL Session mode')
    
    def getTableCount():
      # create a query that counts the number of records of the AstraDB table
      query = SimpleStatement(f"""SELECT COUNT(*) FROM {keyspace}.{table_name};""")
    
      # execute the query
      results = session.execute(query)
      return results.one().count
    
    cqlMode = 'astra_db'
    session = getCQLSession(mode=cqlMode)
    keyspace = getCQLKeyspace(mode=cqlMode)
  6. Instantiate the LLM and embeddings model.

    llm = ChatOpenAI(temperature=0)
    myEmbedding = OpenAIEmbeddings()

Langchain retrieval augmentation

The following is a minimal usage of the Cassandra vector store. The store is created and filled at once, and is then queried to retrieve relevant parts of the indexed text, which are then stuffed into a prompt finally used to answer a question. SAMPLEDATA = [] clears the list so the same files aren’t indexed multiple times.

documents = []
for filename in SAMPLEDATA:
  path = os.path.join(os.getcwd(), filename)

  # Supported file types are pdf and txt
  if filename.endswith(".pdf"):
    loader = PyPDFLoader(path)
    new_docs = loader.load_and_split()
    print(f"Processed pdf file: {filename}")
  elif filename.endswith(".txt"):
    loader = TextLoader(path)
    new_docs = loader.load_and_split()
    print(f"Processed txt file: {filename}")
  else:
    print(f"Unsupported file type: {filename}")

  if len(new_docs) > 0:
    documents.extend(new_docs)

cassVStore = Cassandra.from_documents(
  documents=documents,
  embedding=OpenAIEmbeddings(),
  session=session,
  keyspace=ASTRA_DB_KEYSPACE,
  table_name=ASTRA_DB_TABLE_NAME,
)

SAMPLEDATA = []
print(f"\nProcessing done.")

Query proprietary store

Use VectorStoreIndexWrapper from langchain.indexes.vectorstore for querying.

index = VectorStoreIndexWrapper(vectorstore=cassVStore)
query = "Who is Luchesi?"
index.query(query,llm=llm)
query = "What motivates Montresor to seek revenge against Fortunato?"
index.query(query,llm=llm)
# We can query the index for the relevant documents, which act as context for the LLM.
retriever = index.vectorstore.as_retriever(search_kwargs={
    'k': 2, # retrieve 2 documents
})
retriever.get_relevant_documents(
    "What motivates Montresor to seek revenge against Fortunado?"
)

Was this helpful?

Give Feedback

How can we improve the documentation?

© 2024 DataStax | Privacy policy | Terms of use

Apache, Apache Cassandra, Cassandra, Apache Tomcat, Tomcat, Apache Lucene, Apache Solr, Apache Hadoop, Hadoop, Apache Pulsar, Pulsar, Apache Spark, Spark, Apache TinkerPop, TinkerPop, Apache Kafka and Kafka are either registered trademarks or trademarks of the Apache Software Foundation or its subsidiaries in Canada, the United States and/or other countries. Kubernetes is the registered trademark of the Linux Foundation.

General Inquiries: +1 (650) 389-6000, info@datastax.com