Skip to content

feat(sortformer): optional 6-bit head palettization + variant build driver#73

Open
Alex-Wengg wants to merge 2 commits into
mainfrom
feat/sortformer-palettization
Open

feat(sortformer): optional 6-bit head palettization + variant build driver#73
Alex-Wengg wants to merge 2 commits into
mainfrom
feat/sortformer-palettization

Conversation

@Alex-Wengg

Copy link
Copy Markdown
Member

Summary

Adds optional 6-bit weight palettization to the Sortformer head conversion, plus a driver to rebuild all variants. Motivated by FluidInference/FluidAudio#726 (BNNS crash + RAM crash on older devices).

Changes

  • convert_to_coreml.py: new --palettize_head_nbits N. Palettizes the head (conformer + transformer, ~98% of model size) to an N-bit kmeans LUT (constexpr_lut_to_dense) before make_pipeline.
    • 6-bit → model ~2.5× smaller (e.g. highContext 243 → 99 MB), matching Argmax's speakerkit-pro recipe.
    • No measurable per-call speed change on GPU; preserves speaker-argmax decisions.
    • LUT palettization is GPU-safelinear_quantize_weights (int8) crashes MPSGraph (MLIR pass manager failed), palettization does not.
  • build_all_variants.py: rebuilds all 7 variants (Default / NvidiaLow / NvidiaHigh × v2/v2.1, plus Efficient cl=25), compiles to .mlmodelc, and verifies each has no BNNS input==output alias and loads on ComputeUnit.ALL. PALETTIZE_NBITS / OUT_DIR env vars select fp16 vs palettized output sets.

Validation

  • Rebuilt all 7 variants both fp16 and 6-bit; all verified no-alias + ANE-loadable.
  • Parity vs NeMo PyTorch reference: 100% speaker-argmax agreement (fp16 and 6-bit).
  • Full AMI-SDM DER (forced-alignment GT, collar 0.25): 6-bit = +0.9 pp avg vs fp16.
  • Models uploaded to FluidInference/diar-streaming-sortformer-coreml/v3/{fp16,palettized}/.

…river

- convert_to_coreml.py: add --palettize_head_nbits N. Palettizes the head (conformer+
  transformer, ~98% of model size) to N-bit kmeans LUT before make_pipeline. 6-bit cuts
  the model ~2.5x (matches Argmax speakerkit), no measurable speed change, preserves
  speaker-argmax decisions. LUT palettization is GPU-safe — int8 linear quantization
  crashes MPSGraph. Helps the FluidAudio #726 RAM crash on older devices.
- build_all_variants.py: driver that rebuilds all 7 Sortformer variants (Default/NvidiaLow/
  NvidiaHigh x v2/v2.1 + Efficient cl=25), verifies no BNNS input==output alias + ANE load.
  Env PALETTIZE_NBITS / OUT_DIR select fp16 vs palettized output sets.
… faster)

Reproducible interleaved A/B: FluidAudio fused offline Sortformer vs Argmax's
3-model chain, same 30.72s window, ComputeUnit.ALL. Backs the numbers in
FluidAudio Documentation/Diarization/Sortformer.md#benchmarks (issue #726).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant