ollieollie commited on
Commit
3ce1983
·
verified ·
1 Parent(s): 831dd4e

Update chatterbox/tts_turbo.py

Browse files
Files changed (1) hide show
  1. chatterbox/tts_turbo.py +153 -296
chatterbox/tts_turbo.py CHANGED
@@ -1,305 +1,162 @@
 
1
  import os
2
- import math
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
-
6
- import librosa
7
  import torch
8
- import perth
9
- import pyloudnorm as ln
10
-
11
- from safetensors.torch import load_file
12
- from huggingface_hub import snapshot_download
13
- from transformers import AutoTokenizer
14
-
15
- from .models.t3 import T3
16
- from .models.s3tokenizer import S3_SR
17
- from .models.s3gen import S3GEN_SR, S3Gen
18
- from .models.tokenizers import EnTokenizer
19
- from .models.voice_encoder import VoiceEncoder
20
- from .models.t3.modules.cond_enc import T3Cond
21
- from .models.t3.modules.t3_config import T3Config
22
- from .models.s3gen.const import S3GEN_SIL
23
- import logging
24
- logger = logging.getLogger(__name__)
25
-
26
- REPO_ID = "ResembleAI/chatterbox-turbo"
27
-
28
-
29
- def punc_norm(text: str) -> str:
30
- """
31
- Quick cleanup func for punctuation from LLMs or
32
- containing chars not seen often in the dataset
33
- """
34
- if len(text) == 0:
35
- return "You need to add some text for me to talk."
36
-
37
- # Capitalise first letter
38
- if text[0].islower():
39
- text = text[0].upper() + text[1:]
40
-
41
- # Remove multiple space chars
42
- text = " ".join(text.split())
43
-
44
- # Replace uncommon/llm punc
45
- punc_to_replace = [
46
- ("…", ", "),
47
- (":", ","),
48
- ("—", "-"),
49
- ("–", "-"),
50
- (" ,", ","),
51
- ("“", "\""),
52
- ("", "\""),
53
- ("‘", "'"),
54
- ("", "'"),
55
- ]
56
- for old_char_sequence, new_char in punc_to_replace:
57
- text = text.replace(old_char_sequence, new_char)
58
-
59
- # Add full stop if no ending punc
60
- text = text.rstrip(" ")
61
- sentence_enders = {".", "!", "?", "-", ","}
62
- if not any(text.endswith(p) for p in sentence_enders):
63
- text += "."
64
-
65
- return text
66
-
67
-
68
- @dataclass
69
- class Conditionals:
70
- """
71
- Conditionals for T3 and S3Gen
72
- - T3 conditionals:
73
- - speaker_emb
74
- - clap_emb
75
- - cond_prompt_speech_tokens
76
- - cond_prompt_speech_emb
77
- - emotion_adv
78
- - S3Gen conditionals:
79
- - prompt_token
80
- - prompt_token_len
81
- - prompt_feat
82
- - prompt_feat_len
83
- - embedding
84
- """
85
- t3: T3Cond
86
- gen: dict
87
-
88
- def to(self, device):
89
- self.t3 = self.t3.to(device=device)
90
- for k, v in self.gen.items():
91
- if torch.is_tensor(v):
92
- self.gen[k] = v.to(device=device)
93
- return self
94
-
95
- def save(self, fpath: Path):
96
- arg_dict = dict(
97
- t3=self.t3.__dict__,
98
- gen=self.gen
99
- )
100
- torch.save(arg_dict, fpath)
101
-
102
- @classmethod
103
- def load(cls, fpath, map_location="cpu"):
104
- if isinstance(map_location, str):
105
- map_location = torch.device(map_location)
106
- kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
107
- return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
108
-
109
-
110
- class ChatterboxTurboTTS:
111
- ENC_COND_LEN = 15 * S3_SR
112
- DEC_COND_LEN = 10 * S3GEN_SR
113
-
114
- def __init__(
115
- self,
116
- t3: T3,
117
- s3gen: S3Gen,
118
- ve: VoiceEncoder,
119
- tokenizer: EnTokenizer,
120
- device: str,
121
- conds: Conditionals = None,
122
- ):
123
- self.sr = S3GEN_SR # sample rate of synthesized audio
124
- self.t3 = t3
125
- self.s3gen = s3gen
126
- self.ve = ve
127
- self.tokenizer = tokenizer
128
- self.device = device
129
- self.conds = conds
130
- self.watermarker = perth.PerthImplicitWatermarker()
131
-
132
- def to(self, device):
133
- self.device = device
134
- self.t3 = self.t3.to(device)
135
- self.s3gen = self.s3gen.to(device)
136
- self.ve = self.ve.to(device)
137
- if self.conds is not None:
138
- self.conds = self.conds.to(device)
139
- return self
140
 
141
- @classmethod
142
- def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS':
143
- ckpt_dir = Path(ckpt_dir)
144
-
145
- # Always load to CPU first for non-CUDA devices to handle CUDA-saved models
146
- if device in ["cpu", "mps"]:
147
- map_location = torch.device('cpu')
148
- else:
149
- map_location = None
150
-
151
- ve = VoiceEncoder()
152
- ve.load_state_dict(
153
- load_file(ckpt_dir / "ve.safetensors")
154
- )
155
- ve.to(device).eval()
156
-
157
- # Turbo specific hp
158
- hp = T3Config(text_tokens_dict_size=50276)
159
- hp.llama_config_name = "GPT2_medium"
160
- hp.speech_tokens_dict_size = 6563
161
- hp.input_pos_emb = None
162
- hp.speech_cond_prompt_len = 375
163
- hp.use_perceiver_resampler = False
164
- hp.emotion_adv = False
165
-
166
- t3 = T3(hp)
167
- t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors")
168
- if "model" in t3_state.keys():
169
- t3_state = t3_state["model"][0]
170
- t3.load_state_dict(t3_state)
171
- del t3.tfmr.wte
172
- t3.to(device).eval()
173
-
174
- s3gen = S3Gen(meanflow=True)
175
- weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors")
176
- s3gen.load_state_dict(
177
- weights, strict=True
178
- )
179
- s3gen.to(device).eval()
180
-
181
- tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
182
- if tokenizer.pad_token is None:
183
- tokenizer.pad_token = tokenizer.eos_token
184
- if len(tokenizer) != 50276:
185
- print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276")
186
-
187
- conds = None
188
- builtin_voice = ckpt_dir / "conds.pt"
189
- if builtin_voice.exists():
190
- conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
191
-
192
- return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
193
-
194
- @classmethod
195
- def from_pretrained(cls, device) -> 'ChatterboxTurboTTS':
196
- # Check if MPS is available on macOS
197
- if device == "mps" and not torch.backends.mps.is_available():
198
- if not torch.backends.mps.is_built():
199
- print("MPS not available because the current PyTorch install was not built with MPS enabled.")
200
- else:
201
- print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
202
- device = "cpu"
203
-
204
- local_path = snapshot_download(
205
- repo_id=REPO_ID,
206
- token=os.getenv("HF_TOKEN") or True,
207
- # Optional: Filter to download only what you need
208
- allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
209
- )
210
-
211
- return cls.from_local(local_path, device)
212
-
213
- def norm_loudness(self, wav, sr, target_lufs=-27):
214
- try:
215
- meter = ln.Meter(sr)
216
- loudness = meter.integrated_loudness(wav)
217
- gain_db = target_lufs - loudness
218
- gain_linear = 10.0 ** (gain_db / 20.0)
219
- if math.isfinite(gain_linear) and gain_linear > 0.0:
220
- wav = wav * gain_linear
221
- except Exception as e:
222
- print(f"Warning: Error in norm_loudness, skipping: {e}")
223
-
224
- return wav
225
-
226
- def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True):
227
- ## Load and norm reference wav
228
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
229
-
230
- assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!"
231
-
232
- if norm_loudness:
233
- s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr)
234
-
235
- ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
236
-
237
- s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
238
- s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
239
-
240
- # Speech cond prompt tokens
241
- if plen := self.t3.hp.speech_cond_prompt_len:
242
- s3_tokzr = self.s3gen.tokenizer
243
- t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
244
- t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
245
-
246
- # Voice-encoder speaker embedding
247
- ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
248
- ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
249
-
250
- t3_cond = T3Cond(
251
- speaker_emb=ve_embed,
252
- cond_prompt_speech_tokens=t3_cond_prompt_tokens,
253
- emotion_adv=exaggeration * torch.ones(1, 1, 1),
254
- ).to(device=self.device)
255
- self.conds = Conditionals(t3_cond, s3gen_ref_dict)
256
 
257
- def generate(
258
- self,
259
- text,
260
- repetition_penalty=1.2,
261
- min_p=0.00,
262
- top_p=0.95,
263
- audio_prompt_path=None,
264
- exaggeration=0.0,
265
- cfg_weight=0.0,
266
- temperature=0.8,
267
- top_k=1000,
268
- norm_loudness=True,
269
- ):
270
- if audio_prompt_path:
271
- self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness)
272
- else:
273
- assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
274
 
275
- if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
276
- logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
 
277
 
278
- # Norm and tokenize text
279
- text = punc_norm(text)
280
- text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
281
- text_tokens = text_tokens.input_ids.to(self.device)
 
 
282
 
283
- speech_tokens = self.t3.inference_turbo(
284
- t3_cond=self.conds.t3,
285
- text_tokens=text_tokens,
286
- temperature=temperature,
287
- top_k=top_k,
288
- top_p=top_p,
289
- repetition_penalty=repetition_penalty,
290
- )
291
 
292
- # Remove OOV tokens and add silence to end
293
- speech_tokens = speech_tokens[speech_tokens < 6561]
294
- speech_tokens = speech_tokens.to(self.device)
295
- silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device)
296
- speech_tokens = torch.cat([speech_tokens, silence])
297
 
298
- wav, _ = self.s3gen.inference(
299
- speech_tokens=speech_tokens,
300
- ref_dict=self.conds.gen,
301
- n_cfm_timesteps=2,
302
- )
303
- wav = wav.squeeze(0).detach().cpu().numpy()
304
- watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
305
- return torch.from_numpy(watermarked_wav).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
  import os
3
+ import numpy as np
 
 
 
 
4
  import torch
5
+ import gradio as gr
6
+ import spaces
7
+ from chatterbox.tts_turbo import ChatterboxTurboTTS
8
+
9
+ # --- 1. FORCE CPU FOR GLOBAL LOADING ---
10
+ # ZeroGPU forbids CUDA during startup. We only move to CUDA inside the decorated function.
11
+ DEVICE = "cpu"
12
+
13
+ MODEL = None
14
+
15
+ EVENT_TAGS = [
16
+ "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
17
+ "[sniff]", "[gasp]", "[chuckle]", "[laugh]"
18
+ ]
19
+
20
+ CUSTOM_CSS = """
21
+ .tag-container {
22
+ display: flex !important;
23
+ flex-wrap: wrap !important;
24
+ gap: 8px !important;
25
+ margin-top: 5px !important;
26
+ margin-bottom: 10px !important;
27
+ border: none !important;
28
+ background: transparent !important;
29
+ }
30
+
31
+ .tag-btn {
32
+ min-width: fit-content !important;
33
+ width: auto !important;
34
+ height: 32px !important;
35
+ font-size: 13px !important;
36
+ background: #eef2ff !important;
37
+ border: 1px solid #c7d2fe !important;
38
+ color: #3730a3 !important;
39
+ border-radius: 6px !important;
40
+ padding: 0 10px !important;
41
+ margin: 0 !important;
42
+ box-shadow: none !important;
43
+ }
44
+
45
+ .tag-btn:hover {
46
+ background: #c7d2fe !important;
47
+ transform: translateY(-1px);
48
+ }
49
+ """
50
+
51
+ INSERT_TAG_JS = """
52
+ (tag_val, current_text) => {
53
+ const textarea = document.querySelector('#main_textbox textarea');
54
+ if (!textarea) return current_text + " " + tag_val;
55
+
56
+ const start = textarea.selectionStart;
57
+ const end = textarea.selectionEnd;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ let prefix = " ";
60
+ let suffix = " ";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ if (start === 0) prefix = "";
63
+ else if (current_text[start - 1] === ' ') prefix = "";
64
+
65
+ if (end < current_text.length && current_text[end] === ' ') suffix = "";
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
68
+ }
69
+ """
70
 
71
+ def set_seed(seed: int):
72
+ torch.manual_seed(seed)
73
+ torch.cuda.manual_seed(seed)
74
+ torch.cuda.manual_seed_all(seed)
75
+ random.seed(seed)
76
+ np.random.seed(seed)
77
 
 
 
 
 
 
 
 
 
78
 
79
+ def load_model():
80
+ global MODEL
81
+ print(f"Loading Chatterbox-Turbo on {DEVICE}...")
82
+ MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE)
83
+ return MODEL
84
 
85
+ @spaces.GPU
86
+ def generate(
87
+ text,
88
+ audio_prompt_path,
89
+ temperature,
90
+ seed_num,
91
+ min_p,
92
+ top_p,
93
+ top_k,
94
+ repetition_penalty,
95
+ norm_loudness
96
+ ):
97
+ global MODEL
98
+ # Reload if the worker lost the global state
99
+ if MODEL is None:
100
+ MODEL = ChatterboxTurboTTS.from_pretrained("cpu")
101
+
102
+ # --- MOVE TO GPU HERE ---
103
+ MODEL.to("cuda")
104
+
105
+ if seed_num != 0:
106
+ set_seed(int(seed_num))
107
+
108
+ wav = MODEL.generate(
109
+ text,
110
+ audio_prompt_path=audio_prompt_path,
111
+ temperature=temperature,
112
+ min_p=min_p,
113
+ top_p=top_p,
114
+ top_k=int(top_k),
115
+ repetition_penalty=repetition_penalty,
116
+ norm_loudness=norm_loudness,
117
+ )
118
+
119
+ return (MODEL.sr, wav.squeeze(0).cpu().numpy())
120
+
121
+
122
+ with gr.Blocks(title="Chatterbox Turbo") as demo:
123
+ gr.Markdown("# ⚡ Chatterbox Turbo")
124
+
125
+ with gr.Row():
126
+ with gr.Column():
127
+ text = gr.Textbox(
128
+ value="Congratulations Miss Connor! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?",
129
+ label="Text to synthesize (max chars 300)",
130
+ max_lines=5,
131
+ elem_id="main_textbox"
132
+ )
133
+
134
+ with gr.Row(elem_classes=["tag-container"]):
135
+ for tag in EVENT_TAGS:
136
+ btn = gr.Button(tag, elem_classes=["tag-btn"])
137
+ btn.click(
138
+ fn=None,
139
+ inputs=[btn, text],
140
+ outputs=text,
141
+ js=INSERT_TAG_JS
142
+ )
143
+
144
+ ref_wav = gr.Audio(
145
+ sources=["upload", "microphone"],
146
+ type="filepath",
147
+ label="Reference Audio File",
148
+ value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav"
149
+ )
150
+
151
+ run_btn = gr.Button("Generate ⚡", variant="primary")
152
+
153
+ with gr.Column():
154
+ audio_output = gr.Audio(label="Output Audio")
155
+
156
+ with gr.Accordion("Advanced Options", open=False):
157
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
158
+ temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
159
+ top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
160
+ top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
161
+ repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
162
+ min_p = gr.Slider(0.00, 1.00, step=0