How to build a production-ready research assistant API using FastAPI and Tatry
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from tatry import TatryRetriever
from langchain_core.documents import Document
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel
from typing import List, Optional
import asyncio
import redis
import json
app = FastAPI(title="Research Assistant API")
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
# Initialize TatryRetriever
retriever = TatryRetriever(
api_key="your-api-key",
model="gpt-3.5-turbo" # Default model for retrieval
)
# Initialize LLM
llm = ChatOpenAI(temperature=0)
# Pydantic models
class ResearchRequest(BaseModel):
topic: str
depth: str = "standard" # quick, standard, deep
include_citations: bool = True
model: Optional[str] = None
class ResearchResponse(BaseModel):
research_id: str
status: str
estimated_time: int
class ResearchResult(BaseModel):
summary: str
key_points: List[str]
citations: List[dict]
cost: float
class ResearchService:
def __init__(self, retriever: TatryRetriever, llm):
self.retriever = retriever
self.llm = llm
self.depth_settings = {
"quick": {"max_sources": 3},
"standard": {"max_sources": 5},
"deep": {"max_sources": 10}
}
async def start_research(self, topic: str, depth: str) -> dict:
# Configure retriever based on depth
settings = self.depth_settings[depth]
# Generate research ID
research_id = f"research_{topic[:20]}_{depth}"
return {
"research_id": research_id,
"estimated_time": settings["max_sources"] * 2 # Rough estimate in seconds
}
async def conduct_research(self, topic: str, depth: str, include_citations: bool) -> dict:
# Get relevant documents using LangChain integration
docs = await self.retriever.ainvoke(topic)
# Extract key information
key_points = await self._extract_key_points(docs)
summary = await self._generate_summary(docs)
# Get citations if requested
citations = await self._extract_citations(docs) if include_citations else []
# Calculate total tokens used from metadata
total_tokens = sum(doc.metadata.get("total_tokens", 0) for doc in docs)
return {
"summary": summary,
"key_points": key_points,
"citations": citations,
"cost": cost
}
async def _extract_key_points(self, docs: List[dict]) -> List[str]:
# Implementation of key points extraction
points = []
for doc in docs:
response = await self.llm.agenerate([
f"Extract key points from this text:\n\n{doc.page_content}"
])
points.extend(response.generations[0][0].text.split('\n'))
return points
async def _generate_summary(self, docs: List[dict]) -> str:
# Combine document content
combined_content = "\n\n".join([doc.page_content for doc in docs])
# Generate summary
response = await self.llm.agenerate([
f"Provide a comprehensive summary of this research:\n\n{combined_content}"
])
return response.generations[0][0].text
async def _extract_citations(self, docs: List[dict]) -> List[dict]:
citations = []
for doc in docs:
citations.append({
"title": doc.metadata.get("title", "Unknown"),
"authors": doc.metadata.get("authors", []),
"source": doc.metadata.get("source", "Unknown"),
"published_date": doc.metadata.get("published_date", "Unknown"),
"url": doc.metadata.get("url", None)
})
return citations
# Initialize service
research_service = ResearchService(retriever, llm)
@app.post("/research", response_model=ResearchResponse)
async def start_research(
request: ResearchRequest,
background_tasks: BackgroundTasks
):
try:
# Start research process
research_info = await research_service.start_research(
request.topic,
request.depth,
request.max_budget
)
# Store request parameters in Redis
redis_client.setex(
research_info["research_id"],
3600, # 1 hour expiration
json.dumps({
"topic": request.topic,
"depth": request.depth,
"include_citations": request.include_citations,
"status": "processing"
})
)
# Start research in background
background_tasks.add_task(
conduct_research_task,
research_info["research_id"],
request.topic,
request.depth,
request.include_citations
)
return research_info
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/research/{research_id}")
async def get_research_results(research_id: str):
# Check status in Redis
research_data = redis_client.get(research_id)
if not research_data:
raise HTTPException(status_code=404, detail="Research not found")
research_data = json.loads(research_data)
if research_data["status"] == "processing":
return {"status": "processing"}
return research_data
async def conduct_research_task(
research_id: str,
topic: str,
depth: str,
include_citations: bool
):
try:
# Conduct research
results = await research_service.conduct_research(
topic,
depth,
include_citations
)
# Store results in Redis
redis_client.setex(
research_id,
3600,
json.dumps({
**results,
"status": "completed"
})
)
except Exception as e:
# Store error in Redis
redis_client.setex(
research_id,
3600,
json.dumps({
"status": "error",
"error": str(e)
})
)
from fastapi import Request
from datetime import datetime
@app.middleware("http")
async def track_usage(request: Request, call_next):
# Check if it's a research request
if request.url.path.startswith("/research"):
# Get current daily requests
today = datetime.now().strftime("%Y-%m-%d")
daily_requests = int(redis_client.get(f"daily_requests_{today}") or 0)
# Check if we're over limit
if daily_requests > 1000: # 1000 requests per day limit
raise HTTPException(
status_code=429,
detail="Daily request limit exceeded"
)
response = await call_next(request)
# Update request count
if request.url.path.startswith("/research"):
redis_client.incr(f"daily_requests_{today}")
return response
import httpx
async def main():
async with httpx.AsyncClient() as client:
# Start research
response = await client.post(
"http://localhost:8000/research",
json={
"topic": "Recent advances in fusion energy",
"depth": "standard",
"include_citations": True,
"max_budget": 5.0
}
)
research_info = response.json()
research_id = research_info["research_id"]
# Poll for results
while True:
response = await client.get(
f"http://localhost:8000/research/{research_id}"
)
data = response.json()
if data["status"] == "completed":
print("\nResearch Summary:")
print(data["summary"])
print("\nKey Points:")
for point in data["key_points"]:
print(f"- {point}")
print("\nCitations:")
for citation in data["citations"]:
print(f"- {citation['title']} ({citation['source']})")
break
await asyncio.sleep(2)
if __name__ == "__main__":
asyncio.run(main())
Scaling
Monitoring
Security
Cost Management
Performance
Error Handling