Integrate Cohere with Astra DB Serverless

After embedding text with Cohere’s embedding models, Astra DB Serverless indexes the text to perform similarity searches.

This tutorial uses a portion of the Stanford Question Answering Dataset (SQuAD), which consists of questions and answers to set up a Retrieval-Augmented Generation (RAG) pipeline. To perform a similarity search and retrieve relevant answers from the database, embed the questions and store the questions alongside the answers in the database. Then, embed the user’s query.

Prerequisites

Install packages

  1. Install and import the required Python packages:

    pip install -U cohere astrapy datasets python-dotenv
    • cohere is the interface for the Cohere models.

    • astrapy is the the Astra DB Data API client.

    • datasets contains the SQuAD dataset with the question-and-answer data to be stored in Astra DB Serverless. The datasets can access many datasets from the Hugging Face Datasets Hub.

    • python-dotenv allows the program to load the required credentials from a .env file.

  2. Create a Python script file to run the integration. In the following steps, you will add code to your integration script.

    cohere-integration.py
    import os
    
    import cohere
    from astrapy import DataAPIClient
    from datasets import load_dataset
    from dotenv import load_dotenv
    
    # ...

Set environment variables

  1. Create a .env file with your Cohere and Astra DB Serverless credentials:

    .env
    ASTRA_DB_APPLICATION_TOKEN=TOKEN
    ASTRA_DB_API_ENDPOINT=API_ENDPOINT
    ASTRA_DB_KEYSPACE="default_keyspace"
    ASTRA_DB_COLLECTION_NAME="cohere"
    COHERE_API_KEY=COHERE_API_KEY
  2. Load the credentials in the Python script and create the required clients:

    cohere-integration.py
    # ...
    
    load_dotenv()
    
    cohere_api_key = os.environ["COHERE_API_KEY"]
    
    client = DataAPIClient()
    database = client.get_database(
        os.environ["ASTRA_DB_API_ENDPOINT"],
        token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
        keyspace=os.environ["ASTRA_DB_KEYSPACE"],
    )
    
    cohere_client = cohere.Client(cohere_api_key)
    
    # ...

Create a collection

To create a collection, call the create_collection method.

For a vector store application, you must create your collection with a specified embedding dimension. Only embeddings of this specific length can be inserted into the collection. Therefore, you must select your Cohere embedding model first, and then adjust the vector dimension accordingly.

This guide uses the "embed-english-v3.0" model, which has a dimension of 1024. For a list of available models and their dimensions, see the Cohere API Reference.

cohere-integration.py
# ...

COHERE_MODEL_NAME="embed-english-v3.0"
COHERE_MODEL_DIMENSION = 1024
collection = database.create_collection(
    os.environ["ASTRA_DB_COLLECTION_NAME"],
    dimension=COHERE_MODEL_DIMENSION,
)

# ...

Prepare the data

  1. Load the SQuAD dataset:

    cohere-integration.py
    # ...
    
    # Select the first 2,000 rows of the training set.
    # These rows contain both questions and answers.
    squad = load_dataset("squad", split="train[:2000]")
    
    # Show some example question/answer pairs.
    print("Sample entries in dataset:")
    for question, answers in zip(squad["question"][:5], squad["answers"][:5]):
        print(f"\n  - Question: {question}")
        print(f"  - Answers: {answers['text'][0]}")
    
    # ...
  2. Ask Cohere for the embeddings of all of these questions. The model you selected matches the embedding dimension of the collection.

    cohere-integration.py
    # ...
    
    embeddings = cohere_client.embed(
        texts=squad["question"],
        model=COHERE_MODEL_NAME,
        input_type="search_document",
        truncate="END",
    ).embeddings
    
    # Check that the embeddings have the correct dimension.
    if len(embeddings[0]) != COHERE_MODEL_DIMENSION:
        print("Dimension mismatch")
        # You must adjust the dimension, then delete and re-create the collection.
    
    # ...

    When using Cohere for RAG, embed the documents with an input_type of "search_document", and then embed the query with an input_type of "search_query".

    The truncate value of "END" means that if the provided text is too long to embed, the model cuts off the end of the offending text and returns an embedding of only the beginning part. Other options for this parameter include "START", which cuts off the beginning of the text, and "NONE", which returns an error message if the text is too long.

  3. Combine each dictionary, which represents a row from the SQuAD dataset, with its generated embedding.

    This process creates one dictionary with the SQuAD dataset keys and values untouched and the embedding associated with the "$vector" key. Embeddings need to be top-level values associated with the "$vector" key to be valid vector search targets in Astra DB Serverless.

    cohere-integration.py
    # ...
    
    to_insert = []
    for doc_index, squad_document in enumerate(squad):
        to_insert.append({**squad_document, "$vector": embeddings[doc_index]})
    
    # ...

Use the insert_many method to insert your documents into Astra DB:

+ .cohere-integration.py

# ...

insert_result = collection.insert_many(to_insert)
print(f"\nInserted {len(insert_result.inserted_ids)} documents.")

# ...

Embed the query and get the answer

  1. Call cohere.embed again with the input_type of "search_query". Use the same model and truncate values. Replace the text in user_query to search for an answer to a different question.

    cohere-integration.py
    # ...
    
    user_query = "What's in front of Notre Dame?"
    embedded_query = cohere_client.embed(
        texts=[user_query],
        model=COHERE_MODEL_NAME,
        input_type="search_query",
        truncate="END",
    ).embeddings[0]
    
    # ...
  2. Use the find method to extract the top documents whose question is similar to the embedded query:

    cohere-integration.py
    # ...
    
    results = collection.find(
        sort={"$vector": embedded_query},
        limit=5,
        include_similarity=True,
    )
    
    # ...
  3. Extract the answer value and similarity score from those rows.

    cohere-integration.py
    # ...
    
    print(f"\nQuery: {user_query}")
    print("Answers:")
    for idx, answer in enumerate(results):
        answer_text = answer["answers"]["text"][0]
        similarity = answer["$similarity"]
        print(f"  - Answer {idx} (similarity: {similarity:.3f}): {answer_text}")
    Result

    This is the result of running the whole script provided above with python3 cohere-integration.py.

    Sample entries in dataset:
    
      - Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
      - Answers: Saint Bernadette Soubirous
    
      - Question: What is in front of the Notre Dame Main Building?
      - Answers: a copper statue of Christ
    
      - Question: The Basilica of the Sacred heart at Notre Dame is beside to which structure?
      - Answers: the Main Building
    
      - Question: What is the Grotto at Notre Dame?
      - Answers: a Marian place of prayer and reflection
    
      - Question: What sits on top of the Main Building at Notre Dame?
      - Answers: a golden statue of the Virgin Mary
    
    Inserted 2000 documents.
    
    Query: What's in front of Notre Dame?
    Answers:
      - Answer 0 (similarity: 0.883): a copper statue of Christ
      - Answer 1 (similarity: 0.779): a golden statue of the Virgin Mary
      - Answer 2 (similarity: 0.765): the City of South Bend
      - Answer 3 (similarity: 0.747): the Main Building
      - Answer 4 (similarity: 0.730): Old College

Was this helpful?

Give Feedback

How can we improve the documentation?

© 2024 DataStax | Privacy policy | Terms of use

Apache, Apache Cassandra, Cassandra, Apache Tomcat, Tomcat, Apache Lucene, Apache Solr, Apache Hadoop, Hadoop, Apache Pulsar, Pulsar, Apache Spark, Spark, Apache TinkerPop, TinkerPop, Apache Kafka and Kafka are either registered trademarks or trademarks of the Apache Software Foundation or its subsidiaries in Canada, the United States and/or other countries. Kubernetes is the registered trademark of the Linux Foundation.

General Inquiries: +1 (650) 389-6000, info@datastax.com