Spaces:
Running
Running
| import os | |
| import tempfile | |
| import gc | |
| import logging | |
| import streamlit as st | |
| from groq import Groq, APIError | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| import torch | |
| # ---------------- CONFIGURATION ---------------- | |
| logging.basicConfig(level=logging.INFO) | |
| # Load API key from Hugging Face secrets | |
| GROQ_API_KEY = st.secrets.get("GROQ_API_KEY", os.environ.get("GROQ_API_KEY")) | |
| GROQ_MODEL = "llama-3.1-8b-instant" | |
| # Initialize Groq client | |
| client = None | |
| if GROQ_API_KEY: | |
| try: | |
| client = Groq(api_key=GROQ_API_KEY) | |
| st.success("β Groq client initialized successfully.") | |
| except Exception as e: | |
| st.error(f"β Failed to initialize Groq client: {e}") | |
| client = None | |
| else: | |
| st.warning("β οΈ GROQ_API_KEY not found. Please add it to Hugging Face secrets.") | |
| # ---------------- STREAMLIT UI SETUP ---------------- | |
| st.set_page_config(page_title="PDF Assistant", page_icon="π", layout="wide") | |
| # ---------------- CSS (Your exact UI) ---------------- | |
| st.markdown(""" | |
| <style> | |
| :root { | |
| --primary-color: #1e3a8a; | |
| --background-color: #0e1117; | |
| --secondary-background-color: #1a1d29; | |
| --text-color: #f0f2f6; | |
| } | |
| .chat-user { | |
| background: #2d3748; | |
| padding: 12px; | |
| border-radius: 10px 10px 2px 10px; | |
| margin: 6px 0 6px auto; | |
| max-width: 85%; | |
| text-align: right; | |
| color: var(--text-color); | |
| } | |
| .chat-bot { | |
| background: var(--primary-color); | |
| padding: 12px; | |
| border-radius: 10px 10px 10px 2px; | |
| margin: 6px auto 6px 0; | |
| max-width: 85%; | |
| text-align: left; | |
| color: #ffffff; | |
| } | |
| .sources { | |
| font-size: 0.8em; | |
| opacity: 0.7; | |
| margin-top: 10px; | |
| border-top: 1px solid rgba(255, 255, 255, 0.1); | |
| padding-top: 5px; | |
| } | |
| .footer { | |
| position: fixed; | |
| left: 0; | |
| bottom: 0; | |
| width: 100%; | |
| background-color: var(--secondary-background-color); | |
| color: var(--text-color); | |
| text-align: center; | |
| padding: 10px; | |
| font-size: 0.85em; | |
| border-top: 1px solid rgba(255, 255, 255, 0.1); | |
| } | |
| .footer a { | |
| color: var(--primary-color); | |
| text-decoration: none; | |
| font-weight: bold; | |
| } | |
| .footer a:hover { | |
| text-decoration: underline; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ---------------- SESSION STATE ---------------- | |
| if "chat" not in st.session_state: | |
| st.session_state.chat = [] | |
| if "vectorstore" not in st.session_state: | |
| st.session_state.vectorstore = None | |
| if "retriever" not in st.session_state: | |
| st.session_state.retriever = None | |
| if "uploaded_file_name" not in st.session_state: | |
| st.session_state.uploaded_file_name = None | |
| if "uploader_key" not in st.session_state: | |
| st.session_state.uploader_key = 0 | |
| # ---------------- FUNCTIONS ---------------- | |
| def clear_chat_history(): | |
| st.session_state.chat = [] | |
| def clear_memory(): | |
| st.session_state.vectorstore = None | |
| st.session_state.retriever = None | |
| st.session_state.uploaded_file_name = None | |
| st.session_state.uploader_key += 1 | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| st.success("Memory cleared. Please upload a new PDF.") | |
| def process_pdf(uploaded_file): | |
| """Process uploaded PDF and create vectorstore.""" | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| tmp.write(uploaded_file.getvalue()) | |
| path = tmp.name | |
| # Load PDF | |
| loader = PyPDFLoader(path) | |
| docs = loader.load() | |
| # Split into chunks | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=50 | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| # Create embeddings | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True} | |
| ) | |
| # Create vectorstore | |
| vectorstore = Chroma.from_documents(chunks, embeddings) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| # Store in session state | |
| st.session_state.vectorstore = vectorstore | |
| st.session_state.retriever = retriever | |
| # Cleanup | |
| if os.path.exists(path): | |
| os.unlink(path) | |
| return len(chunks) | |
| except Exception as e: | |
| st.error(f"Error processing PDF: {str(e)}") | |
| return None | |
| def ask_question(question): | |
| """Retrieve and generate answer for the question.""" | |
| if not client: | |
| return None, 0, "Groq client is not initialized. Check API key setup." | |
| if not st.session_state.retriever: | |
| return None, 0, "Upload PDF first to initialize the knowledge base." | |
| try: | |
| # Retrieve relevant chunks | |
| docs = st.session_state.retriever.invoke(question) | |
| context = "\n\n".join(d.page_content for d in docs) | |
| # Build prompt | |
| prompt = f""" | |
| You are a strict RAG Q&A assistant. | |
| Use ONLY the context provided. If the answer is not found, reply: | |
| "I cannot find this in the PDF." | |
| ---------------- CONTEXT ---------------- | |
| {context} | |
| ----------------------------------------- | |
| QUESTION: {question} | |
| FINAL ANSWER: | |
| """ | |
| # Call Groq API | |
| response = client.chat.completions.create( | |
| model=GROQ_MODEL, | |
| messages=[ | |
| {"role": "system", | |
| "content": "Use only the PDF content. If answer not found, say: 'I cannot find this in the PDF.'"}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.0 | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| return answer, len(docs), None | |
| except APIError as e: | |
| return None, 0, f"Groq API Error: {str(e)}" | |
| except Exception as e: | |
| return None, 0, f"General error: {str(e)}" | |
| # ---------------- UI COMPONENTS ---------------- | |
| st.title("π PDF Assistant") | |
| # Sidebar Controls | |
| with st.sidebar: | |
| st.header("Controls") | |
| st.button("ποΈ Clear Chat History", on_click=clear_chat_history, use_container_width=True) | |
| st.button("π₯ Clear PDF Memory", on_click=clear_memory, use_container_width=True) | |
| st.markdown("---") | |
| if st.session_state.uploaded_file_name: | |
| st.success(f"β **Active PDF:**\n `{st.session_state.uploaded_file_name}`") | |
| else: | |
| st.warning("β¬οΈ Upload a PDF to start chatting!") | |
| # File Upload | |
| uploaded = st.file_uploader( | |
| "Upload your PDF", | |
| type=["pdf"], | |
| key=st.session_state.uploader_key | |
| ) | |
| if uploaded and uploaded.name != st.session_state.uploaded_file_name: | |
| st.session_state.uploaded_file_name = None | |
| st.session_state.chat = [] | |
| with st.spinner(f"Processing '{uploaded.name}'..."): | |
| chunks_count = process_pdf(uploaded) | |
| if chunks_count is not None: | |
| st.success(f"β PDF processed successfully! {chunks_count} chunks created.") | |
| st.session_state.uploaded_file_name = uploaded.name | |
| else: | |
| st.error("β Failed to process PDF") | |
| st.session_state.uploaded_file_name = None | |
| st.rerun() | |
| # Chat Input | |
| disabled_input = st.session_state.uploaded_file_name is None or client is None | |
| question = st.text_input( | |
| "Ask a question about the loaded PDF:", | |
| key="question_input", | |
| disabled=disabled_input | |
| ) | |
| if st.button("Send", disabled=disabled_input) and question: | |
| # Add user query to chat history | |
| st.session_state.chat.append(("user", question)) | |
| # Get answer | |
| with st.spinner("Thinking..."): | |
| answer, sources, error = ask_question(question) | |
| if answer: | |
| bot_message = f"{answer}<div class='sources'>Context Chunks Used: {sources}</div>" | |
| st.session_state.chat.append(("bot", bot_message)) | |
| else: | |
| st.session_state.chat.append(("bot", f"π΄ **Error:** {error}")) | |
| st.rerun() | |
| # Display Chat History | |
| st.markdown("## Chat History") | |
| for role, msg in st.session_state.chat: | |
| if role == "user": | |
| st.markdown(f"<div class='chat-user'>{msg}</div>", unsafe_allow_html=True) | |
| else: | |
| st.markdown(f"<div class='chat-bot'>{msg}</div>", unsafe_allow_html=True) | |
| # Footer | |
| footer_html = """ | |
| <div class="footer"> | |
| Created by <a href="https://www.linkedin.com/in/abhishek-iitr/" target="_blank">Abhishek Saxena</a> | |
| </div> | |
| """ | |
| st.markdown(footer_html, unsafe_allow_html=True) |