feat: Qwen3-TTS proxy with HIP graph + CPU decoder optimisations

- OpenAI-compatible Flask proxy (POST /audio/speech, GET /models)
- faster-qwen3-tts HIP graph acceleration: GPU LLM at 1.78x RTF
- CPU speech tokenizer decoder: bypasses MIOpen ConvDirectNaiveConvFwd,
  eliminates 4-40s per-request decode overhead
- attn_implementation=sdpa for transformer attention
- AOTRITON env var toggle (off=short sentences, on=long-form/novel chapters)
- HIP_GRAPHS env var toggle (default on)
- Startup warmup with HIP graph capture (~5s)
- CORS support for browser extension requests
- RTF: 0.9-1.5x on AMD RX 7900 XTX (gfx1100, ROCm 6.3)

Performance vs baseline (CPU-only, ~3 min/sentence):
  12c: 3.2s | 44c: 2.7s | 115c: 6.6s
This commit is contained in:
2026-03-25 21:18:42 -07:00
commit d3ca5ab0b2
5 changed files with 627 additions and 0 deletions

206
qwen3-proxy/app.py Normal file
View File

@@ -0,0 +1,206 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""OpenAI-compatible TTS proxy backed by Qwen3-TTS.
Implements the two endpoints that Read-Aloud's OpenAI engine uses:
GET /models — connection test
POST /audio/speech — synthesise text → mp3
Set env vars to override defaults:
QWEN_MODEL — HuggingFace model id or local path
PROXY_PORT — listening port (default 5000)
DEVICE — torch device (default: cuda:0 if available, else cpu)
AOTRITON — "1" to enable AOTriton flash attention on gfx1100.
Faster for long text (>~80 chars, e.g. novel chapters).
Slower for short sentences (e.g. read-aloud). Default: 0.
HIP_GRAPHS — "1" to use faster-qwen3-tts (HIP/CUDA graph acceleration).
Eliminates Python overhead per autoregressive token — 3-4x
faster than the standard path. Requires GPU. Default: 1.
"""
import os
# Must be set before the first torch SDPA call (checked lazily, not at import).
if os.getenv("AOTRITON", "0") == "1":
os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"
import io, time, logging, subprocess, tempfile
import torch, soundfile as sf
from flask import Flask, request, jsonify, abort, send_file
from flask_cors import CORS
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app) # allow requests from browser extensions (chrome-extension:// etc.)
# ── Configuration ──────────────────────────────────────────────────────────────
MODEL_PATH = os.getenv("QWEN_MODEL", "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice")
DEVICE = os.getenv("DEVICE", "cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
USE_GRAPHS = os.getenv("HIP_GRAPHS", "1") == "1" and torch.cuda.is_available()
# Map OpenAI voice names → Qwen3-TTS speaker + language + optional instruct
VOICE_MAP = {
"alloy": {"speaker": "Ryan", "language": "English", "instruct": ""},
"echo": {"speaker": "Ryan", "language": "English", "instruct": "Speak in a calm, measured tone."},
"fable": {"speaker": "Ryan", "language": "English", "instruct": "Speak warmly and expressively."},
"onyx": {"speaker": "Ryan", "language": "English", "instruct": "Speak with a deep, authoritative voice."},
"nova": {"speaker": "Vivian", "language": "Chinese", "instruct": ""},
"shimmer": {"speaker": "Vivian", "language": "Chinese", "instruct": "Speak gently and softly."},
}
DEFAULT_VOICE = "alloy"
# ── Load model ─────────────────────────────────────────────────────────────────
if USE_GRAPHS:
from faster_qwen3_tts import FasterQwen3TTS
log.info("Loading FasterQwen3TTS (HIP graph mode) %s on %s", MODEL_PATH, DEVICE)
tts = FasterQwen3TTS.from_pretrained(MODEL_PATH, device=DEVICE, dtype=DTYPE)
def _synthesise(text, language, speaker, instruct):
# Cap audio length proportional to input text length.
# At 12Hz token rate, ~2.5 tokens per character is a generous ceiling.
# This prevents stochastic generation from producing absurdly long audio
# (e.g. "Hello world." generating 16s of audio with default max_new_tokens=2048).
max_new_tokens = max(60, int(len(text) * 2.5))
wavs, sr = tts.generate_custom_voice(
text=text, language=language, speaker=speaker,
instruct=instruct or None,
max_new_tokens=max_new_tokens,
)
return wavs, sr
def _synthesise_greedy(text, language, speaker):
"""Deterministic synthesis for warmup — uses tight token budget."""
max_new_tokens = max(60, int(len(text) * 2.5))
wavs, sr = tts.generate_custom_voice(
text=text, language=language, speaker=speaker,
instruct=None, do_sample=False,
max_new_tokens=max_new_tokens,
)
return wavs, sr
else:
from qwen_tts import Qwen3TTSModel
log.info("Loading Qwen3TTSModel (standard mode) %s on %s", MODEL_PATH, DEVICE)
tts = Qwen3TTSModel.from_pretrained(
MODEL_PATH, device_map=DEVICE, dtype=DTYPE, attn_implementation="sdpa",
)
def _synthesise(text, language, speaker, instruct):
wavs, sr = tts.generate_custom_voice(
text=text, language=language, speaker=speaker, instruct=instruct,
)
return wavs, sr
def _synthesise_greedy(text, language, speaker):
return _synthesise(text, language, speaker, "")
# ── Patch: run the speech tokenizer decoder on CPU ────────────────────────────
# The 12Hz decoder is pure Conv1d/ConvTranspose1d. On AMD ROCm, MIOpen's solver
# for these ops falls back to ConvDirectNaiveConvFwd (named "naive" for a reason),
# causing 4-40s of GPU decode time per request.
#
# Moving to CPU sidesteps MIOpen entirely. The Ryzen's AVX2 path handles these
# small 1D convolutions in <100ms, giving end-to-end RTF > 1.0x on typical text.
def _move_decoder_to_cpu(model_obj):
try:
st = model_obj.model.model.speech_tokenizer # FasterQwen3TTS path
except AttributeError:
st = model_obj.model.speech_tokenizer # Qwen3TTSModel path
st.model.to("cpu")
st.device = torch.device("cpu")
log.info("Speech tokenizer decoder moved to CPU (bypasses MIOpen)")
_move_decoder_to_cpu(tts)
# Use greedy (deterministic) decoding so warmup produces consistent audio lengths
# and MIOpen compiles the exact shapes that common inputs will hit at runtime.
# The 3 texts below produce ~1s, ~4s, and ~6s of audio deterministically.
log.info("Warming up — HIP graph capture …")
_t = time.monotonic()
# One synthesis call captures both HIP graphs (talker + predictor).
# No MIOpen warmup needed — decoder runs on CPU now.
_synthesise_greedy("Hello.", "English", "Ryan")
log.info("Warm-up done in %.1fs — proxy ready. mode=%s",
time.monotonic() - _t, "HIP-graphs" if USE_GRAPHS else "standard-sdpa")
# ── Helpers ────────────────────────────────────────────────────────────────────
def wav_to_mp3(wav_bytes: bytes) -> bytes:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
tmp_in.write(wav_bytes)
tmp_in_path = tmp_in.name
tmp_out_path = tmp_in_path.replace(".wav", ".mp3")
try:
subprocess.run(
["ffmpeg", "-y", "-i", tmp_in_path, "-codec:a", "libmp3lame", "-q:a", "4", tmp_out_path],
check=True, capture_output=True,
)
with open(tmp_out_path, "rb") as f:
return f.read()
finally:
os.unlink(tmp_in_path)
if os.path.exists(tmp_out_path):
os.unlink(tmp_out_path)
# ── Endpoints ──────────────────────────────────────────────────────────────────
@app.route("/models", methods=["GET"])
def models():
return jsonify({"object": "list", "data": [{"id": "tts-1", "object": "model"}]})
@app.route("/audio/speech", methods=["POST"])
def speech():
data = request.get_json(force=True, silent=True) or {}
text = data.get("input", "").strip()
voice = data.get("voice", DEFAULT_VOICE)
fmt = data.get("response_format", "mp3")
if not text:
abort(400, description="'input' field is required")
info = VOICE_MAP.get(voice, VOICE_MAP[DEFAULT_VOICE])
log.info("Synthesising %d chars | voice=%s speaker=%s", len(text), voice, info["speaker"])
try:
t0 = time.monotonic()
wavs, sr = _synthesise(text, info["language"], info["speaker"], info["instruct"])
elapsed = time.monotonic() - t0
audio_s = len(wavs[0]) / sr
log.info("Synthesis done in %.1fs audio=%.1fs RTF=%.2fx",
elapsed, audio_s, audio_s / elapsed)
except Exception as exc:
log.exception("TTS generation failed")
abort(500, description=str(exc))
wav_buf = io.BytesIO()
sf.write(wav_buf, wavs[0], sr, format="WAV")
wav_bytes = wav_buf.getvalue()
if fmt == "mp3":
audio_bytes = wav_to_mp3(wav_bytes)
mimetype = "audio/mpeg"
else:
audio_bytes = wav_bytes
mimetype = "audio/wav"
return send_file(io.BytesIO(audio_bytes), mimetype=mimetype)
# ── Error handlers ─────────────────────────────────────────────────────────────
@app.errorhandler(400)
@app.errorhandler(404)
@app.errorhandler(500)
@app.errorhandler(502)
def json_error(e):
return jsonify({"error": {"message": str(e), "type": "proxy_error"}}), e.code
if __name__ == "__main__":
port = int(os.getenv("PROXY_PORT", "5000"))
log.info("Starting proxy on port %d", port)
app.run(host="0.0.0.0", port=port, debug=False)