|
|
|
|
|
|
|
|
from langgraph.graph import StateGraph, START, END |
|
|
from typing import TypedDict |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RAGState(TypedDict): |
|
|
query: str |
|
|
context: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding = HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2" |
|
|
) |
|
|
|
|
|
try: |
|
|
vectorstore = FAISS.load_local( |
|
|
"faiss_index_fast", |
|
|
embedding, |
|
|
allow_dangerous_deserialization=True |
|
|
) |
|
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) |
|
|
except: |
|
|
retriever = None |
|
|
print("⚠️ WARNING: FAISS index missing → RAG disabled") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve(state: RAGState): |
|
|
query = state["query"] |
|
|
|
|
|
if retriever: |
|
|
docs = retriever.invoke(query) |
|
|
context = "\n".join([d.page_content for d in docs]) |
|
|
else: |
|
|
context = "" |
|
|
|
|
|
return {"query": query, "context": context} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
workflow = StateGraph(RAGState) |
|
|
workflow.add_node("retrieve", retrieve) |
|
|
workflow.add_edge(START, "retrieve") |
|
|
workflow.add_edge("retrieve", END) |
|
|
|
|
|
rag_graph = workflow.compile() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_rag(query: str): |
|
|
"""Call this from TeacherAgent to retrieve context.""" |
|
|
return rag_graph.invoke({"query": query}) |
|
|
|