- OpenAI-compatible Flask proxy (POST /audio/speech, GET /models) - faster-qwen3-tts HIP graph acceleration: GPU LLM at 1.78x RTF - CPU speech tokenizer decoder: bypasses MIOpen ConvDirectNaiveConvFwd, eliminates 4-40s per-request decode overhead - attn_implementation=sdpa for transformer attention - AOTRITON env var toggle (off=short sentences, on=long-form/novel chapters) - HIP_GRAPHS env var toggle (default on) - Startup warmup with HIP graph capture (~5s) - CORS support for browser extension requests - RTF: 0.9-1.5x on AMD RX 7900 XTX (gfx1100, ROCm 6.3) Performance vs baseline (CPU-only, ~3 min/sentence): 12c: 3.2s | 44c: 2.7s | 115c: 6.6s
207 lines
9.2 KiB
Python
207 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""OpenAI-compatible TTS proxy backed by Qwen3-TTS.
|
|
|
|
Implements the two endpoints that Read-Aloud's OpenAI engine uses:
|
|
GET /models — connection test
|
|
POST /audio/speech — synthesise text → mp3
|
|
|
|
Set env vars to override defaults:
|
|
QWEN_MODEL — HuggingFace model id or local path
|
|
PROXY_PORT — listening port (default 5000)
|
|
DEVICE — torch device (default: cuda:0 if available, else cpu)
|
|
AOTRITON — "1" to enable AOTriton flash attention on gfx1100.
|
|
Faster for long text (>~80 chars, e.g. novel chapters).
|
|
Slower for short sentences (e.g. read-aloud). Default: 0.
|
|
HIP_GRAPHS — "1" to use faster-qwen3-tts (HIP/CUDA graph acceleration).
|
|
Eliminates Python overhead per autoregressive token — 3-4x
|
|
faster than the standard path. Requires GPU. Default: 1.
|
|
"""
|
|
|
|
import os
|
|
|
|
# Must be set before the first torch SDPA call (checked lazily, not at import).
|
|
if os.getenv("AOTRITON", "0") == "1":
|
|
os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"
|
|
|
|
import io, time, logging, subprocess, tempfile
|
|
import torch, soundfile as sf
|
|
from flask import Flask, request, jsonify, abort, send_file
|
|
from flask_cors import CORS
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
log = logging.getLogger(__name__)
|
|
|
|
app = Flask(__name__)
|
|
CORS(app) # allow requests from browser extensions (chrome-extension:// etc.)
|
|
|
|
# ── Configuration ──────────────────────────────────────────────────────────────
|
|
MODEL_PATH = os.getenv("QWEN_MODEL", "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice")
|
|
DEVICE = os.getenv("DEVICE", "cuda:0" if torch.cuda.is_available() else "cpu")
|
|
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
USE_GRAPHS = os.getenv("HIP_GRAPHS", "1") == "1" and torch.cuda.is_available()
|
|
|
|
# Map OpenAI voice names → Qwen3-TTS speaker + language + optional instruct
|
|
VOICE_MAP = {
|
|
"alloy": {"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"echo": {"speaker": "Ryan", "language": "English", "instruct": "Speak in a calm, measured tone."},
|
|
"fable": {"speaker": "Ryan", "language": "English", "instruct": "Speak warmly and expressively."},
|
|
"onyx": {"speaker": "Ryan", "language": "English", "instruct": "Speak with a deep, authoritative voice."},
|
|
"nova": {"speaker": "Vivian", "language": "Chinese", "instruct": ""},
|
|
"shimmer": {"speaker": "Vivian", "language": "Chinese", "instruct": "Speak gently and softly."},
|
|
}
|
|
DEFAULT_VOICE = "alloy"
|
|
|
|
# ── Load model ─────────────────────────────────────────────────────────────────
|
|
if USE_GRAPHS:
|
|
from faster_qwen3_tts import FasterQwen3TTS
|
|
log.info("Loading FasterQwen3TTS (HIP graph mode) %s on %s …", MODEL_PATH, DEVICE)
|
|
tts = FasterQwen3TTS.from_pretrained(MODEL_PATH, device=DEVICE, dtype=DTYPE)
|
|
|
|
def _synthesise(text, language, speaker, instruct):
|
|
# Cap audio length proportional to input text length.
|
|
# At 12Hz token rate, ~2.5 tokens per character is a generous ceiling.
|
|
# This prevents stochastic generation from producing absurdly long audio
|
|
# (e.g. "Hello world." generating 16s of audio with default max_new_tokens=2048).
|
|
max_new_tokens = max(60, int(len(text) * 2.5))
|
|
wavs, sr = tts.generate_custom_voice(
|
|
text=text, language=language, speaker=speaker,
|
|
instruct=instruct or None,
|
|
max_new_tokens=max_new_tokens,
|
|
)
|
|
return wavs, sr
|
|
|
|
def _synthesise_greedy(text, language, speaker):
|
|
"""Deterministic synthesis for warmup — uses tight token budget."""
|
|
max_new_tokens = max(60, int(len(text) * 2.5))
|
|
wavs, sr = tts.generate_custom_voice(
|
|
text=text, language=language, speaker=speaker,
|
|
instruct=None, do_sample=False,
|
|
max_new_tokens=max_new_tokens,
|
|
)
|
|
return wavs, sr
|
|
|
|
else:
|
|
from qwen_tts import Qwen3TTSModel
|
|
log.info("Loading Qwen3TTSModel (standard mode) %s on %s …", MODEL_PATH, DEVICE)
|
|
tts = Qwen3TTSModel.from_pretrained(
|
|
MODEL_PATH, device_map=DEVICE, dtype=DTYPE, attn_implementation="sdpa",
|
|
)
|
|
|
|
def _synthesise(text, language, speaker, instruct):
|
|
wavs, sr = tts.generate_custom_voice(
|
|
text=text, language=language, speaker=speaker, instruct=instruct,
|
|
)
|
|
return wavs, sr
|
|
|
|
def _synthesise_greedy(text, language, speaker):
|
|
return _synthesise(text, language, speaker, "")
|
|
|
|
# ── Patch: run the speech tokenizer decoder on CPU ────────────────────────────
|
|
# The 12Hz decoder is pure Conv1d/ConvTranspose1d. On AMD ROCm, MIOpen's solver
|
|
# for these ops falls back to ConvDirectNaiveConvFwd (named "naive" for a reason),
|
|
# causing 4-40s of GPU decode time per request.
|
|
#
|
|
# Moving to CPU sidesteps MIOpen entirely. The Ryzen's AVX2 path handles these
|
|
# small 1D convolutions in <100ms, giving end-to-end RTF > 1.0x on typical text.
|
|
|
|
def _move_decoder_to_cpu(model_obj):
|
|
try:
|
|
st = model_obj.model.model.speech_tokenizer # FasterQwen3TTS path
|
|
except AttributeError:
|
|
st = model_obj.model.speech_tokenizer # Qwen3TTSModel path
|
|
st.model.to("cpu")
|
|
st.device = torch.device("cpu")
|
|
log.info("Speech tokenizer decoder moved to CPU (bypasses MIOpen)")
|
|
|
|
_move_decoder_to_cpu(tts)
|
|
# Use greedy (deterministic) decoding so warmup produces consistent audio lengths
|
|
# and MIOpen compiles the exact shapes that common inputs will hit at runtime.
|
|
# The 3 texts below produce ~1s, ~4s, and ~6s of audio deterministically.
|
|
log.info("Warming up — HIP graph capture …")
|
|
_t = time.monotonic()
|
|
|
|
# One synthesis call captures both HIP graphs (talker + predictor).
|
|
# No MIOpen warmup needed — decoder runs on CPU now.
|
|
_synthesise_greedy("Hello.", "English", "Ryan")
|
|
log.info("Warm-up done in %.1fs — proxy ready. mode=%s",
|
|
time.monotonic() - _t, "HIP-graphs" if USE_GRAPHS else "standard-sdpa")
|
|
|
|
|
|
# ── Helpers ────────────────────────────────────────────────────────────────────
|
|
def wav_to_mp3(wav_bytes: bytes) -> bytes:
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
|
|
tmp_in.write(wav_bytes)
|
|
tmp_in_path = tmp_in.name
|
|
tmp_out_path = tmp_in_path.replace(".wav", ".mp3")
|
|
try:
|
|
subprocess.run(
|
|
["ffmpeg", "-y", "-i", tmp_in_path, "-codec:a", "libmp3lame", "-q:a", "4", tmp_out_path],
|
|
check=True, capture_output=True,
|
|
)
|
|
with open(tmp_out_path, "rb") as f:
|
|
return f.read()
|
|
finally:
|
|
os.unlink(tmp_in_path)
|
|
if os.path.exists(tmp_out_path):
|
|
os.unlink(tmp_out_path)
|
|
|
|
|
|
# ── Endpoints ──────────────────────────────────────────────────────────────────
|
|
@app.route("/models", methods=["GET"])
|
|
def models():
|
|
return jsonify({"object": "list", "data": [{"id": "tts-1", "object": "model"}]})
|
|
|
|
|
|
@app.route("/audio/speech", methods=["POST"])
|
|
def speech():
|
|
data = request.get_json(force=True, silent=True) or {}
|
|
text = data.get("input", "").strip()
|
|
voice = data.get("voice", DEFAULT_VOICE)
|
|
fmt = data.get("response_format", "mp3")
|
|
|
|
if not text:
|
|
abort(400, description="'input' field is required")
|
|
|
|
info = VOICE_MAP.get(voice, VOICE_MAP[DEFAULT_VOICE])
|
|
log.info("Synthesising %d chars | voice=%s speaker=%s", len(text), voice, info["speaker"])
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
wavs, sr = _synthesise(text, info["language"], info["speaker"], info["instruct"])
|
|
elapsed = time.monotonic() - t0
|
|
audio_s = len(wavs[0]) / sr
|
|
log.info("Synthesis done in %.1fs audio=%.1fs RTF=%.2fx",
|
|
elapsed, audio_s, audio_s / elapsed)
|
|
except Exception as exc:
|
|
log.exception("TTS generation failed")
|
|
abort(500, description=str(exc))
|
|
|
|
wav_buf = io.BytesIO()
|
|
sf.write(wav_buf, wavs[0], sr, format="WAV")
|
|
wav_bytes = wav_buf.getvalue()
|
|
|
|
if fmt == "mp3":
|
|
audio_bytes = wav_to_mp3(wav_bytes)
|
|
mimetype = "audio/mpeg"
|
|
else:
|
|
audio_bytes = wav_bytes
|
|
mimetype = "audio/wav"
|
|
|
|
return send_file(io.BytesIO(audio_bytes), mimetype=mimetype)
|
|
|
|
|
|
# ── Error handlers ─────────────────────────────────────────────────────────────
|
|
@app.errorhandler(400)
|
|
@app.errorhandler(404)
|
|
@app.errorhandler(500)
|
|
@app.errorhandler(502)
|
|
def json_error(e):
|
|
return jsonify({"error": {"message": str(e), "type": "proxy_error"}}), e.code
|
|
|
|
|
|
if __name__ == "__main__":
|
|
port = int(os.getenv("PROXY_PORT", "5000"))
|
|
log.info("Starting proxy on port %d", port)
|
|
app.run(host="0.0.0.0", port=port, debug=False)
|