VectorStore QA with MMR

colab badge

This page demonstrates using RAGStack and an vector-enabled Astra DB Serverless database to perform vector search with the Maximal Marginal Relevance (MMR) algorithm.

Instead of selecting the top k stored documents most relevant to the provided query, MMR first identifies a larger pool of relevant results, and then retrieves top k from this pool. MMR algorithms return results with more diverse information.

Prerequisites

  1. You will need an vector-enabled Astra DB Serverless database.

    1. Create an Astra vector database.

    2. Within your database, create an Astra DB Access Token with Database Administrator permissions.

    3. Copy the Astra DB Serverless API Endpoint for your Astra DB Serverless database.

  2. Set the following environment variables in a .env file in the root of your project:

    ASTRA_DB_ID=aad075g999-8ab4-4d81-aa7d-7f58dbed3ead
    ASTRA_DB_APPLICATION_TOKEN=AstraCS:...
    OPENAI_API_KEY=sk-...
    ASTRA_DB_KEYSPACE=default_keyspace #optional

    The ASTRA_DB_ID can be found in the Astra DB Serverless API Endpoint that’s displayed for your vector-enabled database in Astra Portal. If your API Endpoint is https://aad075g999-8ab4-4d81-aa7d-7f58dbed3ead-us-east-2.apps.astra.datastax.com, then your ASTRA_DB_ID is aad075g999-8ab4-4d81-aa7d-7f58dbed3ead.

  3. Install the following dependencies:

    pip install -qU ragstack-ai python-dotenv

    See the Prerequisites page for more details.

Create embedding model and vector store

  1. Import dependencies and load environment variables.

    import os
    import cassio
    from dotenv import load_dotenv
    from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
    from langchain_openai import OpenAI, OpenAIEmbeddings
    from langchain.indexes.vectorstore import VectorStoreIndexWrapper
    from langchain_community.vectorstores import Cassandra
    
    load_dotenv()
  2. Initialize the OpenAI model and embeddings.

    llm = OpenAI(temperature=0)
    myEmbedding = OpenAIEmbeddings()
  3. Initialize the vector store.

    cassio.init(
            database_id=os.environ["ASTRA_DB_ID"],
            token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
            keyspace=os.environ.get("ASTRA_DB_KEYSPACE"),  # this is optional
        )
    
    myCassandraVStore = Cassandra(
        embedding=myEmbedding,
        session=None,
        keyspace=None,
        table_name='vs_test2',
    )
    index = VectorStoreIndexWrapper(vectorstore=myCassandraVStore)

Populate the vector store

  1. Create a list of sentences, with their sources stored as metadata. Note that the last sentence’s content is considerably different from the others.

    # declare data
    
    BASE_SENTENCE_0 =     ('The frogs and the toads were meeting in the night '
                           'for a party under the moon.')
    
    BASE_SENTENCE_1 =     ('There was a party under the moon, that all toads, '
                           'with the frogs, decided to throw that night.')
    
    BASE_SENTENCE_2 =     ('And the frogs and the toads said: "Let us have a party '
                           'tonight, as the moon is shining".')
    
    BASE_SENTENCE_3 =     ('I remember that night... toads, along with frogs, '
                           'were all busy planning a moonlit celebration.')
    
    DIFFERENT_SENTENCE =  ('For the party, frogs and toads set a rule: '
                           'everyone was to wear a purple hat.')
    
    # insert into index
    texts = [
        BASE_SENTENCE_0,
        BASE_SENTENCE_1,
        BASE_SENTENCE_2,
        BASE_SENTENCE_3,
        DIFFERENT_SENTENCE,
    ]
    metadatas = [
        {'source': 'Barney\'s story at the pub'},
        {'source': 'Barney\'s story at the pub'},
        {'source': 'Barney\'s story at the pub'},
        {'source': 'Barney\'s story at the pub'},
        {'source': 'The chronicles at the village library'},
    ]
  2. Load the sentences into the vector store and print their IDs.

    ids = myCassandraVStore.add_texts(
        texts,
        metadatas=metadatas,
        )
    print('\n'.join(ids))

Create and compare retrievers

Create one retriever with similarity search, and another retriever with MMR search.

Both will return the top 2 results with the source metadata included. Ask them a question, and see how the MMR response differs from the similarity response.

  1. Set the question.

    QUESTION = 'Tell me about the party that night.'
  2. Create a retriever with similarity search.

    retrieverSim = myCassandraVStore.as_retriever(
        search_type='similarity',
        search_kwargs={
            'k': 2,
        },
    )
    
    chainSimSrc = RetrievalQAWithSourcesChain.from_chain_type(
        llm,
        retriever=retrieverSim,
    )
    
    responseSimSrc = chainSimSrc.invoke({chainSimSrc.question_key: QUESTION})
    print('Similarity-based chain:')
    print(f'  ANSWER : {responseSimSrc["answer"].strip()}')
    print(f'  SOURCES: {responseSimSrc["sources"].strip()}')
  3. Create a retriever with MMR search.

    retrieverMMR = myCassandraVStore.as_retriever(
        search_type='mmr',
        search_kwargs={
            'k': 2,
        },
    )
    
    chainMMRSrc = RetrievalQAWithSourcesChain.from_chain_type(
        llm,
        retriever=retrieverMMR,
    )
    
    responseMMRSrc = chainMMRSrc.invoke({chainMMRSrc.question_key: QUESTION})
    print('MMR-based chain:')
    print(f'  ANSWER : {responseMMRSrc["answer"].strip()}')
    print(f'  SOURCES: {responseMMRSrc["sources"].strip()}')
  4. Run the code and observe the differences in the responses.

    Similarity search returns only the most similar sentence. MMR returns the DIFFERENT_SENTENCE, which was considerably different from the others.

    Similarity-based chain:
      ANSWER : The party was thrown by all the toads and frogs under the moon that night.
      SOURCES: Barney's story at the pub
    MMR-based chain:
      ANSWER : The party that night was thrown by the frogs and toads, and the rule was for everyone to wear a purple hat.
      SOURCES: Barney's story at the pub, The chronicles at the village library

Complete code example

Python
import os
import cassio
from dotenv import load_dotenv
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain_openai import OpenAI, OpenAIEmbeddings
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain_community.vectorstores import Cassandra

# Load environment variables
load_dotenv()

# Initialize OpenAI and embeddings
llm = OpenAI(temperature=0)
myEmbedding = OpenAIEmbeddings()

cassio.init(
        database_id=os.environ["ASTRA_DB_ID"],
        token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
        keyspace=os.environ.get("ASTRA_DB_KEYSPACE"),  # this is optional
    )

myCassandraVStore = Cassandra(
    embedding=myEmbedding,
    session=None,
    keyspace=None,
    table_name='vs_test2',
)
index = VectorStoreIndexWrapper(vectorstore=myCassandraVStore)

# declare data

BASE_SENTENCE_0 =     ('The frogs and the toads were meeting in the night '
                       'for a party under the moon.')

BASE_SENTENCE_1 =     ('There was a party under the moon, that all toads, '
                       'with the frogs, decided to throw that night.')

BASE_SENTENCE_2 =     ('And the frogs and the toads said: "Let us have a party '
                       'tonight, as the moon is shining".')

BASE_SENTENCE_3 =     ('I remember that night... toads, along with frogs, '
                       'were all busy planning a moonlit celebration.')

DIFFERENT_SENTENCE =  ('For the party, frogs and toads set a rule: '
                       'everyone was to wear a purple hat.')

# insert into index
texts = [
    BASE_SENTENCE_0,
    BASE_SENTENCE_1,
    BASE_SENTENCE_2,
    BASE_SENTENCE_3,
    DIFFERENT_SENTENCE,
]
metadatas = [
    {'source': 'Barney\'s story at the pub'},
    {'source': 'Barney\'s story at the pub'},
    {'source': 'Barney\'s story at the pub'},
    {'source': 'Barney\'s story at the pub'},
    {'source': 'The chronicles at the village library'},
]

# add texts to vector store and print first
ids = myCassandraVStore.add_texts(
    texts,
    metadatas=metadatas,
    )
print('\n'.join(ids))

# query the index

QUESTION = 'Tell me about the party that night.'

# manual creation of the "retriever" with the 'similarity' search type
retrieverSim = myCassandraVStore.as_retriever(
    search_type='similarity',
    search_kwargs={
        'k': 2,
    },
)

chainSimSrc = RetrievalQAWithSourcesChain.from_chain_type(
    llm,
    retriever=retrieverSim,
)

# Run the chain and print results with sources
responseSimSrc = chainSimSrc.invoke({chainSimSrc.question_key: QUESTION})
print('Similarity-based chain:')
print(f'  ANSWER : {responseSimSrc["answer"].strip()}')
print(f'  SOURCES: {responseSimSrc["sources"].strip()}')


# mmr search with sources

# manual creation of the "retriever" with the 'MMR' search type
retrieverMMR = myCassandraVStore.as_retriever(
    search_type='mmr',
    search_kwargs={
        'k': 2,
    },
)

chainMMRSrc = RetrievalQAWithSourcesChain.from_chain_type(
    llm,
    retriever=retrieverMMR,
)

# Run the chain and print results with sources
responseMMRSrc = chainMMRSrc.invoke({chainMMRSrc.question_key: QUESTION})
print('MMR-based chain:')
print(f'  ANSWER : {responseMMRSrc["answer"].strip()}')
print(f'  SOURCES: {responseMMRSrc["sources"].strip()}')

Was this helpful?

Give Feedback

How can we improve the documentation?

© 2024 DataStax | Privacy policy | Terms of use

Apache, Apache Cassandra, Cassandra, Apache Tomcat, Tomcat, Apache Lucene, Apache Solr, Apache Hadoop, Hadoop, Apache Pulsar, Pulsar, Apache Spark, Spark, Apache TinkerPop, TinkerPop, Apache Kafka and Kafka are either registered trademarks or trademarks of the Apache Software Foundation or its subsidiaries in Canada, the United States and/or other countries. Kubernetes is the registered trademark of the Linux Foundation.

General Inquiries: +1 (650) 389-6000, info@datastax.com