|
|
""" |
|
|
Hybrid Mental Health Chatbot |
|
|
============================= |
|
|
Background Classifier (DistilBERT) + Streaming LLM (GPT-OSS) |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLASSIFIER_MODEL = "YureiYuri/empathy" |
|
|
LLM_MODEL = "meta-llama/Llama-3.2-3B-Instruct" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π€ Loading emotion classifier...") |
|
|
tokenizer = DistilBertTokenizerFast.from_pretrained(CLASSIFIER_MODEL) |
|
|
classifier = DistilBertForSequenceClassification.from_pretrained(CLASSIFIER_MODEL) |
|
|
print("β
Classifier loaded!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INTENT_WEIGHTS = { |
|
|
"depression": {"dejection": 3.0, "mood": 2.0, "calmness": 1.5}, |
|
|
"suicide": {"dejection": 5.0, "mood": 3.0, "calmness": 3.5}, |
|
|
"trauma": {"dejection": 3.5, "mood": 1.5, "calmness": 2.5}, |
|
|
"grief": {"dejection": 2.5, "mood": 1.5, "calmness": 1.0}, |
|
|
"self_esteem": {"dejection": 1.0, "mood": 3.5, "calmness": 0.5}, |
|
|
"anxiety": {"dejection": 0.5, "mood": 1.0, "calmness": 3.5}, |
|
|
"sleep_issues": {"dejection": 1.0, "mood": 1.0, "calmness": 2.5}, |
|
|
"anger": {"dejection": 1.0, "mood": 1.5, "calmness": 3.5}, |
|
|
"relationship": {"dejection": 1.5, "mood": 2.0, "calmness": 2.0}, |
|
|
"family": {"dejection": 1.5, "mood": 2.0, "calmness": 1.5}, |
|
|
} |
|
|
|
|
|
SESSION_STATE = { |
|
|
"dejection": 0.0, |
|
|
"calmness": 0.0, |
|
|
"mood": 0.0, |
|
|
"severity": 0.0, |
|
|
"mode": "supportive", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_intent(text): |
|
|
"""Classify emotional intent silently""" |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = classifier(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=1) |
|
|
pred = torch.argmax(probs).item() |
|
|
intent = classifier.config.id2label[pred] |
|
|
confidence = probs[0][pred].item() |
|
|
return intent, confidence |
|
|
|
|
|
|
|
|
def update_state(intent, confidence): |
|
|
"""Update emotional metrics in background""" |
|
|
|
|
|
SESSION_STATE["dejection"] *= 0.88 |
|
|
SESSION_STATE["calmness"] *= 0.88 |
|
|
SESSION_STATE["mood"] *= 0.88 |
|
|
|
|
|
|
|
|
weights = INTENT_WEIGHTS.get(intent, {"dejection": 0.5, "mood": 0.5, "calmness": 0.5}) |
|
|
SESSION_STATE["dejection"] += weights["dejection"] * confidence |
|
|
SESSION_STATE["mood"] += weights["mood"] * confidence |
|
|
SESSION_STATE["calmness"] += weights["calmness"] * confidence |
|
|
|
|
|
|
|
|
SESSION_STATE["severity"] = ( |
|
|
SESSION_STATE["dejection"] * 0.5 + |
|
|
SESSION_STATE["mood"] * 0.25 + |
|
|
SESSION_STATE["calmness"] * 0.25 |
|
|
) |
|
|
|
|
|
|
|
|
if intent == "suicide" or SESSION_STATE["severity"] > 35: |
|
|
SESSION_STATE["mode"] = "crisis" |
|
|
elif SESSION_STATE["severity"] > 20: |
|
|
SESSION_STATE["mode"] = "urgent" |
|
|
elif SESSION_STATE["severity"] > 10: |
|
|
SESSION_STATE["mode"] = "concerned" |
|
|
else: |
|
|
SESSION_STATE["mode"] = "supportive" |
|
|
|
|
|
print(f"π {intent} ({confidence:.2f}) | Mode: {SESSION_STATE['mode']} | Severity: {SESSION_STATE['severity']:.1f}") |
|
|
|
|
|
|
|
|
def get_system_prompt(): |
|
|
"""Generate system prompt based on emotional state""" |
|
|
mode = SESSION_STATE["mode"] |
|
|
|
|
|
base = "You are a warm, empathetic mental health support assistant." |
|
|
|
|
|
if mode == "crisis": |
|
|
return f"{base} The user is in crisis. Show genuine concern and guide them to crisis resources." |
|
|
elif mode == "urgent": |
|
|
return f"{base} The user shows significant distress (severity {SESSION_STATE['severity']:.0f}/100). Be extra empathetic and supportive." |
|
|
elif mode == "concerned": |
|
|
return f"{base} The user is experiencing moderate distress. Show increased warmth and validation." |
|
|
else: |
|
|
return f"{base} Have a natural, supportive conversation. Keep responses concise (2-3 sentences). Listen actively." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def respond( |
|
|
message, |
|
|
history: list[dict[str, str]], |
|
|
max_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
hf_token: gr.OAuthToken, |
|
|
): |
|
|
"""Main conversation handler with background emotion tracking""" |
|
|
|
|
|
|
|
|
intent, confidence = classify_intent(message) |
|
|
update_state(intent, confidence) |
|
|
|
|
|
|
|
|
if SESSION_STATE["mode"] == "crisis": |
|
|
crisis_msg = ( |
|
|
"I'm really concerned about your safety right now.\n\n" |
|
|
"**Please reach out immediately:**\n" |
|
|
"β’ Call/text **988** (Suicide & Crisis Lifeline)\n" |
|
|
"β’ Text **HOME to 741741** (Crisis Text Line)\n" |
|
|
"β’ Call **911** or go to nearest ER\n\n" |
|
|
"You don't have to face this alone. Help is available 24/7." |
|
|
) |
|
|
yield crisis_msg |
|
|
return |
|
|
|
|
|
|
|
|
if not hf_token or not hf_token.token: |
|
|
yield "β οΈ Please sign in with your Hugging Face account (click the button in the sidebar) to start chatting." |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
client = InferenceClient(token=hf_token.token, model=LLM_MODEL) |
|
|
|
|
|
system_message = get_system_prompt() |
|
|
messages = [{"role": "system", "content": system_message}] |
|
|
messages.extend(history) |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
response = "" |
|
|
for msg in client.chat_completion( |
|
|
messages, |
|
|
max_tokens=max_tokens, |
|
|
stream=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
): |
|
|
if msg.choices and msg.choices[0].delta.content: |
|
|
response += msg.choices[0].delta.content |
|
|
yield response |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
print(f"β LLM Error: {error_msg}") |
|
|
|
|
|
if "401" in error_msg or "Unauthorized" in error_msg: |
|
|
yield "β οΈ Authentication error. Please sign in with your Hugging Face account using the button in the sidebar." |
|
|
elif "quota" in error_msg.lower() or "rate" in error_msg.lower(): |
|
|
yield "β οΈ API rate limit reached. Please try again in a moment." |
|
|
else: |
|
|
yield "I'm here to listen and support you. Could you tell me more about what you're experiencing?" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π§ Mental Health Support Chatbot") |
|
|
gr.Markdown("π A safe space to talk. Your emotional state is tracked privately to provide better support.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.ChatInterface( |
|
|
respond, |
|
|
type="messages", |
|
|
chatbot=gr.Chatbot(height=500, type="messages", placeholder="Hi there! I'm here to listen. What's on your mind?"), |
|
|
textbox=gr.Textbox(placeholder="Type your message...", lines=2, show_label=False), |
|
|
additional_inputs=[ |
|
|
gr.Slider(minimum=128, maximum=1024, value=300, step=32, label="Max tokens", visible=False), |
|
|
gr.Slider(minimum=0.1, maximum=1.5, value=0.8, step=0.1, label="Temperature", visible=False), |
|
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.92, step=0.05, label="Top-p", visible=False), |
|
|
], |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π Authentication") |
|
|
gr.LoginButton(size="lg") |
|
|
gr.Markdown("*Sign in with Hugging Face to chat*") |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
gr.Markdown("### π Emotional Tracking") |
|
|
gr.Markdown( |
|
|
"Your emotional state is monitored in the background to adjust support levels:\n\n" |
|
|
"π’ **Supportive** - General listening\n" |
|
|
"π‘ **Concerned** - Increased empathy\n" |
|
|
"π΄ **Urgent** - Active support\n" |
|
|
"π¨ **Crisis** - Immediate resources" |
|
|
) |
|
|
|
|
|
metrics_display = gr.Markdown(f""" |
|
|
**Current Mode:** π’ SUPPORTIVE |
|
|
|
|
|
**Severity:** 0.0/100 |
|
|
- Dejection: 0.0 |
|
|
- Mood: 0.0 |
|
|
- Calmness: 0.0 |
|
|
|
|
|
*Metrics update as you chat* |
|
|
""") |
|
|
|
|
|
gr.Markdown( |
|
|
"---\n" |
|
|
"β οΈ **In Crisis?** Call **988** (US) or your local emergency number.\n\n" |
|
|
"*This is a support tool, not a replacement for professional mental health care.*" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |