Contextual Retrieval by anthropic

Exploring Contextual Retrieval with LangChain

🌟 Diving into Contextual Retrieval with LangChain 🌟

Welcome to an exciting journey into Contextual Retrieval! This blog post will guide you through implementing Anthropic’s Contextual Retrieval using LangChain in a simple, colorful, and beginner-friendly way. Let’s make complex retrieval systems easy to understand! 🚀

Contextual Retrieval Diagram

🎯 What is Contextual Retrieval?

Traditional Retrieval-Augmented Generation (RAG) systems can miss the mark because they don’t provide enough context for document chunks. Contextual Retrieval solves this by adding a short explanation to each chunk before embedding it, making searches smarter and more accurate. Let’s see how it works with code!

1️⃣ Setting Up the Environment

import logging
import os

logging.disable(level=logging.INFO)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["AZURE_OPENAI_API_KEY"] = ""
os.environ["AZURE_OPENAI_ENDPOINT"] = ""
os.environ["COHERE_API_KEY"] = ""

What’s happening here? We import basic libraries, disable unnecessary logging, and set environment variables for Azure OpenAI and Cohere APIs. This prepares our system to use these services.

2️⃣ Installing Required Libraries

!pip install -q langchain langchain-openai langchain-community faiss-cpu rank_bm25 langchain-cohere

Why this? This command installs LangChain and its dependencies, including tools for embeddings (FAISS), ranking (BM25), and reranking (Cohere). The -q keeps the output quiet.

3️⃣ Importing LangChain Tools

from langchain.document_loaders import TextLoader
from langchain.prompts import PromptTemplate
from langchain.retrievers import BM25Retriever
from langchain.vectorstores import FAISS
from langchain_cohere import CohereRerank
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings

What’s this doing? We import LangChain modules for loading documents, creating prompts, retrieving data (BM25 and FAISS), reranking results, and connecting to Azure’s LLM and embedding models.

4️⃣ Downloading the Dataset

!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O './paul_graham_essay.txt'

What’s this? We download a sample dataset (Paul Graham’s essay) from a public URL and save it as a text file for processing.

5️⃣ Setting Up the LLM and Embeddings

llm = AzureChatOpenAI(
    deployment_name="gpt-4-32k-0613",
    openai_api_version="2023-08-01-preview",
    temperature=0.0,
)

embeddings = AzureOpenAIEmbeddings(
    deployment="text-embedding-ada-002",
    api_version="2023-08-01-preview",
)

What’s happening? We initialize Azure’s GPT-4 model for text generation (with zero temperature for precise outputs) and set up the embedding model to convert text into vectors for similarity search.

6️⃣ Loading the Data

loader = TextLoader("./paul_graham_essay.txt")
documents = loader.load()
WHOLE_DOCUMENT = documents[0].page_content

Why this? We load the essay text file into a LangChain document object and store its content in WHOLE_DOCUMENT for further processing.

7️⃣ Creating Prompts for Contextual Chunks

prompt_document = PromptTemplate(
    input_variables=["WHOLE_DOCUMENT"], template="{WHOLE_DOCUMENT}"
)
prompt_chunk = PromptTemplate(
    input_variables=["CHUNK_CONTENT"],
    template="Here is the chunk we want to situate within the whole document\n\n{CHUNK_CONTENT}\n\n"
    "Please give a short succinct context to situate this chunk within the overall document for "
    "the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else.",
)

What’s this for? We define two prompt templates: one to pass the entire document and another to generate a concise context for each chunk, which will improve retrieval accuracy.

8️⃣ Splitting Text into Chunks

from langchain.text_splitter import RecursiveCharacterTextSplitter

def split_text(texts):
    text_splitter = RecursiveCharacterTextSplitter(chunk_overlap=200)
    doc_chunks = text_splitter.create_documents(texts)
    for i, doc in enumerate(doc_chunks):
        doc.metadata = {"doc_id": f"doc_{i}"}
    return doc_chunks

What’s this doing? This function splits the document into smaller chunks with a 200-character overlap to ensure continuity. Each chunk gets a unique ID in its metadata.

9️⃣ Creating Retrievers

def create_embedding_retriever(documents_):
    vector_store = FAISS.from_documents(documents_, embedding=embeddings)
    return vector_store.as_retriever(search_kwargs={"k": 4})

def create_bm25_retriever(documents_):
    retriever = BM25Retriever.from_documents(documents_, language="english")
    return retriever

What’s happening? We create two retrievers: one using FAISS for vector-based similarity search (returns top 4 results) and another using BM25 for keyword-based ranking.

🔟 Combining Retrievers with Reranking

class EmbeddingBM25RerankerRetriever:
    def __init__(self, vector_retriever, bm25_retriever, reranker):
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
        self.reranker = reranker

    def invoke(self, query):
        vector_docs = self.vector_retriever.invoke(query)
        bm25_docs = self.bm25_retriever.invoke(query)
        combined_docs = vector_docs + [doc for doc in bm25_docs if doc not in vector_docs]
        reranked_docs = self.reranker.compress_documents(combined_docs, query)
        return reranked_docs

Why this? This class combines vector and BM25 retrievers, merges their results, and uses Cohere’s reranker to prioritize the most relevant documents for a given query.

1️⃣1️⃣ Adding Context to Chunks

def create_contextual_chunks(chunks_):
    contextual_documents = []
    for chunk in tqdm.tqdm(chunks_):
        context = prompt_document.format(WHOLE_DOCUMENT=WHOLE_DOCUMENT)
        chunk_context = prompt_chunk.format(CHUNK_CONTENT=chunk)
        llm_response = llm.invoke(context + chunk_context).content
        page_content = f"""Text: {chunk.page_content}\n\n\nContext: {llm_response}"""
        doc = Document(page_content=page_content, metadata=chunk.metadata)
        contextual_documents.append(doc)
    return contextual_documents

What’s this? This function adds a contextual explanation to each chunk using the LLM and prompt templates, creating new documents that include both the chunk text and its context.

1️⃣2️⃣ Generating Question-Context Pairs

def generate_question_context_pairs(documents, llm, num_questions_per_chunk=2):
    doc_dict = {doc.metadata["doc_id"]: doc.page_content for doc in documents}
    queries = {}
    relevant_docs = {}
    for doc_id, text in tqdm(doc_dict.items()):
        query = DEFAULT_QA_GENERATE_PROMPT_TMPL.format(context_str=text, num_questions_per_chunk=num_questions_per_chunk)
        response = llm.invoke(query).content
        questions = [re.sub(r"^\d+[\).\s]", "", q).strip() for q in re.split(r"\n+", response.strip()) if q.strip()]
        questions = questions[:num_questions_per_chunk]
        for question in questions:
            question_id = str(uuid.uuid4())
            queries[question_id] = question
            relevant_docs[question_id] = [doc_id]
    return QuestionContextEvalDataset(queries=queries, corpus=doc_dict, relevant_docs=relevant_docs)

What’s this for? This generates questions for each chunk (2 per chunk) using the LLM and creates a dataset mapping questions to their relevant document IDs for evaluation.

1️⃣3️⃣ Evaluating Retrievers

def evaluate(retriever, dataset):
    mrr_result = []
    hit_rate_result = []
    ndcg_result = []
    for i in tqdm(range(len(dataset.queries))):
        context = retriever.invoke(extract_queries(dataset)[i])
        expected_ids = dataset.relevant_docs[list(dataset.queries.keys())[i]]
        retrieved_ids = extract_doc_ids(context)
        mrr = compute_mrr(expected_ids=expected_ids, retrieved_ids=retrieved_ids)
        hit_rate = compute_hit_rate(expected_ids=expected_ids, retrieved_ids=retrieved_ids)
        ndcg = compute_ndcg(expected_ids=expected_ids, retrieved_ids=retrieved_ids)
        mrr_result.append(mrr)
        hit_rate_result.append(hit_rate)
        ndcg_result.append(ndcg)
    results_df = pd.DataFrame(np.mean([mrr_result, hit_rate_result, ndcg_result], axis=1), index=["MRR", "Hit Rate", "nDCG"])
    return results_df

What’s happening? This function evaluates a retriever by computing three metrics (MRR, Hit Rate, nDCG) for each query in the dataset and returns the average scores in a DataFrame.

1️⃣4️⃣ Displaying Results

def display_results(name, eval_results):
    metrics = ["MRR", "Hit Rate", "nDCG"]
    columns = {"Retrievers": [name], **{metric: val for metric, val in zip(metrics, eval_results.values)}}
    return pd.DataFrame(columns)

pd.concat([
    display_results("Embedding Retriever", embedding_retriever_results),
    display_results("BM25 Retriever", bm25_results),
    display_results("Embedding + BM25 Retriever + Reranker", embedding_bm25_rerank_results),
], ignore_index=True)

Why this? This code displays the evaluation results for non-contextual retrievers in a neat table, comparing their performance across MRR, Hit Rate, and nDCG metrics.

1️⃣5️⃣ Contextual Retriever Results

pd.concat([
    display_results("Contextual Embedding Retriever", contextual_embedding_retriever_results),
    display_results("Contextual BM25 Retriever", contextual_bm25_results),
    display_results("Contextual Embedding + BM25 Retriever + Reranker", contextual_embedding_bm25_rerank_results),
], ignore_index=True)

What’s this? Similar to the previous step, this displays results for contextual retrievers, allowing us to compare their performance against non-contextual ones.

🎉 Wrapping Up

Congratulations! You’ve just explored how to implement Contextual Retrieval using LangChain. By adding context to chunks, combining retrievers, and evaluating performance, you’ve built a smarter search system. Try tweaking the code and experimenting with your own datasets to see how it performs! 🌈

Comments