|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import streamlit as st |
|
|
import time |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
from rag.loader import load_rag_index |
|
|
from rag.retriever import RAGRetriever |
|
|
from agents.student import StudentAgent |
|
|
from agents.teacher import TeacherAgent |
|
|
from graph.classroom_graph import build_classroom_graph |
|
|
|
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "conversation" not in st.session_state: |
|
|
st.session_state.conversation = [] |
|
|
|
|
|
if "transcript" not in st.session_state: |
|
|
st.session_state.transcript = "" |
|
|
|
|
|
if "quiz_raw" not in st.session_state: |
|
|
st.session_state.quiz_raw = None |
|
|
|
|
|
if "summary" not in st.session_state: |
|
|
st.session_state.summary = "" |
|
|
|
|
|
if "state" not in st.session_state: |
|
|
st.session_state.state = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_quiz(topic, transcript): |
|
|
llm = ChatOpenAI( |
|
|
model="gpt-4o-mini", |
|
|
temperature=0.2, |
|
|
api_key=os.getenv("OPENAI_API_KEY") |
|
|
) |
|
|
|
|
|
prompt = f""" |
|
|
Create a quiz of EXACTLY 5 MCQs based on this topic: {topic}. |
|
|
|
|
|
SOURCE TRANSCRIPT: |
|
|
{transcript} |
|
|
|
|
|
FORMAT STRICTLY LIKE THIS: |
|
|
Q1: <question> |
|
|
A) option |
|
|
B) option |
|
|
C) option |
|
|
D) option |
|
|
Answer: <B|C|A|D> |
|
|
""" |
|
|
|
|
|
response = llm.invoke([{"role": "user", "content": prompt}]) |
|
|
return response.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_teacher_response(question, rag_answer, nonrag_answer): |
|
|
llm = ChatOpenAI( |
|
|
model="gpt-4o-mini", |
|
|
temperature=0.0, |
|
|
api_key=os.getenv("OPENAI_API_KEY") |
|
|
) |
|
|
|
|
|
prompt = f""" |
|
|
Evaluate two teacher responses... |
|
|
""" |
|
|
|
|
|
response = llm.invoke([{"role": "user", "content": prompt}]) |
|
|
return response.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_summary(transcript, topic): |
|
|
llm = ChatOpenAI( |
|
|
model="gpt-4o-mini", |
|
|
temperature=0.3, |
|
|
api_key=os.getenv("OPENAI_API_KEY") |
|
|
) |
|
|
|
|
|
prompt = f""" |
|
|
Create a structured summary for topic: {topic} |
|
|
|
|
|
TRANSCRIPT: |
|
|
{transcript} |
|
|
""" |
|
|
|
|
|
response = llm.invoke([{"role": "user", "content": prompt}]) |
|
|
return response.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="AI Classroom", page_icon="π", layout="wide") |
|
|
st.title("π AI Classroom: RAG vs Non-RAG + Quiz + Summary + Evaluation") |
|
|
|
|
|
st.write("This app simulates an AI teacher using RAG and compares it with a normal teacher.") |
|
|
|
|
|
|
|
|
|
|
|
st.sidebar.header("Settings") |
|
|
topic = st.sidebar.text_input("Enter Topic:", "Quantum Computing") |
|
|
turns = st.sidebar.slider("Number of Dialogue Turns", 1, 10, 3) |
|
|
|
|
|
run_button = st.sidebar.button("Start Lesson") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if run_button: |
|
|
|
|
|
st.session_state.conversation = [] |
|
|
st.session_state.quiz_raw = None |
|
|
st.session_state.summary = "" |
|
|
|
|
|
st.subheader(f"π Topic: **{topic}**") |
|
|
|
|
|
|
|
|
with st.spinner("Loading RAG index..."): |
|
|
vectorstore = load_rag_index() |
|
|
rag = RAGRetriever(vectorstore) |
|
|
|
|
|
student = StudentAgent() |
|
|
|
|
|
rag_teacher = TeacherAgent() |
|
|
rag_teacher.rag_enabled = True |
|
|
|
|
|
|
|
|
graph = build_classroom_graph(student, rag_teacher, rag) |
|
|
|
|
|
|
|
|
state = { |
|
|
"topic": topic, |
|
|
"last_teacher_reply": "I am ready to teach.", |
|
|
"last_student_question": "", |
|
|
"conversation": [] |
|
|
} |
|
|
|
|
|
st.info("Classroom session started...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messages_displayed_count = 0 |
|
|
|
|
|
for t in range(turns): |
|
|
|
|
|
state = graph.invoke(state) |
|
|
|
|
|
|
|
|
all_messages = state["conversation"] |
|
|
|
|
|
|
|
|
new_messages = all_messages[messages_displayed_count:] |
|
|
|
|
|
|
|
|
for msg_data in new_messages: |
|
|
|
|
|
|
|
|
if isinstance(msg_data, tuple): |
|
|
role, msg = msg_data |
|
|
elif isinstance(msg_data, dict): |
|
|
role = msg_data.get("role", "").capitalize() |
|
|
msg = msg_data.get("message", "") |
|
|
else: |
|
|
role, msg = "Unknown", str(msg_data) |
|
|
|
|
|
|
|
|
st.session_state.conversation.append((role, msg)) |
|
|
|
|
|
|
|
|
if role.lower() == "student": |
|
|
st.chat_message("user", avatar="π§βπ").write(msg) |
|
|
else: |
|
|
st.chat_message("assistant", avatar="π¨βπ«").write(msg) |
|
|
|
|
|
|
|
|
time.sleep(0.5) |
|
|
|
|
|
|
|
|
messages_displayed_count = len(all_messages) |
|
|
|
|
|
|
|
|
transcript = "" |
|
|
for role, msg in st.session_state.conversation: |
|
|
transcript += f"{role}: {msg}\n" |
|
|
|
|
|
st.session_state.transcript = transcript |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.conversation: |
|
|
st.subheader("π Full Transcript") |
|
|
|
|
|
for role, msg in st.session_state.conversation: |
|
|
if role.lower() == "student": |
|
|
st.chat_message("user").write(msg) |
|
|
else: |
|
|
st.chat_message("assistant").write(msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader("π Lesson Summary") |
|
|
|
|
|
if st.button("Generate Summary"): |
|
|
with st.spinner("Generating summary..."): |
|
|
summary = generate_summary(st.session_state.transcript, topic) |
|
|
|
|
|
st.session_state.summary = summary |
|
|
|
|
|
if st.session_state.summary: |
|
|
st.text_area("Lesson Summary", st.session_state.summary, height=300) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader("π§ Quiz Time!") |
|
|
|
|
|
if st.button("Generate Quiz"): |
|
|
with st.spinner("Generating quiz..."): |
|
|
quiz = generate_quiz(topic, st.session_state.transcript) |
|
|
|
|
|
st.session_state.quiz_raw = quiz |
|
|
|
|
|
if st.session_state.quiz_raw: |
|
|
st.text_area("Generated Quiz", st.session_state.quiz_raw, height=260) |
|
|
|
|
|
|
|
|
lines = st.session_state.quiz_raw.split("\n") |
|
|
questions = [] |
|
|
q = {} |
|
|
|
|
|
for line in lines: |
|
|
if line.startswith("Q"): |
|
|
q = {"question": line, "options": []} |
|
|
elif line.startswith(("A)", "B)", "C)", "D)")): |
|
|
q["options"].append(line) |
|
|
elif line.startswith("Answer"): |
|
|
q["answer"] = line.split(":")[1].strip() |
|
|
questions.append(q) |
|
|
|
|
|
st.subheader("π Take the Quiz") |
|
|
|
|
|
user_answers = [] |
|
|
|
|
|
for i, q in enumerate(questions): |
|
|
st.write(f"**{q['question']}**") |
|
|
choice = st.radio("", q["options"], key=f"quiz{i}") |
|
|
user_answers.append(choice[0]) |
|
|
|
|
|
if st.button("Submit Answers"): |
|
|
score = sum( |
|
|
user_answers[i] == questions[i]["answer"] |
|
|
for i in range(len(questions)) |
|
|
) |
|
|
st.success(f"π Your Score: **{score}/5**") |
|
|
|