Indic Parler-TTS β ONNX
This is a verbatim ONNX export of ai4bharat/indic-parler-tts for in-browser inference via onnxruntime-web or transformers.js.
The upstream model is the multilingual Indic-language fine-tune of Parler-TTS (huggingface/parler-tts) by AI4Bharat. This repo re-exports the same weights as ONNX graphs so they can be loaded into a browser without a Python runtime.
Files β fp32 + fp16 variants
Both precisions ship in this repo following the
onnx-community/dac_44khz-ONNX
convention. fp16 graphs use keep_io_types=True so they are drop-in
replacements for fp32 β consumer code doesn't change tensor dtypes
between sessions when switching variants.
| File | Size | Purpose |
|---|---|---|
text_encoder.onnx + .data |
1.30 GB | fp32 β T5 (flan-t5-large) description encoder |
text_encoder_fp16.onnx + .data |
651 MB | fp16 β same, fp32 I/O |
decoder_model.onnx + .data |
1.98 GB | fp32 β first AR step (codec embed + prompt prefix + cross-attn + 24-layer decoder + fused 9-codebook lm-head) |
decoder_model_fp16.onnx + .data |
992 MB | fp16 |
decoder_with_past_model.onnx + .data |
1.25 GB | fp32 β subsequent AR steps with past KV cache |
decoder_with_past_model_fp16.onnx + .data |
622 MB | fp16 |
tokenizer.json, tokenizer.model, tokenizer_config.json |
12 MB | LlamaTokenizerFast, vocab 90,714 β used for both description AND prompt |
config.json, generation_config.json, special_tokens_map.json |
small | Model + generation config |
Bundle totals:
- fp16: ~2.3 GB β recommended for browser consumers (halves IDB cache + WebGPU memory)
- fp32: ~4.5 GB β sanity reference; use only for parity testing or non-WebGPU paths
The DAC 44.1 kHz vocoder (onnx-community/dac_44khz-ONNX, referenced separately β pick its decoder_model_fp16.onnx at 104 MB) is needed to turn the decoder's codec tokens into audio waveform.
fp16 parity vs fp32
Conversion via onnxruntime.transformers.float16.convert_float_to_float16(..., keep_io_types=True) with the default ORT op-block list. Verified on CPU EP:
| Graph | Max abs diff |
|---|---|
text_encoder_fp16.onnx |
8.6Γ10β»Β³ |
decoder_model_fp16.onnx |
1.8Γ10β»Β² |
decoder_with_past_model_fp16.onnx |
(sanity only β finite outputs across full step) |
All within typical fp16 transformer tolerances (β€5Γ10β»Β² is the gate). Audio quality is indistinguishable in practice for VITS-derived TTS.
Loading by dtype
// Browser, prefer fp16 unless you need bit-exact PyTorch reproduction
const url = (file) =>
`https://huggingface.co/naklitechie/indic-parler-tts-ONNX/resolve/main/${file}`;
const dtype = 'fp16'; // or 'fp32'
const suffix = dtype === 'fp16' ? '_fp16' : '';
const sess = await ort.InferenceSession.create(
await fetch(url(`text_encoder${suffix}.onnx`)).then(r => r.arrayBuffer()),
{
externalData: [{
data: await fetch(url(`text_encoder${suffix}.onnx.data`)).then(r => r.arrayBuffer()),
path: `text_encoder${suffix}.onnx.data`,
}],
executionProviders: ['webgpu', 'wasm'],
}
);
Architecture facts (consumer cheat sheet)
- Single tokenizer for description + prompt (the upstream README's mention of a
prompt_tokenizer/subfolder is a documentation artifact β both inputs use the sameLlamaTokenizerFast, vocab 90,714). prompt_cross_attention=Falseβ prompt embeddings are prefixed to the decoder's input embeddings, not cross-attended. The decoder's self-attention KV cache region therefore covers(prompt_len + codec_len)positions after step 1.- Decoder: 24 layers Γ 16 heads Γ 64 head_dim (1024 hidden), 9 codebooks, vocab 1088 per codebook, max position 4096.
- Fused lm_head: a single
Linear: 1024 β 9 Γ 1088projection (not 9 separate heads), reshape + transpose inside the graph. - PAD/BOS/EOS/START: 1024 / 1025 / 1024 / 1025.
- Sample rate: 44100 Hz from DAC. ~10 ms per codec token end-to-end.
I/O signatures
text_encoder.onnx
| Input | Shape | Type |
|---|---|---|
input_ids |
[B, S] |
int64 |
attention_mask |
[B, S] |
int64 |
| Output | Shape | Type |
|---|---|---|
last_hidden_state |
[B, S, 1024] |
float32 |
decoder_model.onnx (first AR step)
| Input | Shape | Type |
|---|---|---|
codec_input_ids |
[B, 9, T] |
int64 |
prompt_input_ids |
[B, P] |
int64 |
prompt_attention_mask |
[B, P] |
int64 |
encoder_hidden_states |
[B, S, 1024] |
float32 |
encoder_attention_mask |
[B, S] |
int64 |
| Output | Shape | Type |
|---|---|---|
logits |
[9*B, P+T, 1088] |
float32 |
present.{i}.{decoder.key,decoder.value,encoder.key,encoder.value} Γ 24 layers |
[B, 16, P+T or S, 64] |
float32 |
decoder_with_past_model.onnx (subsequent AR steps)
| Input | Shape | Type |
|---|---|---|
codec_input_ids |
[B, 9, 1] |
int64 |
attention_mask |
[B, full_t] |
int64 β covers prompt + prior codec + new |
encoder_attention_mask |
[B, S] |
int64 |
cache_position |
[1] |
int64 β past_kv_length scalar |
past_key_values.{i}.{decoder.key,decoder.value,encoder.key,encoder.value} Γ 24 layers |
as in present above |
float32 |
| Output | Shape | Type |
|---|---|---|
logits |
[9*B, 1, 1088] |
float32 |
present.{i}.{...} Γ 24 layers |
shifted by 1 in self-attn time dim | float32 |
KV layout per layer: (decoder_self.key, decoder_self.value, encoder_cross.key, encoder_cross.value) β IT2-compatible.
End-to-end inference (sketch)
// 1. Encode description
const encOut = await sessText.run({ input_ids, attention_mask });
const enc_h = encOut.last_hidden_state;
// 2. Build initial codec β [B, 9, 1] all START_ID=1025
const codec0 = new BigInt64Array(B * 9 * 1).fill(1025n);
// 3. Build delay-pattern mask (port of parler_tts.build_delay_pattern_mask)
const { initialIds, patternMask } = buildDelayPatternMask({
bos: 1025, pad: 1024, maxLen: 256, numCodebooks: 9,
});
// 4. First step
let { logits, ...present } = await sessDecNoPast.run({
codec_input_ids: initialIds, // [B, 9, T_init]
prompt_input_ids,
prompt_attention_mask,
encoder_hidden_states: enc_h,
encoder_attention_mask: attention_mask,
});
// 5. Greedy AR loop β apply delay pattern between steps
for (let step = 1; step < maxSteps; step++) {
const next = applyDelayPatternMask(argmaxLogits(logits), patternMask, step);
({ logits, ...present } = await sessDecWithPast.run({
codec_input_ids: next,
attention_mask: buildFullAttnMask(promptLen + step),
encoder_attention_mask,
cache_position: BigInt(promptLen + step),
...present, // re-feed all 96 KV tensors
}));
}
// 6. Decode codec tokens β audio via DAC
// use onnx-community/dac_44khz-ONNX decoder
The apply_delay_pattern_mask and build_delay_pattern_mask functions need porting to JS β reference the Python source in parler_tts/modeling_parler_tts.py.
Validated parity vs PyTorch
- Text encoder: max abs diff
4.72e-6 - Decoder step 1 (no-past): max abs diff
1.91e-5 - Decoder step 2 (with-past): max abs diff
1.68e-4
All well below the typical fp32-sufficient 1e-3 tolerance.
Citation
If you use this in research or production, please cite the original AI4Bharat work:
@misc{indic-parler-tts,
title = {Indic Parler-TTS},
author = {AI4Bharat},
year = {2024},
url = {https://huggingface.co/ai4bharat/indic-parler-tts}
}
And the upstream Parler-TTS:
@misc{lacombe-etal-2024-parler-tts,
author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi},
title = {Parler-TTS},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/huggingface/parler-tts}}
}
License
Apache 2.0 β same as the upstream ai4bharat/indic-parler-tts. See NOTICE.md for attribution.
- Downloads last month
- 48
Model tree for naklitechie/indic-parler-tts-ONNX
Base model
ai4bharat/indic-parler-tts