Spaces:
Runtime error
Runtime error
| import asyncio | |
| import openai | |
| import chainlit as cl # importing chainlit for our app | |
| from chainlit.prompt import Prompt, PromptMessage # importing prompt tools | |
| import os | |
| import getpass | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| os.environ["PINECONE_ENV"] = "gcp-starter" | |
| import arxiv | |
| arxiv_client = arxiv.Client() | |
| paper_urls = [] | |
| '''' | |
| search = arxiv.Search( | |
| query = "Retrieval Augmented Generation", | |
| max_results = 5, | |
| sort_by = arxiv.SortCriterion.Relevance | |
| ) | |
| for result in arxiv_client.results(search): | |
| paper_urls.append(result.pdf_url) | |
| print(paper_urls) | |
| ''' | |
| from langchain.document_loaders import PyPDFLoader | |
| docs = [] | |
| '''' | |
| for paper_url in paper_urls: | |
| loader = PyPDFLoader(paper_url) | |
| docs.append(loader.load()) | |
| print(docs[0][6]) | |
| ''' | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size = 1000, | |
| chunk_overlap = 100, | |
| length_function = len | |
| ) | |
| import pinecone | |
| from pinecone.core.client.configuration import Configuration as OpenApiConfiguration | |
| YOUR_API_KEY = os.environ["PINECONE_API_KEY"] | |
| YOUR_ENV = os.environ["PINECONE_ENV"] | |
| index_name = 'arxiv-paper-index2' | |
| pinecone.init( | |
| api_key=YOUR_API_KEY, | |
| environment=YOUR_ENV | |
| ) | |
| if index_name not in pinecone.list_indexes(): | |
| # we create a new index | |
| pinecone.create_index( | |
| name=index_name, | |
| metric='cosine', | |
| dimension=1536 | |
| ) | |
| index = pinecone.GRPCIndex(index_name) | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.embeddings import CacheBackedEmbeddings | |
| from langchain.storage import LocalFileStore | |
| store = LocalFileStore("./cache/") | |
| core_embeddings_model = OpenAIEmbeddings() | |
| embedder = CacheBackedEmbeddings.from_bytes_store( | |
| core_embeddings_model, | |
| store, | |
| namespace=core_embeddings_model.model | |
| ) | |
| from tqdm.auto import tqdm | |
| from uuid import uuid4 | |
| BATCH_LIMIT = 100 | |
| texts = [] | |
| metadatas = [] | |
| '''' | |
| for i in tqdm(range(len(docs))): | |
| for doc in docs[i]: | |
| metadata = { | |
| 'source_document' : doc.metadata["source"], | |
| 'page_number' : doc.metadata["page"] | |
| } | |
| record_texts = text_splitter.split_text(doc.page_content) | |
| record_metadatas = [{ | |
| "chunk": j, "text": text, **metadata | |
| } for j, text in enumerate(record_texts)] | |
| texts.extend(record_texts) | |
| metadatas.extend(record_metadatas) | |
| if len(texts) >= BATCH_LIMIT: | |
| ids = [str(uuid4()) for _ in range(len(texts))] | |
| embeds = embedder.embed_documents(texts) | |
| index.upsert(vectors=zip(ids, embeds, metadatas)) | |
| texts = [] | |
| metadatas = [] | |
| if len(texts) > 0: | |
| ids = [str(uuid4()) for _ in range(len(texts))] | |
| embeds = embedder.embed_documents(texts) | |
| index.upsert(vectors=zip(ids, embeds, metadatas)) | |
| ''' | |
| from langchain.vectorstores import Pinecone | |
| text_field = "text" | |
| index = pinecone.Index(index_name) | |
| vectorstore = Pinecone( | |
| index, | |
| embedder.embed_query, | |
| text_field | |
| ) | |
| '''' | |
| query = "What is dense vector retrieval?" | |
| ''' | |
| '''' | |
| vectorstore.similarity_search( | |
| query, | |
| k=3 | |
| ) | |
| ''' | |
| from langchain.chat_models import ChatOpenAI | |
| llm = ChatOpenAI( | |
| model="gpt-3.5-turbo", | |
| temperature=0 | |
| ) | |
| from langchain.prompts import ChatPromptTemplate | |
| system_template = """Answer the following question with the provided context only. If you aren't able to get the answer from the provided context only, then please don't answer the question. | |
| ### CONTEXT | |
| {context} | |
| ###QUESTION | |
| {question} | |
| """ | |
| retriever = vectorstore.as_retriever() | |
| from langchain.prompts import ChatPromptTemplate | |
| prompt = ChatPromptTemplate.from_template(system_template) | |
| from operator import itemgetter | |
| from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
| from langchain.schema import format_document | |
| from langchain.schema.output_parser import StrOutputParser | |
| from langchain.prompts.prompt import PromptTemplate | |
| retrieval_augmented_qa_chain = ( | |
| {"context": itemgetter("question") | retriever, | |
| "question": itemgetter("question") | |
| } | |
| | RunnablePassthrough.assign( | |
| context=itemgetter("context") | |
| ) | |
| | { | |
| "response": prompt | llm, | |
| "context": itemgetter("context"), | |
| } | |
| ) | |
| import langchain | |
| from langchain.cache import InMemoryCache | |
| from langchain.globals import set_llm_cache | |
| set_llm_cache(InMemoryCache()) | |
| async def on_chat_start(): | |
| print("starting up") | |
| async def on_message(message: cl.Message): | |
| await (cl.Message(content=retrieval_augmented_qa_chain.invoke({"question":message.content})).send()) | |