File size: 1,731 Bytes
73fbc5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# rag_graph.py ----------------------------------------------------
from langgraph.graph import StateGraph, START, END
from typing import TypedDict
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
# -------------------------------
# Graph State Definition
# -------------------------------
class RAGState(TypedDict):
query: str
context: str
# -------------------------------
# Load FAISS + Embeddings
# -------------------------------
embedding = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
try:
vectorstore = FAISS.load_local(
"faiss_index_fast",
embedding,
allow_dangerous_deserialization=True
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
except:
retriever = None
print("⚠️ WARNING: FAISS index missing → RAG disabled")
# -------------------------------
# Retrieval Node
# -------------------------------
def retrieve(state: RAGState):
query = state["query"]
if retriever:
docs = retriever.invoke(query)
context = "\n".join([d.page_content for d in docs])
else:
context = ""
return {"query": query, "context": context}
# -------------------------------
# Build Retrieval LangGraph
# -------------------------------
workflow = StateGraph(RAGState)
workflow.add_node("retrieve", retrieve)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", END)
rag_graph = workflow.compile()
# -------------------------------
# External API
# -------------------------------
def run_rag(query: str):
"""Call this from TeacherAgent to retrieve context."""
return rag_graph.invoke({"query": query})
|