Spaces:
Runtime error
Runtime error
| # https://python.langchain.com/docs/modules/chains/how_to/custom_chain | |
| # Including reformulation of the question in the chain | |
| import json | |
| from langchain import PromptTemplate, LLMChain | |
| from langchain.chains import QAWithSourcesChain | |
| from langchain.chains import TransformChain, SequentialChain | |
| from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
| from anyqa.prompts import answer_prompt, reformulation_prompt | |
| from anyqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain | |
| def load_qa_chain_with_docs(llm): | |
| """Load a QA chain with documents. | |
| Useful when you already have retrieved docs | |
| To be called with this input | |
| ``` | |
| output = chain({ | |
| "question":query, | |
| "audience":"experts scientists", | |
| "docs":docs, | |
| "language":"English", | |
| }) | |
| ``` | |
| """ | |
| qa_chain = load_combine_documents_chain(llm) | |
| chain = QAWithSourcesChain( | |
| input_docs_key="docs", | |
| combine_documents_chain=qa_chain, | |
| return_source_documents=True, | |
| ) | |
| return chain | |
| def load_combine_documents_chain(llm): | |
| prompt = PromptTemplate( | |
| template=answer_prompt, | |
| input_variables=["summaries", "question", "audience", "language"], | |
| ) | |
| qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt) | |
| return qa_chain | |
| def load_qa_chain_with_text(llm): | |
| prompt = PromptTemplate( | |
| template=answer_prompt, | |
| input_variables=["question", "audience", "language", "summaries"], | |
| ) | |
| qa_chain = LLMChain(llm=llm, prompt=prompt) | |
| return qa_chain | |
| def load_qa_chain(retriever, llm_reformulation, llm_answer): | |
| reformulation_chain = load_reformulation_chain(llm_reformulation) | |
| answer_chain = load_qa_chain_with_retriever(retriever, llm_answer) | |
| qa_chain = SequentialChain( | |
| chains=[reformulation_chain, answer_chain], | |
| input_variables=["query", "audience"], | |
| output_variables=["answer", "question", "language", "source_documents"], | |
| return_all=True, | |
| verbose=True, | |
| ) | |
| return qa_chain | |
| def load_reformulation_chain(llm): | |
| prompt = PromptTemplate( | |
| template=reformulation_prompt, | |
| input_variables=["query"], | |
| ) | |
| reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json") | |
| # Parse the output | |
| def parse_output(output): | |
| query = output["query"] | |
| print("output", output) | |
| json_output = json.loads(output["json"]) | |
| question = json_output.get("question", query) | |
| language = json_output.get("language", "English") | |
| return { | |
| "question": question, | |
| "language": language, | |
| } | |
| transform_chain = TransformChain( | |
| input_variables=["json"], | |
| output_variables=["question", "language"], | |
| transform=parse_output, | |
| ) | |
| reformulation_chain = SequentialChain( | |
| chains=[reformulation_chain, transform_chain], | |
| input_variables=["query"], | |
| output_variables=["question", "language"], | |
| ) | |
| return reformulation_chain | |
| def load_qa_chain_with_retriever(retriever, llm): | |
| qa_chain = load_combine_documents_chain(llm) | |
| # This could be improved by providing a document prompt to avoid modifying page_content in the docs | |
| # See here https://github.com/langchain-ai/langchain/issues/3523 | |
| answer_chain = CustomRetrievalQAWithSourcesChain( | |
| combine_documents_chain=qa_chain, | |
| retriever=retriever, | |
| return_source_documents=True, | |
| verbose=True, | |
| fallback_answer="**⚠️ No relevant passages found in the sources, you may want to ask a more specific question.**", | |
| ) | |
| return answer_chain | |