cassandrasestier commited on
Commit
3e7d5d4
·
verified ·
1 Parent(s): a220779

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -26
app.py CHANGED
@@ -3,6 +3,7 @@
3
  # Advice + Inspirational quotes + Emotion-based color + SQLite DB
4
  # GoEmotions model + loads GoEmotions dataset ("simplified" config)
5
  # ================================
 
6
  import re
7
  import random
8
  import sqlite3
@@ -13,15 +14,27 @@ import torch
13
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
14
  from datasets import load_dataset
15
 
16
- # --- Paths (persist across Space restarts when Persistent storage is ON) ---
17
- DB_PATH = "/data/moodmirror.db"
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # --- Load GoEmotions dataset ("simplified") ---
20
  # This pulls from: google-research-datasets/go_emotions
21
- # The "simplified" config uses standard splits and provides label indices + names.
22
  try:
23
  ds = load_dataset("google-research-datasets/go_emotions", "simplified")
24
  LABEL_NAMES = ds["train"].features["labels"].feature.names # e.g. ['admiration', ..., 'neutral']
 
25
  except Exception as e:
26
  ds = None
27
  LABEL_NAMES = None
@@ -156,32 +169,47 @@ THRESHOLD = 0.35 # tune to be more/less sensitive
156
 
157
  # --- SQLite setup ---
158
  def get_conn():
159
- return sqlite3.connect(DB_PATH, check_same_thread=False)
 
160
 
161
  def init_db():
162
- conn = get_conn()
163
- c = conn.cursor()
164
- c.execute("""
165
- CREATE TABLE IF NOT EXISTS sessions(
166
- id INTEGER PRIMARY KEY AUTOINCREMENT,
167
- ts TEXT,
168
- country TEXT,
169
- user_text TEXT,
170
- main_emotion TEXT
171
- )
172
- """)
173
- conn.commit()
174
- conn.close()
 
 
 
 
 
 
 
175
 
176
  def log_session(country, msg, emotion):
177
- conn = get_conn()
178
- c = conn.cursor()
179
- c.execute(
180
- "INSERT INTO sessions(ts, country, user_text, main_emotion) VALUES(?,?,?,?)",
181
- (datetime.utcnow().isoformat(timespec="seconds"), country, msg[:500], emotion),
182
- )
183
- conn.commit()
184
- conn.close()
 
 
 
 
 
 
 
185
 
186
  # --- Emotion detection (multi-label via model) ---
187
  def detect_emotions(text: str):
@@ -191,7 +219,7 @@ def detect_emotions(text: str):
191
  - main_app: top mapped category for UI/tips/colors
192
  """
193
  try:
194
- preds = pipe(text)[0] # list of {'label': 'joy', 'score': 0.82}, for all labels
195
  chosen = [p for p in preds if p["score"] >= THRESHOLD]
196
  chosen.sort(key=lambda x: x["score"], reverse=True)
197
 
 
3
  # Advice + Inspirational quotes + Emotion-based color + SQLite DB
4
  # GoEmotions model + loads GoEmotions dataset ("simplified" config)
5
  # ================================
6
+ import os
7
  import re
8
  import random
9
  import sqlite3
 
14
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
15
  from datasets import load_dataset
16
 
17
+ # --- Storage paths (robust across local dev vs. HF Spaces) ---
18
+ def _pick_data_dir():
19
+ # Prefer /data if it exists AND is writable (Spaces with persistent storage).
20
+ if os.path.isdir("/data") and os.access("/data", os.W_OK):
21
+ return "/data"
22
+ # Otherwise, fall back to the repo working directory.
23
+ return os.getcwd()
24
+
25
+ DATA_DIR = os.getenv("MM_DATA_DIR", _pick_data_dir())
26
+ os.makedirs(DATA_DIR, exist_ok=True)
27
+ DB_PATH = os.path.join(DATA_DIR, "moodmirror.db")
28
+ print(f"[MM] Using data dir: {DATA_DIR}")
29
+ print(f"[MM] SQLite path: {DB_PATH}")
30
 
31
  # --- Load GoEmotions dataset ("simplified") ---
32
  # This pulls from: google-research-datasets/go_emotions
33
+ # The "simplified" config uses train/validation/test splits and label indices.
34
  try:
35
  ds = load_dataset("google-research-datasets/go_emotions", "simplified")
36
  LABEL_NAMES = ds["train"].features["labels"].feature.names # e.g. ['admiration', ..., 'neutral']
37
+ print("[MM] GoEmotions dataset loaded.")
38
  except Exception as e:
39
  ds = None
40
  LABEL_NAMES = None
 
169
 
170
  # --- SQLite setup ---
171
  def get_conn():
172
+ # timeout helps if multiple requests hit the DB at once
173
+ return sqlite3.connect(DB_PATH, check_same_thread=False, timeout=10)
174
 
175
  def init_db():
176
+ conn = None
177
+ try:
178
+ conn = get_conn()
179
+ c = conn.cursor()
180
+ c.execute("""
181
+ CREATE TABLE IF NOT EXISTS sessions(
182
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
183
+ ts TEXT,
184
+ country TEXT,
185
+ user_text TEXT,
186
+ main_emotion TEXT
187
+ )
188
+ """)
189
+ conn.commit()
190
+ finally:
191
+ try:
192
+ if conn is not None:
193
+ conn.close()
194
+ except Exception:
195
+ pass
196
 
197
  def log_session(country, msg, emotion):
198
+ conn = None
199
+ try:
200
+ conn = get_conn()
201
+ c = conn.cursor()
202
+ c.execute(
203
+ "INSERT INTO sessions(ts, country, user_text, main_emotion) VALUES(?,?,?,?)",
204
+ (datetime.utcnow().isoformat(timespec="seconds"), country, msg[:500], emotion),
205
+ )
206
+ conn.commit()
207
+ finally:
208
+ try:
209
+ if conn is not None:
210
+ conn.close()
211
+ except Exception:
212
+ pass
213
 
214
  # --- Emotion detection (multi-label via model) ---
215
  def detect_emotions(text: str):
 
219
  - main_app: top mapped category for UI/tips/colors
220
  """
221
  try:
222
+ preds = pipe(text)[0] # list of {'label': 'joy', 'score': 0.82} for all labels
223
  chosen = [p for p in preds if p["score"] >= THRESHOLD]
224
  chosen.sort(key=lambda x: x["score"], reverse=True)
225