cassandrasestier commited on
Commit
24a867b
Β·
verified Β·
1 Parent(s): 19a82d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -14
app.py CHANGED
@@ -1,7 +1,8 @@
1
  # ================================
2
  # πŸͺž MoodMirror+ β€” Conversational Emotional Self-Care
3
- # Uses ONLY the GoEmotions dataset (no pretrained model)
4
- # Trains TF-IDF + OneVsRest Logistic Regression on first run, caches to /data
 
5
  # ================================
6
  import os
7
  import re
@@ -36,8 +37,14 @@ print(f"[MM] SQLite path: {DB_PATH}")
36
  print(f"[MM] Model path: {MODEL_PATH}")
37
 
38
  # ---------------- Crisis & regex ----------------
39
- CRISIS_RE = re.compile(r"\b(self[- ]?harm|suicid|kill myself|end my life|overdose|cutting|i don.?t want to live|can.?t go on)\b", re.I)
40
- CLOSING_RE = re.compile(r"\b(thanks?|thank you|that'?s all|bye|goodbye|see you|take care|ok bye|no thanks?)\b", re.I)
 
 
 
 
 
 
41
 
42
  CRISIS_NUMBERS = {
43
  "United States": "Call or text **988** (24/7 Suicide & Crisis Lifeline). If in immediate danger, call **911**.",
@@ -284,16 +291,22 @@ def train_or_load_model():
284
  Y_val = mlb.transform(y_val)
285
 
286
  clf = Pipeline([
287
- ("tfidf", TfidfVectorizer(lowercase=True, ngram_range=(1, 2), min_df=2, max_df=0.9, strip_accents="unicode")),
288
- ("ovr", OneVsRestClassifier(LogisticRegression(solver="saga", max_iter=1000, n_jobs=-1, class_weight="balanced"), n_jobs=-1))
 
 
 
 
 
289
  ])
290
 
291
  print("[MM] Training classifier...")
292
  clf.fit(X_train, Y_train)
293
-
294
- print(f"[MM] Validation macro F1: {f1_score(Y_val, clf.predict(X_val), average='macro', zero_division=0):.3f}")
295
 
296
  joblib.dump({"version": MODEL_VERSION, "pipeline": clf, "mlb": mlb, "label_names": label_names}, MODEL_PATH)
 
297
  return clf, mlb, label_names
298
 
299
  try:
@@ -304,7 +317,8 @@ except Exception as e:
304
 
305
  # ---------------- Emotion detection ----------------
306
  def classify_text(text: str):
307
- if not CLASSIFIER: return []
 
308
  try:
309
  proba = CLASSIFIER.predict_proba([text])[0]
310
  except AttributeError:
@@ -316,7 +330,8 @@ def classify_text(text: str):
316
 
317
  def detect_emotions(text: str):
318
  chosen = classify_text(text)
319
- if not chosen: return "neutral"
 
320
  bucket = {}
321
  for label, p in chosen:
322
  app = GOEMO_TO_APP.get(label.lower(), "neutral")
@@ -324,9 +339,16 @@ def detect_emotions(text: str):
324
  return max(bucket, key=bucket.get)
325
 
326
  # ---------------- Reply composer ----------------
327
- def compose_support_legacy(main_emotion: str, is_first_msg: bool) -> str:
328
- tip = random.choice(SUGGESTIONS.get(main_emotion, ["Take a slow breath. One small act of kindness can shift your day."]))
329
- quote = random.choice(QUOTES.get(main_emotion, ["β€œNo matter what you feel right now, this moment will pass.”"]))
 
 
 
 
 
 
 
330
  include_quote = random.random() < 0.5
331
 
332
  reply = tip
@@ -351,4 +373,76 @@ def chat_step(message, history, country, save_session):
351
  if CLOSING_RE.search(message):
352
  return ("Thank you πŸ’› Take care of yourself. Small steps matter. 🌿", "#FFFFFF")
353
 
354
- emotion = detect_emotions(" ".join
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ================================
2
  # πŸͺž MoodMirror+ β€” Conversational Emotional Self-Care
3
+ # Dataset-only: trains a TF-IDF + OneVsRest Logistic Regression on GoEmotions
4
+ # Cache du modèle + DB dans /data quand dispo (HF Spaces: activer Persistent storage)
5
+ # Toujours donner au moins 1 conseil + parfois une citation
6
  # ================================
7
  import os
8
  import re
 
37
  print(f"[MM] Model path: {MODEL_PATH}")
38
 
39
  # ---------------- Crisis & regex ----------------
40
+ CRISIS_RE = re.compile(
41
+ r"\b(self[- ]?harm|suicid|kill myself|end my life|overdose|cutting|i don.?t want to live|can.?t go on)\b",
42
+ re.I,
43
+ )
44
+ CLOSING_RE = re.compile(
45
+ r"\b(thanks?|thank you|that'?s all|bye|goodbye|see you|take care|ok bye|no thanks?)\b",
46
+ re.I,
47
+ )
48
 
49
  CRISIS_NUMBERS = {
50
  "United States": "Call or text **988** (24/7 Suicide & Crisis Lifeline). If in immediate danger, call **911**.",
 
291
  Y_val = mlb.transform(y_val)
292
 
293
  clf = Pipeline([
294
+ ("tfidf", TfidfVectorizer(
295
+ lowercase=True, ngram_range=(1, 2), min_df=2, max_df=0.9, strip_accents="unicode"
296
+ )),
297
+ ("ovr", OneVsRestClassifier(
298
+ LogisticRegression(solver="saga", max_iter=1000, n_jobs=-1, class_weight="balanced"),
299
+ n_jobs=-1
300
+ ))
301
  ])
302
 
303
  print("[MM] Training classifier...")
304
  clf.fit(X_train, Y_train)
305
+ macro_f1 = f1_score(Y_val, clf.predict(X_val), average="macro", zero_division=0)
306
+ print(f"[MM] Validation macro F1: {macro_f1:.3f}")
307
 
308
  joblib.dump({"version": MODEL_VERSION, "pipeline": clf, "mlb": mlb, "label_names": label_names}, MODEL_PATH)
309
+ print(f"[MM] Saved model -> {MODEL_PATH}")
310
  return clf, mlb, label_names
311
 
312
  try:
 
317
 
318
  # ---------------- Emotion detection ----------------
319
  def classify_text(text: str):
320
+ if not CLASSIFIER:
321
+ return []
322
  try:
323
  proba = CLASSIFIER.predict_proba([text])[0]
324
  except AttributeError:
 
330
 
331
  def detect_emotions(text: str):
332
  chosen = classify_text(text)
333
+ if not chosen:
334
+ return "neutral"
335
  bucket = {}
336
  for label, p in chosen:
337
  app = GOEMO_TO_APP.get(label.lower(), "neutral")
 
339
  return max(bucket, key=bucket.get)
340
 
341
  # ---------------- Reply composer ----------------
342
+ def compose_support(main_emotion: str, is_first_msg: bool) -> str:
343
+ # Always include an advice tip; 50% chance to add a quote
344
+ tip = random.choice(SUGGESTIONS.get(
345
+ main_emotion,
346
+ ["Take a slow breath. One small act of kindness can shift your day."]
347
+ ))
348
+ quote = random.choice(QUOTES.get(
349
+ main_emotion,
350
+ ["β€œNo matter what you feel right now, this moment will pass.”"]
351
+ ))
352
  include_quote = random.random() < 0.5
353
 
354
  reply = tip
 
373
  if CLOSING_RE.search(message):
374
  return ("Thank you πŸ’› Take care of yourself. Small steps matter. 🌿", "#FFFFFF")
375
 
376
+ recent = " ".join(message.split()[-100:])
377
+ emotion = detect_emotions(recent)
378
+ color = COLOR_MAP.get(emotion, "#FFFFFF")
379
+
380
+ if save_session:
381
+ log_session(country, message, emotion)
382
+
383
+ reply = compose_support(emotion, is_first_msg=not bool(history))
384
+ return reply, color
385
+
386
+ # ---------------- Gradio UI ----------------
387
+ init_db()
388
+
389
+ custom_css = """
390
+ :root, body, .gradio-container { transition: background-color 0.8s ease !important; }
391
+ .typing { font-style: italic; opacity: 0.8; animation: blink 1s infinite; }
392
+ @keyframes blink { 50% {opacity: 0.4;} }
393
+ """
394
+
395
+ with gr.Blocks(css=custom_css, title="πŸͺž MoodMirror+ (Dataset-only Edition)") as demo:
396
+ style_injector = gr.HTML("")
397
+ gr.Markdown(
398
+ "### πŸͺž MoodMirror+ β€” Emotional Support & Inspiration 🌸\n"
399
+ "Powered only by the **GoEmotions dataset** (trained locally on startup).\n\n"
400
+ "_Not medical advice. If you feel unsafe, please reach out for help immediately._"
401
+ )
402
+
403
+ with gr.Row():
404
+ country = gr.Dropdown(choices=list(CRISIS_NUMBERS.keys()),
405
+ value="Other / Not listed", label="Country")
406
+ save_ok = gr.Checkbox(value=False, label="Save anonymized session (no personal data)")
407
+
408
+ chat = gr.Chatbot(height=360)
409
+ msg = gr.Textbox(placeholder="Type how you feel...", label="Your message")
410
+ send = gr.Button("Send")
411
+ typing = gr.Markdown("", elem_classes="typing")
412
+
413
+ # Optional: dataset preview
414
+ with gr.Accordion("πŸ”Ž Preview GoEmotions samples", open=False):
415
+ with gr.Row():
416
+ n_examples = gr.Slider(1, 10, value=5, step=1, label="Number of examples")
417
+ split = gr.Dropdown(["train", "validation", "test"], value="train", label="Split")
418
+ refresh = gr.Button("Show samples")
419
+ table = gr.Dataframe(headers=["text", "labels"], row_count=5, wrap=True)
420
+
421
+ def refresh_samples(n, split_name):
422
+ try:
423
+ ds = load_dataset("google-research-datasets/go_emotions", "simplified")
424
+ names = ds["train"].features["labels"].feature.names
425
+ rows = ds[split_name].shuffle(seed=42).select(range(min(int(n), len(ds[split_name]))))
426
+ return [[t, ", ".join([names[i] for i in labs])] for t, labs in zip(rows["text"], rows["labels"])]
427
+ except Exception as e:
428
+ return [[f"Dataset load error: {e}", ""]]
429
+
430
+ refresh.click(refresh_samples, inputs=[n_examples, split], outputs=[table])
431
+
432
+ def respond(user_msg, chat_hist, country_choice, save_flag):
433
+ if not user_msg or not user_msg.strip():
434
+ yield chat_hist + [[user_msg, "Please share a short sentence about how you feel πŸ™‚"]], "", "", ""
435
+ return
436
+ yield chat_hist, "πŸ’­ MoodMirror is thinking...", "", ""
437
+ reply, color = chat_step(user_msg, chat_hist, country_choice, bool(save_flag))
438
+ style_tag = f"<style>:root,body,.gradio-container{{background:{color}!important;}}</style>"
439
+ yield chat_hist + [[user_msg, reply]], "", style_tag, ""
440
+
441
+ send.click(respond, inputs=[msg, chat, country, save_ok],
442
+ outputs=[chat, typing, style_injector, msg], queue=True)
443
+ msg.submit(respond, inputs=[msg, chat, country, save_ok],
444
+ outputs=[chat, typing, style_injector, msg], queue=True)
445
+
446
+ if __name__ == "__main__":
447
+ demo.queue()
448
+ demo.launch()