OliverPerrin commited on
Commit
7977c7d
·
1 Parent(s): 18fc263

Clean up demo_gradio.py with consistent commenting style

Browse files
Files changed (1) hide show
  1. scripts/demo_gradio.py +76 -32
scripts/demo_gradio.py CHANGED
@@ -1,4 +1,14 @@
1
- """Minimal Gradio demo for LexiMind multitask model."""
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
@@ -8,8 +18,12 @@ from pathlib import Path
8
 
9
  import gradio as gr
10
 
 
 
 
11
  SCRIPT_DIR = Path(__file__).resolve().parent
12
  PROJECT_ROOT = SCRIPT_DIR.parent
 
13
  if str(PROJECT_ROOT) not in sys.path:
14
  sys.path.insert(0, str(PROJECT_ROOT))
15
 
@@ -21,45 +35,71 @@ from src.utils.logging import configure_logging, get_logger
21
  configure_logging()
22
  logger = get_logger(__name__)
23
 
 
 
24
  OUTPUTS_DIR = PROJECT_ROOT / "outputs"
25
  EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
26
 
 
 
 
 
 
 
 
 
 
 
27
  _pipeline = None
28
 
29
 
30
  def get_pipeline():
 
31
  global _pipeline
32
- if _pipeline is None:
33
- checkpoint_path = Path("checkpoints/best.pt")
34
- if not checkpoint_path.exists():
35
- checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
36
- hf_hub_download(
37
- repo_id="OliverPerrin/LexiMind-Model",
38
- filename="best.pt",
39
- local_dir="checkpoints",
40
- local_dir_use_symlinks=False,
41
- )
42
- _pipeline, _ = create_inference_pipeline(
43
- tokenizer_dir="artifacts/hf_tokenizer/",
44
- checkpoint_path="checkpoints/best.pt",
45
- labels_path="artifacts/labels.json",
46
  )
 
 
 
 
 
 
47
  return _pipeline
48
 
49
 
 
 
 
50
  def analyze(text: str) -> str:
51
- """Run all three tasks and return results as formatted text."""
 
 
 
 
52
  if not text or not text.strip():
53
  return "Please enter some text to analyze."
54
 
55
  try:
56
  pipe = get_pipeline()
57
 
58
- # Summarization
59
  summary = pipe.summarize([text], max_length=128)[0].strip() or "(empty)"
60
-
61
- # Emotion detection
62
  emotions = pipe.predict_emotions([text], threshold=0.5)[0]
 
 
 
63
  if emotions.labels:
64
  emotion_str = ", ".join(
65
  f"{lbl} ({score:.1%})"
@@ -68,10 +108,6 @@ def analyze(text: str) -> str:
68
  else:
69
  emotion_str = "No strong emotions detected"
70
 
71
- # Topic classification
72
- topic = pipe.predict_topics([text])[0]
73
- topic_str = f"{topic.label} ({topic.confidence:.1%})"
74
-
75
  return f"""## Summary
76
  {summary}
77
 
@@ -79,7 +115,7 @@ def analyze(text: str) -> str:
79
  {emotion_str}
80
 
81
  ## Topic
82
- {topic_str}
83
  """
84
  except Exception as e:
85
  logger.error("Analysis failed: %s", e, exc_info=True)
@@ -87,7 +123,7 @@ def analyze(text: str) -> str:
87
 
88
 
89
  def get_metrics() -> str:
90
- """Load evaluation metrics as markdown."""
91
  if not EVAL_REPORT_PATH.exists():
92
  return "No evaluation report found. Run `scripts/evaluate.py` first."
93
 
@@ -95,6 +131,7 @@ def get_metrics() -> str:
95
  with open(EVAL_REPORT_PATH) as f:
96
  r = json.load(f)
97
 
 
98
  lines = [
99
  "## Model Performance\n",
100
  "| Task | Metric | Score |",
@@ -108,10 +145,13 @@ def get_metrics() -> str:
108
  "| Label | Precision | Recall | F1 |",
109
  "|-------|-----------|--------|-----|",
110
  ]
111
- for k, v in r["topic"]["classification_report"].items():
112
- if isinstance(v, dict) and "precision" in v:
 
 
113
  lines.append(
114
- f"| {k} | {v['precision']:.3f} | {v['recall']:.3f} | {v['f1-score']:.3f} |"
 
115
  )
116
 
117
  return "\n".join(lines)
@@ -119,15 +159,16 @@ def get_metrics() -> str:
119
  return f"Error loading metrics: {e}"
120
 
121
 
122
- SAMPLE = """Artificial intelligence is rapidly transforming technology. Machine learning algorithms process vast amounts of data, identifying patterns with unprecedented accuracy. From healthcare to finance, AI is revolutionizing industries worldwide. However, ethical considerations around privacy and bias remain critical challenges."""
123
 
124
  with gr.Blocks(title="LexiMind Demo") as demo:
125
  gr.Markdown(
126
- "# LexiMind NLP Demo\nMulti-task model: summarization, emotion detection, topic classification."
 
127
  )
128
 
129
  with gr.Tab("Analyze"):
130
- text_input = gr.Textbox(label="Input Text", lines=6, value=SAMPLE)
131
  analyze_btn = gr.Button("Analyze", variant="primary")
132
  output = gr.Markdown(label="Results")
133
  analyze_btn.click(fn=analyze, inputs=text_input, outputs=output)
@@ -135,6 +176,9 @@ with gr.Blocks(title="LexiMind Demo") as demo:
135
  with gr.Tab("Metrics"):
136
  gr.Markdown(get_metrics())
137
 
 
 
 
138
  if __name__ == "__main__":
139
- get_pipeline() # Pre-load
140
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ """
2
+ Gradio demo for LexiMind multi-task NLP model.
3
+
4
+ Provides a simple web interface for the three core tasks:
5
+ - Summarization: Generates concise summaries of input text
6
+ - Emotion Detection: Identifies emotional content with confidence scores
7
+ - Topic Classification: Categorizes text into predefined topics
8
+
9
+ Author: Oliver Perrin
10
+ Date: 2025-12-04
11
+ """
12
 
13
  from __future__ import annotations
14
 
 
18
 
19
  import gradio as gr
20
 
21
+ # --------------- Path Setup ---------------
22
+ # Ensure local src package is importable when running script directly
23
+
24
  SCRIPT_DIR = Path(__file__).resolve().parent
25
  PROJECT_ROOT = SCRIPT_DIR.parent
26
+
27
  if str(PROJECT_ROOT) not in sys.path:
28
  sys.path.insert(0, str(PROJECT_ROOT))
29
 
 
35
  configure_logging()
36
  logger = get_logger(__name__)
37
 
38
+ # --------------- Constants ---------------
39
+
40
  OUTPUTS_DIR = PROJECT_ROOT / "outputs"
41
  EVAL_REPORT_PATH = OUTPUTS_DIR / "evaluation_report.json"
42
 
43
+ SAMPLE_TEXT = (
44
+ "Artificial intelligence is rapidly transforming technology. "
45
+ "Machine learning algorithms process vast amounts of data, identifying "
46
+ "patterns with unprecedented accuracy. From healthcare to finance, AI is "
47
+ "revolutionizing industries worldwide. However, ethical considerations "
48
+ "around privacy and bias remain critical challenges."
49
+ )
50
+
51
+ # --------------- Pipeline Management ---------------
52
+
53
  _pipeline = None
54
 
55
 
56
  def get_pipeline():
57
+ """Lazy-load the inference pipeline, downloading checkpoint if needed."""
58
  global _pipeline
59
+ if _pipeline is not None:
60
+ return _pipeline
61
+
62
+ checkpoint_path = Path("checkpoints/best.pt")
63
+
64
+ # Download from HuggingFace Hub if checkpoint doesn't exist locally
65
+ if not checkpoint_path.exists():
66
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
67
+ hf_hub_download(
68
+ repo_id="OliverPerrin/LexiMind-Model",
69
+ filename="best.pt",
70
+ local_dir="checkpoints",
71
+ local_dir_use_symlinks=False,
 
72
  )
73
+
74
+ _pipeline, _ = create_inference_pipeline(
75
+ tokenizer_dir="artifacts/hf_tokenizer/",
76
+ checkpoint_path="checkpoints/best.pt",
77
+ labels_path="artifacts/labels.json",
78
+ )
79
  return _pipeline
80
 
81
 
82
+ # --------------- Core Functions ---------------
83
+
84
+
85
  def analyze(text: str) -> str:
86
+ """
87
+ Run all three tasks on input text.
88
+
89
+ Returns markdown-formatted results for display in Gradio.
90
+ """
91
  if not text or not text.strip():
92
  return "Please enter some text to analyze."
93
 
94
  try:
95
  pipe = get_pipeline()
96
 
97
+ # Run each task
98
  summary = pipe.summarize([text], max_length=128)[0].strip() or "(empty)"
 
 
99
  emotions = pipe.predict_emotions([text], threshold=0.5)[0]
100
+ topic = pipe.predict_topics([text])[0]
101
+
102
+ # Format emotion results
103
  if emotions.labels:
104
  emotion_str = ", ".join(
105
  f"{lbl} ({score:.1%})"
 
108
  else:
109
  emotion_str = "No strong emotions detected"
110
 
 
 
 
 
111
  return f"""## Summary
112
  {summary}
113
 
 
115
  {emotion_str}
116
 
117
  ## Topic
118
+ {topic.label} ({topic.confidence:.1%})
119
  """
120
  except Exception as e:
121
  logger.error("Analysis failed: %s", e, exc_info=True)
 
123
 
124
 
125
  def get_metrics() -> str:
126
+ """Load evaluation metrics from JSON and format as markdown tables."""
127
  if not EVAL_REPORT_PATH.exists():
128
  return "No evaluation report found. Run `scripts/evaluate.py` first."
129
 
 
131
  with open(EVAL_REPORT_PATH) as f:
132
  r = json.load(f)
133
 
134
+ # Build overall metrics table
135
  lines = [
136
  "## Model Performance\n",
137
  "| Task | Metric | Score |",
 
145
  "| Label | Precision | Recall | F1 |",
146
  "|-------|-----------|--------|-----|",
147
  ]
148
+
149
+ # Add per-class metrics
150
+ for label, metrics in r["topic"]["classification_report"].items():
151
+ if isinstance(metrics, dict) and "precision" in metrics:
152
  lines.append(
153
+ f"| {label} | {metrics['precision']:.3f} | "
154
+ f"{metrics['recall']:.3f} | {metrics['f1-score']:.3f} |"
155
  )
156
 
157
  return "\n".join(lines)
 
159
  return f"Error loading metrics: {e}"
160
 
161
 
162
+ # --------------- Gradio Interface ---------------
163
 
164
  with gr.Blocks(title="LexiMind Demo") as demo:
165
  gr.Markdown(
166
+ "# LexiMind NLP Demo\n"
167
+ "Multi-task model: summarization, emotion detection, topic classification."
168
  )
169
 
170
  with gr.Tab("Analyze"):
171
+ text_input = gr.Textbox(label="Input Text", lines=6, value=SAMPLE_TEXT)
172
  analyze_btn = gr.Button("Analyze", variant="primary")
173
  output = gr.Markdown(label="Results")
174
  analyze_btn.click(fn=analyze, inputs=text_input, outputs=output)
 
176
  with gr.Tab("Metrics"):
177
  gr.Markdown(get_metrics())
178
 
179
+
180
+ # --------------- Entry Point ---------------
181
+
182
  if __name__ == "__main__":
183
+ get_pipeline() # Pre-load to fail fast if checkpoint missing
184
  demo.launch(server_name="0.0.0.0", server_port=7860)