Spaces:
Runtime error
Runtime error
LOUIS SANNA
commited on
Commit
·
780c913
1
Parent(s):
3a575de
feat(domains)
Browse files- anyqa/config.py +10 -0
- anyqa/retriever.py +7 -8
- app.py +11 -12
anyqa/config.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def get_domains():
|
| 6 |
+
domains = []
|
| 7 |
+
for root, dirs, files in os.walk("data"):
|
| 8 |
+
for dir in dirs:
|
| 9 |
+
domains.append(dir)
|
| 10 |
+
return domains
|
anyqa/retriever.py
CHANGED
|
@@ -13,25 +13,24 @@ SUMMARY_TYPES = []
|
|
| 13 |
|
| 14 |
class QARetriever(BaseRetriever):
|
| 15 |
vectorstore: VectorStore
|
| 16 |
-
|
| 17 |
threshold: float = 22
|
| 18 |
k_summary: int = 0
|
| 19 |
k_total: int = 10
|
| 20 |
namespace: str = "vectors"
|
| 21 |
|
| 22 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
| 23 |
-
|
| 24 |
-
assert isinstance(self.sources, list)
|
| 25 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
| 26 |
|
| 27 |
-
query = "He who can bear the misfortune of a nation is called the ruler of the world."
|
| 28 |
# Prepare base search kwargs
|
| 29 |
filters = {}
|
| 30 |
-
if len(self.
|
| 31 |
-
filters["
|
| 32 |
|
| 33 |
if self.k_summary > 0:
|
| 34 |
# Search for k_summary documents in the summaries dataset
|
|
|
|
| 35 |
if len(SUMMARY_TYPES):
|
| 36 |
filters_summaries = {
|
| 37 |
**filters_summaries,
|
|
@@ -48,7 +47,8 @@ class QARetriever(BaseRetriever):
|
|
| 48 |
docs_summaries = []
|
| 49 |
|
| 50 |
# Search for k_total - k_summary documents in the full reports dataset
|
| 51 |
-
filters_full = {}
|
|
|
|
| 52 |
if len(SUMMARY_TYPES):
|
| 53 |
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
|
| 54 |
|
|
@@ -59,7 +59,6 @@ class QARetriever(BaseRetriever):
|
|
| 59 |
filter=self.format_filter(filters_full),
|
| 60 |
k=k_full,
|
| 61 |
)
|
| 62 |
-
print("docs_full", docs_full)
|
| 63 |
|
| 64 |
# Concatenate documents
|
| 65 |
docs = docs_summaries + docs_full
|
|
|
|
| 13 |
|
| 14 |
class QARetriever(BaseRetriever):
|
| 15 |
vectorstore: VectorStore
|
| 16 |
+
domains: list = []
|
| 17 |
threshold: float = 22
|
| 18 |
k_summary: int = 0
|
| 19 |
k_total: int = 10
|
| 20 |
namespace: str = "vectors"
|
| 21 |
|
| 22 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
| 23 |
+
assert isinstance(self.domains, list)
|
|
|
|
| 24 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
| 25 |
|
|
|
|
| 26 |
# Prepare base search kwargs
|
| 27 |
filters = {}
|
| 28 |
+
if len(self.domains):
|
| 29 |
+
filters["domain"] = {"$in": self.domains}
|
| 30 |
|
| 31 |
if self.k_summary > 0:
|
| 32 |
# Search for k_summary documents in the summaries dataset
|
| 33 |
+
filters_summaries = {**filters}
|
| 34 |
if len(SUMMARY_TYPES):
|
| 35 |
filters_summaries = {
|
| 36 |
**filters_summaries,
|
|
|
|
| 47 |
docs_summaries = []
|
| 48 |
|
| 49 |
# Search for k_total - k_summary documents in the full reports dataset
|
| 50 |
+
filters_full = {**filters}
|
| 51 |
+
print("filters", filters)
|
| 52 |
if len(SUMMARY_TYPES):
|
| 53 |
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
|
| 54 |
|
|
|
|
| 59 |
filter=self.format_filter(filters_full),
|
| 60 |
k=k_full,
|
| 61 |
)
|
|
|
|
| 62 |
|
| 63 |
# Concatenate documents
|
| 64 |
docs = docs_summaries + docs_full
|
app.py
CHANGED
|
@@ -7,6 +7,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
| 7 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 8 |
|
| 9 |
# ClimateQ&A imports
|
|
|
|
| 10 |
from anyqa.embeddings import EMBEDDING_MODEL_NAME
|
| 11 |
from anyqa.llm import get_llm
|
| 12 |
from anyqa.qa_logging import log
|
|
@@ -136,16 +137,14 @@ def answer_user_example(query, query_example, history):
|
|
| 136 |
return query_example, history + [[query_example, ". . ."]]
|
| 137 |
|
| 138 |
|
| 139 |
-
def fetch_sources(query,
|
| 140 |
-
# Prepare default values
|
| 141 |
-
if len(sources) == 0:
|
| 142 |
-
sources = ["IPCC"]
|
| 143 |
|
| 144 |
llm_reformulation = get_llm(
|
| 145 |
max_tokens=512, temperature=0.0, verbose=True, streaming=False
|
| 146 |
)
|
|
|
|
| 147 |
retriever = QARetriever(
|
| 148 |
-
vectorstore=vectorstore,
|
| 149 |
)
|
| 150 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
| 151 |
|
|
@@ -379,11 +378,11 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
|
|
| 379 |
gr.Markdown(
|
| 380 |
"Reminder: You can talk in any language, this tool is multi-lingual!"
|
| 381 |
)
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
label="Select
|
| 386 |
-
value=[
|
| 387 |
interactive=True,
|
| 388 |
)
|
| 389 |
|
|
@@ -419,7 +418,7 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
|
|
| 419 |
.success(change_tab, None, tabs)
|
| 420 |
.success(
|
| 421 |
fetch_sources,
|
| 422 |
-
[textbox,
|
| 423 |
[
|
| 424 |
textbox,
|
| 425 |
sources_textbox,
|
|
@@ -454,7 +453,7 @@ with gr.Blocks(title="❓ Q&A", css="style.css", theme=theme) as demo:
|
|
| 454 |
.success(change_tab, None, tabs)
|
| 455 |
.success(
|
| 456 |
fetch_sources,
|
| 457 |
-
[textbox,
|
| 458 |
[
|
| 459 |
textbox,
|
| 460 |
sources_textbox,
|
|
|
|
| 7 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 8 |
|
| 9 |
# ClimateQ&A imports
|
| 10 |
+
from anyqa.config import get_domains
|
| 11 |
from anyqa.embeddings import EMBEDDING_MODEL_NAME
|
| 12 |
from anyqa.llm import get_llm
|
| 13 |
from anyqa.qa_logging import log
|
|
|
|
| 137 |
return query_example, history + [[query_example, ". . ."]]
|
| 138 |
|
| 139 |
|
| 140 |
+
def fetch_sources(query, domains):
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
llm_reformulation = get_llm(
|
| 143 |
max_tokens=512, temperature=0.0, verbose=True, streaming=False
|
| 144 |
)
|
| 145 |
+
print("domains", domains)
|
| 146 |
retriever = QARetriever(
|
| 147 |
+
vectorstore=vectorstore, domains=domains, k_summary=0, k_total=10
|
| 148 |
)
|
| 149 |
reformulation_chain = load_reformulation_chain(llm_reformulation)
|
| 150 |
|
|
|
|
| 378 |
gr.Markdown(
|
| 379 |
"Reminder: You can talk in any language, this tool is multi-lingual!"
|
| 380 |
)
|
| 381 |
+
domains = get_domains()
|
| 382 |
+
dropdown_domains = gr.CheckboxGroup(
|
| 383 |
+
domains,
|
| 384 |
+
label="Select source types",
|
| 385 |
+
value=[],
|
| 386 |
interactive=True,
|
| 387 |
)
|
| 388 |
|
|
|
|
| 418 |
.success(change_tab, None, tabs)
|
| 419 |
.success(
|
| 420 |
fetch_sources,
|
| 421 |
+
[textbox, dropdown_domains],
|
| 422 |
[
|
| 423 |
textbox,
|
| 424 |
sources_textbox,
|
|
|
|
| 453 |
.success(change_tab, None, tabs)
|
| 454 |
.success(
|
| 455 |
fetch_sources,
|
| 456 |
+
[textbox, dropdown_domains],
|
| 457 |
[
|
| 458 |
textbox,
|
| 459 |
sources_textbox,
|