cbensimon HF Staff commited on
Commit
56345c1
·
verified ·
1 Parent(s): a2009bc

Simpler initial CUDA loading

Browse files
Files changed (1) hide show
  1. app.py +2 -20
app.py CHANGED
@@ -6,8 +6,8 @@ import gradio as gr
6
  import spaces
7
  from chatterbox.tts_turbo import ChatterboxTurboTTS
8
 
9
- MODEL = None
10
- FIRST_RUN = True
11
 
12
  EVENT_TAGS = [
13
  "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
@@ -72,12 +72,6 @@ def set_seed(seed: int):
72
  random.seed(seed)
73
  np.random.seed(seed)
74
 
75
-
76
- def load_model():
77
- global MODEL
78
- MODEL = ChatterboxTurboTTS.from_pretrained("cpu")
79
- return MODEL
80
-
81
  @spaces.GPU
82
  def generate(
83
  text,
@@ -90,15 +84,6 @@ def generate(
90
  repetition_penalty,
91
  norm_loudness
92
  ):
93
- global MODEL, FIRST_RUN
94
- # Reload if the worker lost the global state
95
- if MODEL is None:
96
- MODEL = load_model()
97
- MODEL.to("cuda")
98
- if FIRST_RUN:
99
- FIRST_RUN = False
100
- MODEL.to("cuda")
101
-
102
  if seed_num != 0:
103
  set_seed(int(seed_num))
104
 
@@ -159,9 +144,6 @@ with gr.Blocks(title="Chatterbox Turbo") as demo:
159
  min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
160
  norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
161
 
162
- # Load on startup (CPU)
163
- demo.load(fn=load_model, inputs=[], outputs=[])
164
-
165
  run_btn.click(
166
  fn=generate,
167
  inputs=[
 
6
  import spaces
7
  from chatterbox.tts_turbo import ChatterboxTurboTTS
8
 
9
+
10
+ MODEL = ChatterboxTurboTTS.from_pretrained("cuda" )
11
 
12
  EVENT_TAGS = [
13
  "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
 
72
  random.seed(seed)
73
  np.random.seed(seed)
74
 
 
 
 
 
 
 
75
  @spaces.GPU
76
  def generate(
77
  text,
 
84
  repetition_penalty,
85
  norm_loudness
86
  ):
 
 
 
 
 
 
 
 
 
87
  if seed_num != 0:
88
  set_seed(int(seed_num))
89
 
 
144
  min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
145
  norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
146
 
 
 
 
147
  run_btn.click(
148
  fn=generate,
149
  inputs=[