diff --git a/api/api.py b/api/api.py index d40e73f9..25b94ac2 100644 --- a/api/api.py +++ b/api/api.py @@ -537,6 +537,88 @@ async def delete_wiki_cache( logger.warning(f"Wiki cache not found, cannot delete: {cache_path}") raise HTTPException(status_code=404, detail="Wiki cache not found") +class RetrieveRequest(BaseModel): + """Request body for pure RAG retrieval (no LLM).""" + repo_url: str = Field(..., description="Full repository URL") + query: str = Field(..., description="Search query") + type: str = Field(default="github", description="Repository type: github, gitlab, bitbucket") + token: Optional[str] = Field(default=None, description="Access token for private repos") + top_k: int = Field(default=5, ge=1, le=20, description="Number of chunks to return") + +@app.post("/api/retrieve") +async def retrieve(request: RetrieveRequest): + """Pure vector retrieval — returns relevant code chunks without calling any LLM.""" + from api.rag import RAG + from api.data_pipeline import DatabaseManager + from api.tools.embedder import get_embedder + from api.config import configs, get_embedder_type + from adalflow.components.retriever.faiss_retriever import FAISSRetriever + + try: + embedder_type = get_embedder_type() + is_ollama = (embedder_type == 'ollama') + + # Prepare database (loads cached .pkl if available) + db_manager = DatabaseManager() + transformed_docs = db_manager.prepare_database( + request.repo_url, request.type, request.token, embedder_type=embedder_type + ) + if not transformed_docs: + raise HTTPException(status_code=404, detail="No indexed data found for this repo. Index it via the web UI first.") + + # Use RAG's validation to filter embeddings with consistent sizes + rag_instance = RAG.__new__(RAG) + valid_docs = rag_instance._validate_and_filter_embeddings(transformed_docs) + if not valid_docs: + raise HTTPException(status_code=404, detail="No valid embeddings found for this repo.") + + # Build embedder for query + embedder = get_embedder(embedder_type=embedder_type) + if is_ollama: + import weakref + embedder_ref = embedder + def query_embedder(query): + if isinstance(query, list): + query = query[0] + return embedder_ref(input=query) + else: + query_embedder = embedder + + # Build FAISS retriever + retriever_config = {**configs["retriever"], "top_k": request.top_k} + retriever = FAISSRetriever( + **retriever_config, + embedder=query_embedder, + documents=valid_docs, + document_map_func=lambda doc: doc.vector, + ) + + # Retrieve + results = retriever(request.query) + docs = [valid_docs[i] for i in results[0].doc_indices] + + return { + "query": request.query, + "total_chunks": len(valid_docs), + "results": [ + { + "text": doc.text, + "file_path": doc.meta_data.get("file_path", ""), + "is_code": doc.meta_data.get("is_code", False), + "token_count": doc.meta_data.get("token_count", 0), + } + for doc in docs + ] + } + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Retrieve error: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @app.get("/health") async def health_check(): """Health check endpoint for Docker and monitoring"""