Integrate Cohere with Astra DB Serverless
After embedding text with Cohere’s embedding models, Astra DB Serverless indexes the text to perform similarity searches.
This tutorial uses a portion of the Stanford Question Answering Dataset (SQuAD), which consists of questions and answers to set up a Retrieval-Augmented Generation (RAG) pipeline. To perform a similarity search and retrieve relevant answers from the database, embed the questions and store the questions alongside the answers in the database. Then, embed the user’s query.
Prerequisites
-
An active Astra account.
-
An active Serverless (Vector) database.
-
An application token with the Database Administrator role and your database’s API endpoint. For more information, see Generate an application token for a database.
-
A Cohere account and Cohere API key from your Cohere dashboard.
Install packages
-
Install and import the required Python packages:
pip install -U cohere astrapy datasets python-dotenv
-
cohere
is the interface for the Cohere models. -
astrapy
is the the Astra DB Data API client. -
datasets
contains the SQuAD dataset with the question-and-answer data to be stored in Astra DB Serverless. Thedatasets
can access many datasets from the Hugging Face Datasets Hub. -
python-dotenv
allows the program to load the required credentials from a.env
file.
-
-
Create a Python script file to run the integration. In the following steps, you will add code to your integration script.
cohere-integration.pyimport os import cohere from astrapy import DataAPIClient from datasets import load_dataset from dotenv import load_dotenv # ...
Set environment variables
-
Create a
.env
file with your Cohere and Astra DB Serverless credentials:.envASTRA_DB_APPLICATION_TOKEN=TOKEN ASTRA_DB_API_ENDPOINT=API_ENDPOINT ASTRA_DB_KEYSPACE="default_keyspace" ASTRA_DB_COLLECTION_NAME="cohere" COHERE_API_KEY=COHERE_API_KEY
-
Load the credentials in the Python script and create the required clients:
cohere-integration.py# ... load_dotenv() cohere_api_key = os.environ["COHERE_API_KEY"] client = DataAPIClient() database = client.get_database( os.environ["ASTRA_DB_API_ENDPOINT"], token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], keyspace=os.environ["ASTRA_DB_KEYSPACE"], ) cohere_client = cohere.Client(cohere_api_key) # ...
Create a collection
To create a collection, call the create_collection
method.
For a vector store application, you must create your collection with a specified embedding dimension. Only embeddings of this specific length can be inserted into the collection. Therefore, you must select your Cohere embedding model first, and then adjust the vector dimension accordingly.
This guide uses the "embed-english-v3.0"
model, which has a dimension of 1024.
For a list of available models and their dimensions, see the Cohere API Reference.
# ...
COHERE_MODEL_NAME="embed-english-v3.0"
COHERE_MODEL_DIMENSION = 1024
collection = database.create_collection(
os.environ["ASTRA_DB_COLLECTION_NAME"],
dimension=COHERE_MODEL_DIMENSION,
)
# ...
Prepare the data
-
Load the SQuAD dataset:
cohere-integration.py# ... # Select the first 2,000 rows of the training set. # These rows contain both questions and answers. squad = load_dataset("squad", split="train[:2000]") # Show some example question/answer pairs. print("Sample entries in dataset:") for question, answers in zip(squad["question"][:5], squad["answers"][:5]): print(f"\n - Question: {question}") print(f" - Answers: {answers['text'][0]}") # ...
-
Ask Cohere for the embeddings of all of these questions. The
model
you selected matches the embedding dimension of the collection.cohere-integration.py# ... embeddings = cohere_client.embed( texts=squad["question"], model=COHERE_MODEL_NAME, input_type="search_document", truncate="END", ).embeddings # Check that the embeddings have the correct dimension. if len(embeddings[0]) != COHERE_MODEL_DIMENSION: print("Dimension mismatch") # You must adjust the dimension, then delete and re-create the collection. # ...
When using Cohere for RAG, embed the documents with an
input_type
of"search_document"
, and then embed the query with aninput_type
of"search_query"
.The
truncate
value of"END"
means that if the provided text is too long to embed, the model cuts off the end of the offending text and returns an embedding of only the beginning part. Other options for this parameter include"START"
, which cuts off the beginning of the text, and"NONE"
, which returns an error message if the text is too long. -
Combine each dictionary, which represents a row from the SQuAD dataset, with its generated embedding.
This process creates one dictionary with the SQuAD dataset keys and values untouched and the embedding associated with the
"$vector"
key. Embeddings need to be top-level values associated with the"$vector"
key to be valid vector search targets in Astra DB Serverless.cohere-integration.py# ... to_insert = [] for doc_index, squad_document in enumerate(squad): to_insert.append({**squad_document, "$vector": embeddings[doc_index]}) # ...
Use the insert_many
method to insert your documents into Astra DB:
+ .cohere-integration.py
# ...
insert_result = collection.insert_many(to_insert)
print(f"\nInserted {len(insert_result.inserted_ids)} documents.")
# ...
Embed the query and get the answer
-
Call
cohere.embed
again with theinput_type
of"search_query"
. Use the samemodel
andtruncate
values. Replace the text inuser_query
to search for an answer to a different question.cohere-integration.py# ... user_query = "What's in front of Notre Dame?" embedded_query = cohere_client.embed( texts=[user_query], model=COHERE_MODEL_NAME, input_type="search_query", truncate="END", ).embeddings[0] # ...
-
Use the
find
method to extract the top documents whose question is similar to the embedded query:cohere-integration.py# ... results = collection.find( sort={"$vector": embedded_query}, limit=5, include_similarity=True, ) # ...
-
Extract the answer value and similarity score from those rows.
cohere-integration.py# ... print(f"\nQuery: {user_query}") print("Answers:") for idx, answer in enumerate(results): answer_text = answer["answers"]["text"][0] similarity = answer["$similarity"] print(f" - Answer {idx} (similarity: {similarity:.3f}): {answer_text}")
Result
This is the result of running the whole script provided above with
python3 cohere-integration.py
.Sample entries in dataset: - Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? - Answers: Saint Bernadette Soubirous - Question: What is in front of the Notre Dame Main Building? - Answers: a copper statue of Christ - Question: The Basilica of the Sacred heart at Notre Dame is beside to which structure? - Answers: the Main Building - Question: What is the Grotto at Notre Dame? - Answers: a Marian place of prayer and reflection - Question: What sits on top of the Main Building at Notre Dame? - Answers: a golden statue of the Virgin Mary Inserted 2000 documents. Query: What's in front of Notre Dame? Answers: - Answer 0 (similarity: 0.883): a copper statue of Christ - Answer 1 (similarity: 0.779): a golden statue of the Virgin Mary - Answer 2 (similarity: 0.765): the City of South Bend - Answer 3 (similarity: 0.747): the Main Building - Answer 4 (similarity: 0.730): Old College