Retrieval Augmented Generation
Overview
This comprehensive cookbook demonstrates how to build a production-ready Retrieval-Augmented Generation (RAG) system using the Gravix Layer API. You will learn how to create an intelligent application that can answer questions based on your private documents with high accuracy and contextual understanding.
What is RAG?
Retrieval-Augmented Generation (RAG) is an advanced AI technique that significantly enhances Large Language Models (LLMs) by connecting them to external knowledge bases. This approach addresses one of the fundamental limitations of traditional LLMs: their knowledge cutoff and inability to access real-time or private information.
How RAG Works:
- Retrieval Phase: When you ask a question, the system searches through your document collection to find the most relevant text chunks
- Augmentation Phase: The retrieved context is combined with your question to create an enriched prompt
- Generation Phase: The LLM generates an answer based on both its training knowledge and the specific retrieved context
Key Benefits:
- Reduced Hallucinations: Answers are grounded in actual document content
- Up-to-date Information: Can work with the latest documents and data
- Domain Expertise: Becomes an expert on your specific documents and use cases
- Transparency: Shows which sources were used to generate each answer
- Cost-Effective: No need to fine-tune expensive models
Architecture Overview
Our RAG system consists of several interconnected components:
Document → Text Extraction → Chunking → Embedding → Vector Store
↓
User Question → Embedding → Similarity Search → Context Retrieval
↓
Retrieved Context + Question → LLM → Generated Answer
In this comprehensive guide, you will learn to:
- Environment Setup: Configure your development environment with all necessary dependencies
- API Configuration: Securely set up and configure the Gravix Layer client
- Document Processing: Extract and preprocess text from various document formats
- Text Chunking: Implement intelligent text segmentation strategies
- Embedding Generation: Create high-quality vector representations of text
- Vector Storage: Build and optimize a searchable vector database
- Retrieval System: Implement semantic search and ranking algorithms
- RAG Chain Construction: Assemble all components into a cohesive pipeline
- Query Processing: Handle user questions and generate contextual responses
- Performance Optimization: Fine-tune the system for better accuracy and speed
- Error Handling: Implement robust error handling and debugging strategies
- Deployment Considerations: Prepare the system for production use
Prerequisites
Technical Requirements
- Python 3.8+: Modern Python version with async support
- Memory: At least 4GB RAM (8GB+ recommended for larger documents)
- Storage: Sufficient space for document storage and vector indices
- Internet Connection: For downloading embedding models and API calls
Required Resources
- Gravix Layer API Key: Sign up at Gravix Layer to get your API key
- Documents: PDF files, text documents, or other text-based content for your knowledge base
- Development Environment: Jupyter Notebook, VS Code, or your preferred Python IDE
Implementation
1. Environment Setup and Dependencies
Setting up the right environment is crucial for a successful RAG implementation. We'll install and configure all necessary components step by step.
Installing Core Dependencies
First, let's install all the necessary Python libraries with detailed explanations:
# Core RAG and LangChain components
pip install langchain==0.1.0 # Framework for building LLM applications
pip install langchain-community==0.0.10 # Community integrations for LangChain
pip install langchain-openai==0.0.2 # OpenAI integrations (for API compatibility)
# Embedding and Vector Store
pip install sentence-transformers==2.2.2 # High-quality text embeddings
pip install faiss-cpu==1.7.4 # Efficient similarity search (CPU version)
# For GPU acceleration (if available): pip install faiss-gpu
# Document Processing
pip install pymupdf==1.23.8 # PDF text extraction
pip install python-docx==0.8.11 # Word document processing (optional)
pip install beautifulsoup4==4.12.2 # HTML parsing (optional)
# API and Environment Management
pip install openai==1.3.7 # OpenAI-compatible client
pip install python-dotenv==1.0.0 # Environment variable management
pip install requests==2.31.0 # HTTP requests
# Interactive and Development Tools
pip install ipywidgets==8.1.1 # Jupyter widgets for file upload
pip install tqdm==4.66.1 # Progress bars
pip install pandas==2.1.4 # Data manipulation (for analysis)
Verifying Installation
# Verify all critical dependencies are installed correctly
import sys
import importlib
def check_installation():
"""Check if all required packages are installed."""
required_packages = [
'langchain', 'langchain_community', 'sentence_transformers',
'faiss', 'fitz', 'openai', 'dotenv', 'tqdm'
]
missing_packages = []
for package in required_packages:
try:
importlib.import_module(package)
print(f"[OK] {package} - OK")
except ImportError:
missing_packages.append(package)
print(f"[MISSING] {package} - Missing")
if missing_packages:
print(f"\nWarning: Missing packages: {', '.join(missing_packages)}")
print("Please install them using pip install <package_name>")
else:
print("\nAll required packages are installed!")
check_installation()
Import Libraries with Error Handling
import os
import sys
import warnings
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import logging
# Configure logging for better debugging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
try:
# Document processing
import fitz # PyMuPDF for PDF text extraction
logger.info("PyMuPDF loaded successfully")
except ImportError as e:
logger.error(f"Failed to import PyMuPDF: {e}")
sys.exit(1)
try:
# Interactive widgets (for Jupyter environments)
import ipywidgets as widgets
from IPython.display import display, clear_output
logger.info("IPython widgets loaded successfully")
except ImportError:
logger.warning("IPython widgets not available (running outside Jupyter)")
widgets = None
try:
# LangChain components for building the RAG chain
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_community.chat_models import ChatOpenAI
from langchain.schema import Document
logger.info("LangChain components loaded successfully")
except ImportError as e:
logger.error(f"Failed to import LangChain components: {e}")
sys.exit(1)
try:
# Additional utilities
import numpy as np
from tqdm.auto import tqdm
import json
from datetime import datetime
logger.info("Utility libraries loaded successfully")
except ImportError as e:
logger.error(f"Failed to import utilities: {e}")
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
print("All libraries imported successfully!")
print(f"Current working directory: {os.getcwd()}")
print(f"Python version: {sys.version}")
2. Advanced Gravix Layer Client Configuration
Configure the client with comprehensive error handling, retry logic, and optimization settings.
Secure API Key Management
from dotenv import load_dotenv
import os
from pathlib import Path
def setup_api_key():
"""
Securely load API key from environment variables with multiple fallback options.
Priority order:
1. Environment variable GRAVIXLAYER_API_KEY
2. .env file in current directory
3. .env file in parent directory
4. Manual input (for interactive environments)
"""
# Load environment variables from .env file
env_files = [".env", "../.env", "../../.env"]
for env_file in env_files:
if Path(env_file).exists():
load_dotenv(env_file)
logger.info(f"[SUCCESS] Loaded environment from {env_file}")
break
# Try to get API key from environment
api_key = os.environ.get("GRAVIXLAYER_API_KEY")
if not api_key:
logger.warning("[WARNING] GRAVIXLAYER_API_KEY not found in environment variables")
# Interactive input as fallback (only in Jupyter/interactive environments)
try:
if widgets and hasattr(widgets, 'Password'):
print("Please enter your Gravix Layer API key:")
api_key_widget = widgets.Password(
description='API Key:',
style={'description_width': 'initial'}
)
display(api_key_widget)
# Note: In production, you should handle this more securely
else:
# Fallback for non-interactive environments
api_key = input("Enter your Gravix Layer API key: ").strip()
except:
logger.error("[ERROR] Failed to get API key interactively")
return None
if api_key and api_key != "YOUR_API_KEY_HERE":
# Mask the API key for logging (show only first 8 and last 4 characters)
masked_key = f"{api_key[:8]}...{api_key[-4:]}" if len(api_key) > 12 else "***"
logger.info(f"[SUCCESS] API key loaded: {masked_key}")
return api_key
else:
logger.error("[ERROR] No valid API key provided")
return None
# Load API key
api_key = setup_api_key()
Enhanced Client Configuration
from langchain_community.chat_models import ChatOpenAI
import time
from typing import Optional
class GravixLayerClient:
"""Enhanced Gravix Layer client with error handling and optimization."""
def __init__(self, api_key: str, model: str = "llama3.1:8b-instruct-fp16"):
"""
Initialize the Gravix Layer client with comprehensive configuration.
Args:
api_key: Your Gravix Layer API key
model: The model to use for text generation
"""
self.api_key = api_key
self.model = model
self.base_url = "https://api.gravixlayer.com/v1/inference"
# Initialize the client with optimal settings
self.llm = ChatOpenAI(
api_key=self.api_key,
base_url=self.base_url,
model=self.model,
temperature=0.1, # Low temperature for factual, consistent answers
max_tokens=1024, # Increased for more detailed responses
request_timeout=60, # 60 second timeout
max_retries=3, # Retry failed requests
)
# Test the connection
self._test_connection()
def _test_connection(self):
"""Test the API connection with a simple query."""
try:
logger.info("[PROCESSING] Testing Gravix Layer API connection...")
test_response = self.llm.invoke("Say 'Hello, RAG system!' if you can read this.")
logger.info("[SUCCESS] Gravix Layer API connection successful")
logger.info(f"[DATA] Model response: {test_response.content[:50]}...")
except Exception as e:
logger.error(f"[ERROR] Failed to connect to Gravix Layer API: {e}")
raise ConnectionError(f"Cannot connect to Gravix Layer API: {e}")
def get_model_info(self) -> Dict:
"""Get information about the current model."""
return {
"model": self.model,
"base_url": self.base_url,
"temperature": self.llm.temperature,
"max_tokens": self.llm.max_tokens,
}
# Initialize the enhanced client
if api_key:
try:
gravix_client = GravixLayerClient(api_key)
llm = gravix_client.llm
print("Gravix Layer client configured successfully")
print(f"Model info: {gravix_client.get_model_info()}")
except Exception as e:
logger.error(f"[ERROR] Failed to initialize Gravix Layer client: {e}")
llm = None
else:
logger.error("[ERROR] Cannot proceed without valid API key")
llm = None
3. Advanced Document Ingestion and Processing
Document processing is a critical component that significantly impacts the quality of your RAG system. We'll implement robust text extraction with support for multiple formats and error handling.
Comprehensive Document Processor
import mimetypes
from pathlib import Path
import hashlib
from typing import Union, List, Dict
class DocumentProcessor:
"""Advanced document processing with multi-format support and optimization."""
def __init__(self):
"""Initialize with supported formats."""
self.supported_formats = {
'.pdf': self._extract_pdf_text,
'.txt': self._extract_txt_text,
'.md': self._extract_txt_text, # Markdown files
}
self.processed_docs = {} # Cache for processed documents
def extract_text_from_file(self, file_path: Union[str, Path]) -> Dict:
"""
Extract text from a file with comprehensive error handling and metadata.
Args:
file_path: Path to the document file
Returns:
Dictionary containing extracted text, metadata, and processing info
"""
file_path = Path(file_path)
# Validate file existence
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Check file size (warn if > 50MB)
file_size = file_path.stat().st_size
if file_size > 50 * 1024 * 1024: # 50MB
logger.warning(f"[WARNING] Large file detected: {file_size / 1024 / 1024:.1f}MB")
# Generate file hash for caching
file_hash = self._get_file_hash(file_path)
if file_hash in self.processed_docs:
logger.info("[INFO] Using cached document processing result")
return self.processed_docs[file_hash]
# Determine file format
file_extension = file_path.suffix.lower()
if file_extension not in self.supported_formats:
raise ValueError(f"Unsupported file format: {file_extension}")
# Extract text using appropriate method
try:
logger.info(f"[PROCESSING] Processing {file_path.name} ({file_extension})...")
text_content = self.supported_formats[file_extension](file_path)
# Prepare result with metadata
result = {
'text': text_content,
'metadata': {
'filename': file_path.name,
'file_path': str(file_path),
'file_size': file_size,
'file_hash': file_hash,
'format': file_extension,
'character_count': len(text_content),
'word_count': len(text_content.split()),
'processed_at': datetime.now().isoformat()
}
}
# Cache the result
self.processed_docs[file_hash] = result
logger.info(f"[SUCCESS] Successfully processed {file_path.name}")
logger.info(f"[DATA] Extracted {len(text_content):,} characters, {len(text_content.split()):,} words")
return result
except Exception as e:
logger.error(f"[ERROR] Failed to process {file_path.name}: {e}")
raise
def _extract_pdf_text(self, file_path: Path) -> str:
"""Extract text from PDF using PyMuPDF with advanced options."""
text_content = []
with fitz.open(file_path) as doc:
logger.info(f"[DOCUMENT] Processing PDF with {len(doc)} pages")
for page_num in tqdm(range(len(doc)), desc="Extracting pages"):
page = doc[page_num]
# Extract text with formatting preservation
page_text = page.get_text("text")
# Clean up common PDF artifacts
page_text = self._clean_pdf_text(page_text)
if page_text.strip(): # Only add non-empty pages
text_content.append(f"--- Page {page_num + 1} ---\n{page_text}")
return "\n\n".join(text_content)
def _extract_txt_text(self, file_path: Path) -> str:
"""Extract text from plain text files with encoding detection."""
# Try different encodings
encodings = ['utf-8', 'utf-16', 'latin-1', 'cp1252']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as file:
content = file.read()
logger.info(f"[TEXT] Successfully read text file with {encoding} encoding")
return content
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"Could not decode {file_path} with any supported encoding")
def _clean_pdf_text(self, text: str) -> str:
"""Clean common PDF text extraction artifacts."""
import re
# Remove excessive whitespace
text = re.sub(r'\n\s*\n\s*\n', '\n\n', text)
# Fix broken words at line endings
text = re.sub(r'(\w)-\s*\n\s*(\w)', r'\1\2', text)
# Normalize whitespace
text = re.sub(r'[ \t]+', ' ', text)
return text.strip()
def _get_file_hash(self, file_path: Path) -> str:
"""Generate MD5 hash of file for caching."""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
# Initialize the document processor
doc_processor = DocumentProcessor()
Interactive File Upload for Jupyter Environments
def create_file_upload_widget():
"""Create an interactive file upload widget for Jupyter environments."""
if not widgets:
logger.warning("[WARNING] File upload widget not available outside Jupyter")
return None
upload_widget = widgets.FileUpload(
accept='.pdf,.txt,.md', # Accepted file types
multiple=False,
description='Upload Document',
style={'description_width': 'initial'},
layout={'width': '300px'}
)
status_output = widgets.Output()
def on_upload_change(change):
"""Handle file upload events."""
with status_output:
clear_output()
if upload_widget.value:
file_info = upload_widget.value[0]
print(f"Uploaded: {file_info['name']}")
print(f"Size: {len(file_info['content']):,} bytes")
print(f"Type: {file_info['type']}")
upload_widget.observe(on_upload_change, names='value')
return upload_widget, status_output
# Create upload widget if in Jupyter environment
if widgets:
print("File upload widget available:")
upload_widget, upload_status = create_file_upload_widget()
display(widgets.VBox([upload_widget, upload_status]))
else:
print("Running in non-Jupyter environment - use file paths directly")
Advanced Text Extraction Function
def extract_text_from_source(source: Union[str, Path, dict]) -> Dict:
"""
Extract text from various sources with comprehensive error handling.
Args:
source: File path, uploaded file dict, or raw text
Returns:
Dictionary with extracted text and metadata
"""
try:
if isinstance(source, (str, Path)):
# File path provided
return doc_processor.extract_text_from_file(source)
elif isinstance(source, dict) and 'content' in source:
# Uploaded file from widget
filename = source.get('name', 'uploaded_file')
content = source['content']
# Save temporarily and process
temp_path = Path(f"/tmp/{filename}")
temp_path.parent.mkdir(exist_ok=True)
with open(temp_path, 'wb') as f:
f.write(content)
try:
result = doc_processor.extract_text_from_file(temp_path)
return result
finally:
# Clean up temporary file
if temp_path.exists():
temp_path.unlink()
elif isinstance(source, str):
# Raw text provided
return {
'text': source,
'metadata': {
'filename': 'raw_text',
'character_count': len(source),
'word_count': len(source.split()),
'format': 'text',
'processed_at': datetime.now().isoformat()
}
}
else:
raise ValueError(f"Unsupported source type: {type(source)}")
except Exception as e:
logger.error(f"[ERROR] Text extraction failed: {e}")
raise
# Example usage
def demo_text_extraction():
"""Demonstrate text extraction capabilities."""
sample_text = """
This is a sample document for testing the RAG system.
It contains multiple paragraphs with different topics.
The RAG system will split this text into chunks and create embeddings
for semantic search capabilities.
"""
result = extract_text_from_source(sample_text)
print("Sample extraction result:")
print(f" Characters: {result['metadata']['character_count']}")
print(f" Words: {result['metadata']['word_count']}")
print(f" Format: {result['metadata']['format']}")
# Run demo
demo_text_extraction()
4. Intelligent Text Processing and Chunking
Proper text chunking is crucial for RAG performance. We'll implement sophisticated chunking strategies that preserve context and optimize for semantic search.
Advanced Text Chunking Strategies
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
TokenTextSplitter,
SpacyTextSplitter
)
from typing import List, Dict, Any
import re
class AdvancedTextProcessor:
"""Advanced text processing with multiple chunking strategies."""
def __init__(self):
"""Initialize with different text splitters."""
self.splitters = {
'recursive': RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
),
'token': TokenTextSplitter(
chunk_size=1000,
chunk_overlap=200
),
'semantic': RecursiveCharacterTextSplitter(
chunk_size=1500,
chunk_overlap=300,
separators=["\n\n", "\n", ".", "!", "?"]
)
}
def preprocess_text(self, text: str) -> str:
"""
Preprocess text to improve chunking quality.
Args:
text: Raw extracted text
Returns:
Cleaned and preprocessed text
"""
logger.info("[CLEAN] Preprocessing text...")
# Remove excessive whitespace
text = re.sub(r'\n\s*\n\s*\n', '\n\n', text)
# Fix common OCR errors
text = re.sub(r'(\w)-\s*\n\s*(\w)', r'\1\2', text) # Fix hyphenated words
text = re.sub(r'([.!?])\s*\n\s*([A-Z])', r'\1 \2', text) # Fix sentence breaks
# Normalize quotation marks
text = re.sub(r'["""]', '"', text)
text = re.sub(r'[''']', "'", text)
# Remove page numbers and headers/footers (common patterns)
text = re.sub(r'\n\s*\d+\s*\n', '\n', text) # Page numbers
text = re.sub(r'\n\s*Page \d+.*?\n', '\n', text) # "Page X" patterns
# Normalize whitespace
text = re.sub(r'[ \t]+', ' ', text)
text = re.sub(r'\n\s*\n', '\n\n', text)
logger.info(f"[SUCCESS] Text preprocessing complete")
return text.strip()
def create_chunks(self, text: str, strategy: str = 'recursive',
custom_params: Dict = None) -> List[Document]:
"""
Create text chunks using the specified strategy.
Args:
text: Input text to chunk
strategy: Chunking strategy ('recursive', 'token', 'semantic')
custom_params: Custom parameters for the text splitter
Returns:
List of Document objects with chunks and metadata
"""
if strategy not in self.splitters:
raise ValueError(f"Unknown strategy: {strategy}")
# Apply custom parameters if provided
splitter = self.splitters[strategy]
if custom_params:
for param, value in custom_params.items():
if hasattr(splitter, param):
setattr(splitter, param, value)
logger.info(f"[TEXT] Creating chunks using '{strategy}' strategy...")
# Preprocess text
processed_text = self.preprocess_text(text)
# Split text into chunks
chunks = splitter.split_text(processed_text)
# Create Document objects with metadata
documents = []
for i, chunk in enumerate(chunks):
doc = Document(
page_content=chunk,
metadata={
'chunk_id': i,
'chunk_size': len(chunk),
'word_count': len(chunk.split()),
'strategy': strategy,
'created_at': datetime.now().isoformat()
}
)
documents.append(doc)
logger.info(f"[SUCCESS] Created {len(documents)} chunks")
self._analyze_chunks(documents)
return documents
def _analyze_chunks(self, documents: List[Document]):
"""Analyze chunk quality and provide insights."""
if not documents:
return
sizes = [len(doc.page_content) for doc in documents]
word_counts = [len(doc.page_content.split()) for doc in documents]
logger.info("[DATA] Chunk Analysis:")
logger.info(f" Total chunks: {len(documents)}")
logger.info(f" Average size: {np.mean(sizes):.0f} chars")
logger.info(f" Size range: {min(sizes)} - {max(sizes)} chars")
logger.info(f" Average words: {np.mean(word_counts):.0f}")
logger.info(f" Word range: {min(word_counts)} - {max(word_counts)}")
# Warn about problematic chunks
very_small = [i for i, size in enumerate(sizes) if size < 100]
very_large = [i for i, size in enumerate(sizes) if size > 2000]
if very_small:
logger.warning(f"[WARNING] {len(very_small)} chunks are very small (<100 chars)")
if very_large:
logger.warning(f"[WARNING] {len(very_large)} chunks are very large (>2000 chars)")
def optimize_chunks(self, documents: List[Document],
min_size: int = 100, max_size: int = 2000) -> List[Document]:
"""
Optimize chunks by merging small ones and splitting large ones.
Args:
documents: List of document chunks
min_size: Minimum chunk size in characters
max_size: Maximum chunk size in characters
Returns:
Optimized list of document chunks
"""
logger.info("[CONFIG] Optimizing chunks...")
optimized_docs = []
current_chunk = ""
current_metadata = None
for doc in documents:
chunk_size = len(doc.page_content)
if chunk_size < min_size and current_chunk:
# Merge with previous chunk
current_chunk += f"\n\n{doc.page_content}"
else:
# Save previous chunk if exists
if current_chunk:
optimized_docs.append(Document(
page_content=current_chunk,
metadata=current_metadata
))
if chunk_size > max_size:
# Split large chunk
sub_chunks = self._split_large_chunk(doc.page_content, max_size)
for i, sub_chunk in enumerate(sub_chunks):
optimized_docs.append(Document(
page_content=sub_chunk,
metadata={**doc.metadata, 'sub_chunk': i}
))
current_chunk = ""
else:
current_chunk = doc.page_content
current_metadata = doc.metadata
# Add final chunk
if current_chunk:
optimized_docs.append(Document(
page_content=current_chunk,
metadata=current_metadata
))
logger.info(f"[SUCCESS] Optimized {len(documents)} chunks to {len(optimized_docs)} chunks")
return optimized_docs
def _split_large_chunk(self, text: str, max_size: int) -> List[str]:
"""Split a large chunk into smaller ones at sentence boundaries."""
sentences = re.split(r'[.!?]+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk + sentence) < max_size:
current_chunk += sentence + ". "
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# Initialize the advanced text processor
text_processor = AdvancedTextProcessor()
Embedding Model Management
class EmbeddingManager:
"""Manage different embedding models and their configurations."""
def __init__(self):
"""Initialize with available embedding models."""
self.models = {
'all-MiniLM-L6-v2': {
'model_name': 'all-MiniLM-L6-v2',
'dimension': 384,
'max_seq_length': 256,
'description': 'Fast, lightweight model for general use'
},
'all-mpnet-base-v2': {
'model_name': 'all-mpnet-base-v2',
'dimension': 768,
'max_seq_length': 384,
'description': 'High-quality model for better accuracy'
},
'multi-qa-MiniLM-L6-cos-v1': {
'model_name': 'multi-qa-MiniLM-L6-cos-v1',
'dimension': 384,
'max_seq_length': 512,
'description': 'Optimized for question-answering tasks'
}
}
self.current_model = None
self.embeddings = None
def load_model(self, model_name: str = 'all-MiniLM-L6-v2',
cache_folder: str = None) -> HuggingFaceEmbeddings:
"""
Load an embedding model with caching and optimization.
Args:
model_name: Name of the model to load
cache_folder: Custom cache folder for models
Returns:
HuggingFaceEmbeddings instance
"""
if model_name not in self.models:
raise ValueError(f"Unknown model: {model_name}")
if self.current_model == model_name and self.embeddings:
logger.info(f"[INFO] Using cached model: {model_name}")
return self.embeddings
model_config = self.models[model_name]
logger.info(f"[DOWNLOAD] Loading embedding model: {model_name}")
logger.info(f" Description: {model_config['description']}")
logger.info(f" Dimension: {model_config['dimension']}")
logger.info(f" Max sequence length: {model_config['max_seq_length']}")
# Configure model parameters
model_kwargs = {
'device': 'cpu', # Use 'cuda' if GPU available
'trust_remote_code': True
}
if cache_folder:
model_kwargs['cache_folder'] = cache_folder
try:
self.embeddings = HuggingFaceEmbeddings(
model_name=model_config['model_name'],
model_kwargs=model_kwargs,
encode_kwargs={'normalize_embeddings': True}
)
self.current_model = model_name
# Test the model
test_embedding = self.embeddings.embed_query("Test embedding")
logger.info(f"[SUCCESS] Model loaded successfully (dimension: {len(test_embedding)})")
return self.embeddings
except Exception as e:
logger.error(f"[ERROR] Failed to load model {model_name}: {e}")
raise
def get_model_info(self) -> Dict:
"""Get information about the current model."""
if not self.current_model:
return {"status": "No model loaded"}
return {
"current_model": self.current_model,
"config": self.models[self.current_model],
"status": "loaded"
}
def benchmark_models(self, test_texts: List[str]) -> Dict:
"""Benchmark different models for performance comparison."""
results = {}
for model_name in self.models:
logger.info(f"[PROCESSING] Benchmarking {model_name}...")
try:
start_time = time.time()
embeddings = self.load_model(model_name)
# Test embedding generation
test_embeddings = embeddings.embed_documents(test_texts)
end_time = time.time()
results[model_name] = {
'embedding_time': end_time - start_time,
'dimension': len(test_embeddings[0]),
'status': 'success'
}
except Exception as e:
results[model_name] = {
'status': 'failed',
'error': str(e)
}
return results
# Initialize embedding manager
embedding_manager = EmbeddingManager()
5. Build the RAG Pipeline
Tie everything together by building the RetrievalQA
chain:
def build_rag_chain(vectorstore, llm):
"""Builds the complete RAG chain."""
# The retriever's job is to fetch relevant documents from the vector store
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 3} # Retrieve the top 3 most relevant chunks
)
# The RetrievalQA chain combines the retriever and the LLM
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff", # 'stuff' puts all retrieved text directly into the prompt
retriever=retriever,
return_source_documents=True # We want to see which chunks were used
)
return qa_chain
def run_pipeline(file_path, llm):
"""Orchestrates the full pipeline from file to ready-to-use QA chain."""
try:
print("\nStarting RAG pipeline...")
print("="*40)
# Step 1: Extract Text
doc_text = extract_text_from_pdf(file_path)
# Step 2: Prepare Vector Store
vector_store = prepare_vectorstore(doc_text)
# Step 3: Build RAG Chain
qa_chain = build_rag_chain(vector_store, llm)
print("="*40)
print("[SUCCESS] Pipeline complete! The 'qa_chain' is ready.")
return qa_chain
except Exception as e:
print(f"\n[ERROR] An error occurred: {e}")
return None
# Run the pipeline with your PDF file
qa_chain = run_pipeline("path/to/your/document.pdf", llm)
6. Ask Questions About Your Document
The RAG pipeline is now ready. You can ask questions and get contextual answers:
def ask_question(qa_chain, question):
"""Ask a question and get an answer with sources."""
if qa_chain:
print(f"Question: Question: {question}")
print("-"*50)
print("[AI] Thinking...")
# Use the .invoke method to run the chain
result = qa_chain.invoke(question)
print("\nAnswer:")
print(result['result'])
print("\nSources Used:")
for i, doc in enumerate(result.get('source_documents', [])):
print(f" > Source {i+1}: \"{doc.page_content[:120].strip()}...\"")
else:
print("QA chain not found. Please make sure the pipeline ran successfully.")
# Example usage
ask_question(qa_chain, "What is the main topic of this document?")
Complete Example
Here's a complete working example:
import os
import fitz
from langchain_community.vectorstores import FAISS
## Complete Working Example
Here's a simplified but production-ready implementation that combines all the concepts:
```python
import os
import fitz
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_community.chat_models import ChatOpenAI
from dotenv import load_dotenv
import logging
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ProductionRAGSystem:
"""Production-ready RAG system with all optimizations."""
def __init__(self, api_key: str):
"""Initialize the RAG system."""
self.api_key = api_key
self.llm = None
self.qa_chain = None
self.vectorstore = None
self._initialize_llm()
def _initialize_llm(self):
"""Initialize the Gravix Layer client."""
try:
self.llm = ChatOpenAI(
api_key=self.api_key,
base_url="https://api.gravixlayer.com/v1/inference",
model="llama3.1:8b-instruct-fp16",
temperature=0.1,
max_tokens=1024,
request_timeout=60,
max_retries=3
)
logger.info("[SUCCESS] Gravix Layer client initialized")
except Exception as e:
logger.error(f"[ERROR] Failed to initialize LLM: {e}")
raise
def process_document(self, file_path: str) -> bool:
"""Process a document and build the RAG system."""
try:
# Extract text
logger.info(f"[DOCUMENT] Processing document: {file_path}")
text = self._extract_text_from_pdf(file_path)
if not text.strip():
raise ValueError("No text extracted from document")
# Split into optimized chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=100,
length_function=len,
separators=["\n\n", "\n", ". ", " ", ""]
)
chunks = text_splitter.split_text(text)
logger.info(f"[TEXT] Created {len(chunks)} text chunks")
# Create embeddings and vector store
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
self.vectorstore = FAISS.from_texts(
chunks,
embedding=embeddings,
metadatas=[{"chunk": i} for i in range(len(chunks))]
)
logger.info("[SEARCH] Vector store created")
# Build optimized retrieval chain
retriever = self.vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 5,
"score_threshold": 0.7
}
)
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
verbose=False
)
logger.info("[SUCCESS] RAG system ready")
return True
except Exception as e:
logger.error(f"[ERROR] Document processing failed: {e}")
return False
def _extract_text_from_pdf(self, file_path: str) -> str:
"""Extract text from PDF with error handling."""
try:
with fitz.open(file_path) as doc:
text = ""
for page_num, page in enumerate(doc):
page_text = page.get_text()
if page_text.strip():
text += f"\n--- Page {page_num + 1} ---\n{page_text}"
return text
except Exception as e:
logger.error(f"[ERROR] PDF extraction failed: {e}")
raise
@monitor_performance # Apply monitoring decorator
def query(self, question: str) -> dict:
"""Query the RAG system with comprehensive error handling."""
if not self.qa_chain:
raise ValueError("RAG system not initialized. Process a document first.")
try:
# Validate input
if not security_manager.validate_input(question):
raise ValueError("Invalid input detected")
logger.info(f"[SEARCH] Processing query: {question[:50]}...")
# Process the query
result = self.qa_chain.invoke(question)
# Sanitize output
sanitized_answer = security_manager.sanitize_output(result['result'])
# Structure the response
response = {
'question': question,
'answer': sanitized_answer,
'sources': [
{
'content': doc.page_content[:200] + "...",
'metadata': doc.metadata
}
for doc in result.get('source_documents', [])
],
'confidence': self._calculate_confidence(result)
}
logger.info("[SUCCESS] Query processed successfully")
return response
except Exception as e:
logger.error(f"[ERROR] Query processing failed: {e}")
raise
def _calculate_confidence(self, result: dict) -> float:
"""Calculate confidence score based on retrieval quality."""
if not result.get('source_documents'):
return 0.0
# Simple confidence based on number of sources and text length
num_sources = len(result['source_documents'])
answer_length = len(result['result'])
base_confidence = min(num_sources / 5.0, 1.0) # More sources = higher confidence
length_factor = min(answer_length / 200.0, 1.0) # Longer answers = higher confidence
return (base_confidence + length_factor) / 2.0
# Usage example
def main():
"""Main execution function."""
try:
# Initialize the system
api_key = os.environ.get("GRAVIXLAYER_API_KEY")
if not api_key:
raise ValueError("GRAVIXLAYER_API_KEY not found in environment variables")
rag_system = ProductionRAGSystem(api_key)
# Process a document
document_path = "your_document.pdf" # Replace with your document
if rag_system.process_document(document_path):
# Interactive query loop
print("\nRAG System Ready! Ask questions about your document.")
print("Type 'exit' to quit, 'metrics' to see performance stats.\n")
while True:
question = input("Question: ").strip()
if question.lower() == 'exit':
break
elif question.lower() == 'metrics':
print(monitor.performance_report())
continue
elif not question:
continue
try:
result = rag_system.query(question)
print(f"\nAnswer: {result['answer']}")
print(f"Confidence: {result['confidence']:.2%}")
print(f"Sources: {len(result['sources'])} documents")
# Optionally show quality assessment
if 'quality_assessor' in globals():
quality_scores = quality_assessor.assess_answer_quality(
question, result['answer'],
"\n".join([s['content'] for s in result['sources']])
)
if 'overall' in quality_scores:
print(f"Quality Score: {quality_scores['overall']}/10")
print("-" * 50)
except Exception as e:
print(f"Error: {e}")
else:
print("Failed to process document")
except Exception as e:
logger.error(f"[ERROR] System initialization failed: {e}")
if __name__ == "__main__":
main()
Performance Benchmarking
Benchmark Your RAG System
import time
import statistics
from typing import List, Dict
class RAGBenchmark:
"""Comprehensive benchmarking suite for RAG systems."""
def __init__(self, rag_system):
self.rag_system = rag_system
self.test_queries = [
"What is the main topic of this document?",
"Can you provide a summary of the key points?",
"What are the most important facts mentioned?",
"How does this relate to current industry trends?",
"What recommendations are provided?"
]
def run_performance_test(self, num_iterations: int = 10) -> Dict:
"""Run comprehensive performance tests."""
print(f"[RUNNING] Running performance benchmark ({num_iterations} iterations)...")
results = {
'response_times': [],
'success_rate': 0,
'average_response_time': 0,
'queries_per_second': 0,
'error_count': 0
}
start_time = time.time()
for i in range(num_iterations):
for query in self.test_queries:
try:
query_start = time.time()
result = self.rag_system.query(query)
query_time = time.time() - query_start
results['response_times'].append(query_time)
except Exception as e:
results['error_count'] += 1
print(f"[ERROR] Query failed: {e}")
# Calculate statistics
total_time = time.time() - start_time
total_queries = num_iterations * len(self.test_queries)
successful_queries = total_queries - results['error_count']
if results['response_times']:
results['average_response_time'] = statistics.mean(results['response_times'])
results['median_response_time'] = statistics.median(results['response_times'])
results['95th_percentile'] = statistics.quantiles(results['response_times'], n=20)[18]
results['success_rate'] = successful_queries / total_queries
results['queries_per_second'] = successful_queries / total_time
self._print_benchmark_results(results)
return results
def _print_benchmark_results(self, results: Dict):
"""Print formatted benchmark results."""
print("\n" + "="*50)
print("[DATA] RAG SYSTEM BENCHMARK RESULTS")
print("="*50)
print(f"[SUCCESS] Success Rate: {results['success_rate']:.2%}")
print(f"[FAST] Queries per Second: {results['queries_per_second']:.2f}")
print(f"[TIME] Average Response Time: {results['average_response_time']:.2f}s")
if 'median_response_time' in results:
print(f"[DATA] Median Response Time: {results['median_response_time']:.2f}s")
print(f"[METRICS] 95th Percentile: {results['95th_percentile']:.2f}s")
print(f"[ERROR] Error Count: {results['error_count']}")
# Performance recommendations
print("\n[TARGET] Performance Analysis:")
if results['average_response_time'] > 3.0:
print("[WARNING] High response times detected - consider optimization")
elif results['average_response_time'] < 1.0:
print("[INFO] Excellent response times!")
else:
print("[SUCCESS] Good response times")
if results['success_rate'] < 0.95:
print("[WARNING] Low success rate - check error handling")
else:
print("[SUCCESS] High success rate")
# Run benchmark if RAG system is available
if 'rag_system' in locals():
benchmark = RAGBenchmark(rag_system)
benchmark_results = benchmark.run_performance_test(5)
Key Takeaways and Best Practices
Production-Ready Principles
- Robust Error Handling: Always implement comprehensive error handling at every stage of the pipeline
- Security First: Validate inputs, sanitize outputs, and secure API keys properly
- Performance Monitoring: Track metrics, response times, and system health continuously
- Modular Design: Build components that can be easily swapped, upgraded, or scaled
- Documentation: Maintain clear documentation for configuration, deployment, and troubleshooting
Performance Optimization Strategies
- Chunk Size Optimization: Find the sweet spot between context and specificity (typically 500-1000 tokens)
- Embedding Model Selection: Balance accuracy vs. speed based on your use case
- Vector Store Indexing: Use proper indexing strategies for large document collections
- Caching: Implement intelligent caching for frequent queries and embeddings
- Async Processing: Use asynchronous processing for better concurrency
Security Considerations
- API Key Management: Use environment variables and secure storage solutions
- Input Validation: Always validate and sanitize user inputs
- Rate Limiting: Implement rate limiting to prevent abuse
- Audit Logging: Log all security-relevant events for compliance and monitoring
- Access Control: Implement proper authentication and authorization
Scaling Strategies
- Horizontal Scaling: Design for multiple instances with load balancing
- Database Optimization: Use appropriate vector databases for your scale
- Content Delivery: Implement CDN for document delivery if needed
- Monitoring: Set up comprehensive monitoring and alerting
- Backup and Recovery: Implement proper backup strategies for vector stores
Advanced Use Cases and Extensions
1. Multi-Modal RAG
Extend the system to handle images, tables, and structured data alongside text.
2. Conversational RAG
Add conversation memory and context tracking for multi-turn conversations.
3. Real-time RAG
Implement streaming responses and real-time document updates.
4. Multi-Language RAG
Support multiple languages with appropriate embedding models and LLMs.
5. Domain-Specific RAG
Customize for specific domains like legal, medical, or technical documentation.
Next Steps and Further Learning
Recommended Learning Path
- Master the Basics: Ensure you understand vector embeddings and semantic search
- Experiment with Models: Try different embedding models and LLMs for your use case
- Production Deployment: Practice deploying with Docker, Kubernetes, or cloud services
- Advanced Techniques: Explore query expansion, re-ranking, and hybrid search
- Domain Expertise: Specialize in your specific domain requirements
Additional Resources
- Gravix Layer Documentation: https://docs.gravixlayer.com
- LangChain Documentation: https://python.langchain.com
- Vector Database Guides: Explore Pinecone, Weaviate, or Chroma for production use
- RAG Research Papers: Stay updated with latest academic research
- Community Forums: Join RAG and LLM communities for best practices
Related Resources
- Multi-Modal AI - Working with images and text
- Function Calling - Integrate with external tools
- Cookbooks - Practical examples and GitHub repositories
Congratulations! You've built a comprehensive, production-ready RAG system with Gravix Layer. This foundation can be extended and customized for virtually any use case requiring intelligent document Q&A capabilities.
Advanced RAG Techniques and Optimization
1. Multi-Vector Retrieval Strategies
Implement sophisticated retrieval methods beyond simple semantic search:
class AdvancedRetriever:
"""Advanced retrieval system with multiple strategies."""
def __init__(self, vectorstore, llm, embeddings):
self.vectorstore = vectorstore
self.llm = llm
self.embeddings = embeddings
def hybrid_retrieval(self, query: str, top_k: int = 10) -> List[Document]:
"""
Combine semantic search with keyword matching for better results.
Args:
query: User query
top_k: Number of documents to retrieve
Returns:
List of relevant documents
"""
# Semantic search
semantic_docs = self.vectorstore.similarity_search(query, k=top_k)
# Simple keyword matching (can be enhanced with BM25)
keyword_docs = self._keyword_search(query, top_k)
# Combine and re-rank results
combined_docs = self._combine_and_rerank(semantic_docs, keyword_docs, query)
return combined_docs[:top_k]
def _keyword_search(self, query: str, top_k: int) -> List[Document]:
"""Simple keyword-based search."""
# This is a simplified version - in production, use BM25 or Elasticsearch
query_words = set(query.lower().split())
all_docs = self.vectorstore.similarity_search("", k=100) # Get more docs
scored_docs = []
for doc in all_docs:
content_words = set(doc.page_content.lower().split())
overlap = len(query_words.intersection(content_words))
if overlap > 0:
scored_docs.append((doc, overlap))
# Sort by keyword overlap
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, _ in scored_docs[:top_k]]
def _combine_and_rerank(self, semantic_docs: List[Document],
keyword_docs: List[Document], query: str) -> List[Document]:
"""Combine and re-rank documents using advanced scoring."""
# Simple combination - in production, use learning-to-rank models
doc_scores = {}
# Score semantic results
for i, doc in enumerate(semantic_docs):
doc_id = id(doc)
doc_scores[doc_id] = doc_scores.get(doc_id, 0) + (10 - i) * 0.6
# Score keyword results
for i, doc in enumerate(keyword_docs):
doc_id = id(doc)
doc_scores[doc_id] = doc_scores.get(doc_id, 0) + (10 - i) * 0.4
# Get unique documents and sort by combined score
unique_docs = {}
for doc in semantic_docs + keyword_docs:
unique_docs[id(doc)] = doc
sorted_docs = sorted(unique_docs.values(),
key=lambda x: doc_scores.get(id(x), 0),
reverse=True)
return sorted_docs
# Initialize advanced retriever
if 'vectorstore' in locals() and vectorstore:
advanced_retriever = AdvancedRetriever(vectorstore, llm, embeddings)
print("[SUCCESS] Advanced retriever initialized")
2. Context-Aware Query Expansion
Enhance user queries with context and related terms:
class QueryExpander:
"""Intelligent query expansion for better retrieval."""
def __init__(self, llm):
self.llm = llm
def expand_query(self, query: str, context: str = "") -> str:
"""
Expand user query with related terms and context.
Args:
query: Original user query
context: Additional context (e.g., conversation history)
Returns:
Expanded query string
"""
expansion_prompt = f"""
You are a query expansion expert. Your task is to expand the following query with related terms, synonyms, and context to improve information retrieval.
Original Query: {query}
Context: {context}
Please provide an expanded query that includes:
1. The original query
2. Related terms and synonyms
3. Different ways to phrase the same question
4. Technical terms that might be relevant
Expanded Query:
"""
try:
response = self.llm.invoke(expansion_prompt)
expanded_query = response.content.strip()
logger.info(f"[METRICS] Query expanded: '{query}' -> '{expanded_query[:100]}...'")
return expanded_query
except Exception as e:
logger.error(f"[ERROR] Query expansion failed: {e}")
return query # Return original query if expansion fails
# Initialize query expander
if llm:
query_expander = QueryExpander(llm)
print("[SUCCESS] Query expander initialized")
3. Answer Quality Assessment
Implement automatic quality assessment for generated answers:
class AnswerQualityAssessor:
"""Assess the quality of generated answers."""
def __init__(self, llm):
self.llm = llm
def assess_answer_quality(self, question: str, answer: str,
context: str) -> Dict[str, float]:
"""
Assess the quality of a generated answer.
Args:
question: Original question
answer: Generated answer
context: Retrieved context used for generation
Returns:
Dictionary with quality scores
"""
assessment_prompt = f"""
You are an expert evaluator of question-answering systems. Please assess the quality of the following answer based on the given question and context.
Question: {question}
Context: {context}
Answer: {answer}
Please rate the answer on a scale of 1-10 for each of the following criteria:
1. Relevance: How well does the answer address the question?
2. Accuracy: Is the information correct based on the context?
3. Completeness: Does the answer fully address the question?
4. Clarity: Is the answer clear and well-structured?
5. Grounding: Is the answer based on the provided context?
Please respond in the following format:
Relevance: [score]
Accuracy: [score]
Completeness: [score]
Clarity: [score]
Grounding: [score]
Overall: [average score]
"""
try:
response = self.llm.invoke(assessment_prompt)
scores = self._parse_assessment_response(response.content)
return scores
except Exception as e:
logger.error(f"[ERROR] Answer quality assessment failed: {e}")
return {"error": "Assessment failed"}
def _parse_assessment_response(self, response: str) -> Dict[str, float]:
"""Parse assessment response into structured scores."""
scores = {}
lines = response.strip().split('\n')
for line in lines:
if ':' in line:
key, value = line.split(':', 1)
try:
scores[key.strip().lower()] = float(value.strip())
except ValueError:
continue
return scores
# Initialize answer quality assessor
if llm:
quality_assessor = AnswerQualityAssessor(llm)
print("[SUCCESS] Answer quality assessor initialized")
Conclusion
This comprehensive guide now covers:
Congratulations! You have successfully built a fully functional Retrieval-Augmented Generation pipeline. You can now use this notebook as a template to build applications that reason about your own private documents.
Key Takeaways:
- RAG is powerful: It grounds LLMs in factual data, improving accuracy and relevance.
- The stack is modular: You can easily swap out components. For example, you could use a different embedding model, a persistent vector database like ChromaDB, or a different LLM from the Gravix Layer platform.
- Security is important: Always handle API keys securely using environment variables, especially in production applications.