mwitiderrick commited on
Commit
3d2faa5
·
verified ·
1 Parent(s): 57b2d01

Update rag_dspy.py

Browse files
Files changed (1) hide show
  1. rag_dspy.py +10 -8
rag_dspy.py CHANGED
@@ -12,7 +12,7 @@ load_dotenv()
12
  # DSPy setup
13
  lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
14
  client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
15
- collection_name = "indexed_medical_chat_bot"
16
  rm = QdrantRM(
17
  qdrant_collection_name=collection_name,
18
  qdrant_client=client,
@@ -24,7 +24,7 @@ dspy.settings.configure(lm=lm, rm=rm)
24
 
25
  # Manual reranker using ColBERT multivector field
26
  # Manual reranker using Qdrant’s native prefetch + ColBERT query
27
- def rerank_with_colbert(query_text, year, specialty):
28
  from fastembed import TextEmbedding, LateInteractionTextEmbedding
29
 
30
  # Encode query once with both models
@@ -48,8 +48,8 @@ def rerank_with_colbert(query_text, year, specialty):
48
  query_filter=Filter(
49
  must=[
50
  FieldCondition(key="specialty", match=MatchValue(value=specialty)),
51
- FieldCondition(key="year", match=MatchValue(value=year))
52
- ]
53
 
54
  )
55
  )
@@ -66,7 +66,8 @@ def rerank_with_colbert(query_text, year, specialty):
66
  class MedicalAnswer(dspy.Signature):
67
  question = dspy.InputField(desc="The medical question to answer")
68
  is_medical = dspy.OutputField(desc="Answer 'Yes' if the question is medical, otherwise 'No'")
69
- year = dspy.InputField(desc="The year of the medical paper")
 
70
  specialty = dspy.InputField(desc="The specialty of the medical paper")
71
  context = dspy.OutputField(desc="The answer to the medical question")
72
  final_answer = dspy.OutputField(desc="The answer to the medical question")
@@ -87,16 +88,17 @@ class MedicalRAG(dspy.Module):
87
  super().__init__()
88
  self.guardrail = MedicalGuardrail()
89
 
90
- def forward(self, question, year, specialty):
91
  if not self.guardrail.forward(question):
92
  class DummyResult:
93
  final_answer = "Sorry, I can only answer medical questions. Please ask a question related to medicine or healthcare."
94
  return DummyResult()
95
- reranked_docs = rerank_with_colbert(question, year, specialty)
96
  context_str = "\n".join(reranked_docs)
97
  return dspy.ChainOfThought(MedicalAnswer)(
98
  question=question,
99
- year=year,
 
100
  specialty=specialty,
101
  context=context_str
102
  )
 
12
  # DSPy setup
13
  lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
14
  client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
15
+ collection_name = "medical_chat_bot"
16
  rm = QdrantRM(
17
  qdrant_collection_name=collection_name,
18
  qdrant_client=client,
 
24
 
25
  # Manual reranker using ColBERT multivector field
26
  # Manual reranker using Qdrant’s native prefetch + ColBERT query
27
+ def rerank_with_colbert(query_text, min_year, max_year, specialty):
28
  from fastembed import TextEmbedding, LateInteractionTextEmbedding
29
 
30
  # Encode query once with both models
 
48
  query_filter=Filter(
49
  must=[
50
  FieldCondition(key="specialty", match=MatchValue(value=specialty)),
51
+ FieldCondition(key="year",range=models.Range(gt=None,gte=min_year,lt=None,lte=max_year))
52
+ ]
53
 
54
  )
55
  )
 
66
  class MedicalAnswer(dspy.Signature):
67
  question = dspy.InputField(desc="The medical question to answer")
68
  is_medical = dspy.OutputField(desc="Answer 'Yes' if the question is medical, otherwise 'No'")
69
+ min_year = dspy.InputField(desc="The minimum year of the medical paper")
70
+ max_year = dspy.InputField(desc="The maximum year of the medical paper")
71
  specialty = dspy.InputField(desc="The specialty of the medical paper")
72
  context = dspy.OutputField(desc="The answer to the medical question")
73
  final_answer = dspy.OutputField(desc="The answer to the medical question")
 
88
  super().__init__()
89
  self.guardrail = MedicalGuardrail()
90
 
91
+ def forward(self, question, min_year, max_year, specialty):
92
  if not self.guardrail.forward(question):
93
  class DummyResult:
94
  final_answer = "Sorry, I can only answer medical questions. Please ask a question related to medicine or healthcare."
95
  return DummyResult()
96
+ reranked_docs = rerank_with_colbert(question, min_year, max_year, specialty)
97
  context_str = "\n".join(reranked_docs)
98
  return dspy.ChainOfThought(MedicalAnswer)(
99
  question=question,
100
+ min_year=min_year,
101
+ max_year=max_year,
102
  specialty=specialty,
103
  context=context_str
104
  )