Indic Parler-TTS β€” ONNX

This is a verbatim ONNX export of ai4bharat/indic-parler-tts for in-browser inference via onnxruntime-web or transformers.js.

The upstream model is the multilingual Indic-language fine-tune of Parler-TTS (huggingface/parler-tts) by AI4Bharat. This repo re-exports the same weights as ONNX graphs so they can be loaded into a browser without a Python runtime.

Files β€” fp32 + fp16 variants

Both precisions ship in this repo following the onnx-community/dac_44khz-ONNX convention. fp16 graphs use keep_io_types=True so they are drop-in replacements for fp32 β€” consumer code doesn't change tensor dtypes between sessions when switching variants.

File Size Purpose
text_encoder.onnx + .data 1.30 GB fp32 β€” T5 (flan-t5-large) description encoder
text_encoder_fp16.onnx + .data 651 MB fp16 β€” same, fp32 I/O
decoder_model.onnx + .data 1.98 GB fp32 β€” first AR step (codec embed + prompt prefix + cross-attn + 24-layer decoder + fused 9-codebook lm-head)
decoder_model_fp16.onnx + .data 992 MB fp16
decoder_with_past_model.onnx + .data 1.25 GB fp32 β€” subsequent AR steps with past KV cache
decoder_with_past_model_fp16.onnx + .data 622 MB fp16
tokenizer.json, tokenizer.model, tokenizer_config.json 12 MB LlamaTokenizerFast, vocab 90,714 β€” used for both description AND prompt
config.json, generation_config.json, special_tokens_map.json small Model + generation config

Bundle totals:

  • fp16: ~2.3 GB ← recommended for browser consumers (halves IDB cache + WebGPU memory)
  • fp32: ~4.5 GB ← sanity reference; use only for parity testing or non-WebGPU paths

The DAC 44.1 kHz vocoder (onnx-community/dac_44khz-ONNX, referenced separately β€” pick its decoder_model_fp16.onnx at 104 MB) is needed to turn the decoder's codec tokens into audio waveform.

fp16 parity vs fp32

Conversion via onnxruntime.transformers.float16.convert_float_to_float16(..., keep_io_types=True) with the default ORT op-block list. Verified on CPU EP:

Graph Max abs diff
text_encoder_fp16.onnx 8.6Γ—10⁻³
decoder_model_fp16.onnx 1.8Γ—10⁻²
decoder_with_past_model_fp16.onnx (sanity only β€” finite outputs across full step)

All within typical fp16 transformer tolerances (≀5Γ—10⁻² is the gate). Audio quality is indistinguishable in practice for VITS-derived TTS.

Loading by dtype

// Browser, prefer fp16 unless you need bit-exact PyTorch reproduction
const url = (file) =>
  `https://huggingface.co/naklitechie/indic-parler-tts-ONNX/resolve/main/${file}`;
const dtype = 'fp16';   // or 'fp32'
const suffix = dtype === 'fp16' ? '_fp16' : '';
const sess = await ort.InferenceSession.create(
  await fetch(url(`text_encoder${suffix}.onnx`)).then(r => r.arrayBuffer()),
  {
    externalData: [{
      data: await fetch(url(`text_encoder${suffix}.onnx.data`)).then(r => r.arrayBuffer()),
      path: `text_encoder${suffix}.onnx.data`,
    }],
    executionProviders: ['webgpu', 'wasm'],
  }
);

Architecture facts (consumer cheat sheet)

  • Single tokenizer for description + prompt (the upstream README's mention of a prompt_tokenizer/ subfolder is a documentation artifact β€” both inputs use the same LlamaTokenizerFast, vocab 90,714).
  • prompt_cross_attention=False β€” prompt embeddings are prefixed to the decoder's input embeddings, not cross-attended. The decoder's self-attention KV cache region therefore covers (prompt_len + codec_len) positions after step 1.
  • Decoder: 24 layers Γ— 16 heads Γ— 64 head_dim (1024 hidden), 9 codebooks, vocab 1088 per codebook, max position 4096.
  • Fused lm_head: a single Linear: 1024 β†’ 9 Γ— 1088 projection (not 9 separate heads), reshape + transpose inside the graph.
  • PAD/BOS/EOS/START: 1024 / 1025 / 1024 / 1025.
  • Sample rate: 44100 Hz from DAC. ~10 ms per codec token end-to-end.

I/O signatures

text_encoder.onnx

Input Shape Type
input_ids [B, S] int64
attention_mask [B, S] int64
Output Shape Type
last_hidden_state [B, S, 1024] float32

decoder_model.onnx (first AR step)

Input Shape Type
codec_input_ids [B, 9, T] int64
prompt_input_ids [B, P] int64
prompt_attention_mask [B, P] int64
encoder_hidden_states [B, S, 1024] float32
encoder_attention_mask [B, S] int64
Output Shape Type
logits [9*B, P+T, 1088] float32
present.{i}.{decoder.key,decoder.value,encoder.key,encoder.value} Γ— 24 layers [B, 16, P+T or S, 64] float32

decoder_with_past_model.onnx (subsequent AR steps)

Input Shape Type
codec_input_ids [B, 9, 1] int64
attention_mask [B, full_t] int64 β€” covers prompt + prior codec + new
encoder_attention_mask [B, S] int64
cache_position [1] int64 β€” past_kv_length scalar
past_key_values.{i}.{decoder.key,decoder.value,encoder.key,encoder.value} Γ— 24 layers as in present above float32
Output Shape Type
logits [9*B, 1, 1088] float32
present.{i}.{...} Γ— 24 layers shifted by 1 in self-attn time dim float32

KV layout per layer: (decoder_self.key, decoder_self.value, encoder_cross.key, encoder_cross.value) β€” IT2-compatible.

End-to-end inference (sketch)

// 1. Encode description
const encOut = await sessText.run({ input_ids, attention_mask });
const enc_h = encOut.last_hidden_state;

// 2. Build initial codec β€” [B, 9, 1] all START_ID=1025
const codec0 = new BigInt64Array(B * 9 * 1).fill(1025n);

// 3. Build delay-pattern mask (port of parler_tts.build_delay_pattern_mask)
const { initialIds, patternMask } = buildDelayPatternMask({
  bos: 1025, pad: 1024, maxLen: 256, numCodebooks: 9,
});

// 4. First step
let { logits, ...present } = await sessDecNoPast.run({
  codec_input_ids: initialIds,           // [B, 9, T_init]
  prompt_input_ids,
  prompt_attention_mask,
  encoder_hidden_states: enc_h,
  encoder_attention_mask: attention_mask,
});

// 5. Greedy AR loop β€” apply delay pattern between steps
for (let step = 1; step < maxSteps; step++) {
  const next = applyDelayPatternMask(argmaxLogits(logits), patternMask, step);
  ({ logits, ...present } = await sessDecWithPast.run({
    codec_input_ids: next,
    attention_mask: buildFullAttnMask(promptLen + step),
    encoder_attention_mask,
    cache_position: BigInt(promptLen + step),
    ...present,                          // re-feed all 96 KV tensors
  }));
}

// 6. Decode codec tokens β†’ audio via DAC
//    use onnx-community/dac_44khz-ONNX decoder

The apply_delay_pattern_mask and build_delay_pattern_mask functions need porting to JS β€” reference the Python source in parler_tts/modeling_parler_tts.py.

Validated parity vs PyTorch

  • Text encoder: max abs diff 4.72e-6
  • Decoder step 1 (no-past): max abs diff 1.91e-5
  • Decoder step 2 (with-past): max abs diff 1.68e-4

All well below the typical fp32-sufficient 1e-3 tolerance.

Citation

If you use this in research or production, please cite the original AI4Bharat work:

@misc{indic-parler-tts,
  title = {Indic Parler-TTS},
  author = {AI4Bharat},
  year = {2024},
  url = {https://huggingface.co/ai4bharat/indic-parler-tts}
}

And the upstream Parler-TTS:

@misc{lacombe-etal-2024-parler-tts,
  author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi},
  title = {Parler-TTS},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/huggingface/parler-tts}}
}

License

Apache 2.0 β€” same as the upstream ai4bharat/indic-parler-tts. See NOTICE.md for attribution.

Downloads last month
48
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for naklitechie/indic-parler-tts-ONNX

Quantized
(1)
this model