- Add Response to flask imports (caused NameError on every PCM request) - Unpack (audio, sr, timing) tuple correctly from generate_custom_voice_streaming (was iterating the tuple itself, passing a 3-element object to np.clip) - Move elapsed/chunk logging inside the generator so it fires after stream ends - PCM streaming now working: 12c test → 2.3s audio in 1.8s, 3 chunks
286 lines
13 KiB
Python
286 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""OpenAI-compatible TTS proxy backed by Qwen3-TTS.
|
|
|
|
Implements the two endpoints that Read-Aloud's OpenAI engine uses:
|
|
GET /models — connection test
|
|
POST /audio/speech — synthesise text → mp3
|
|
|
|
Set env vars to override defaults:
|
|
QWEN_MODEL — HuggingFace model id or local path
|
|
PROXY_PORT — listening port (default 5000)
|
|
DEVICE — torch device (default: cuda:0 if available, else cpu)
|
|
AOTRITON — "1" to enable AOTriton flash attention on gfx1100.
|
|
Faster for long text (>~80 chars, e.g. novel chapters).
|
|
Slower for short sentences (e.g. read-aloud). Default: 0.
|
|
HIP_GRAPHS — "1" to use faster-qwen3-tts (HIP/CUDA graph acceleration).
|
|
Eliminates Python overhead per autoregressive token — 3-4x
|
|
faster than the standard path. Requires GPU. Default: 1.
|
|
"""
|
|
|
|
import os
|
|
|
|
# Must be set before the first torch SDPA call (checked lazily, not at import).
|
|
if os.getenv("AOTRITON", "0") == "1":
|
|
os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"
|
|
|
|
import io, time, logging, subprocess, tempfile
|
|
import torch, soundfile as sf
|
|
import numpy as np
|
|
from flask import Flask, request, jsonify, abort, send_file, stream_with_context, Response
|
|
from flask_cors import CORS
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
log = logging.getLogger(__name__)
|
|
|
|
app = Flask(__name__)
|
|
CORS(app) # allow requests from browser extensions (chrome-extension:// etc.)
|
|
|
|
# ── Configuration ──────────────────────────────────────────────────────────────
|
|
MODEL_PATH = os.getenv("QWEN_MODEL", "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice")
|
|
DEVICE = os.getenv("DEVICE", "cuda:0" if torch.cuda.is_available() else "cpu")
|
|
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
USE_GRAPHS = os.getenv("HIP_GRAPHS", "1") == "1" and torch.cuda.is_available()
|
|
|
|
# Map OpenAI voice names → Qwen3-TTS speaker + language + optional instruct
|
|
VOICE_MAP = {
|
|
# ── Standard OpenAI voices ──────────────────────────────────────────────
|
|
"alloy": {"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"echo": {"speaker": "Ryan", "language": "English", "instruct": "Speak in a calm, measured tone."},
|
|
"fable": {"speaker": "Ryan", "language": "English", "instruct": "Speak warmly and expressively."},
|
|
"onyx": {"speaker": "Ryan", "language": "English", "instruct": "Speak with a deep, authoritative voice."},
|
|
"nova": {"speaker": "Vivian", "language": "Chinese", "instruct": ""},
|
|
"shimmer": {"speaker": "Vivian", "language": "Chinese", "instruct": "Speak gently and softly."},
|
|
# ── Kokoro voice aliases (customtts extension) ──────────────────────────
|
|
# Kokoro names follow: {af|bf|am|bm}_{name} (a/b=American/British, f/m=female/male)
|
|
# We map female English → Ryan (only English speaker in 0.6B model),
|
|
# Chinese voices → Vivian. Individual names get personality instruct where fitting.
|
|
"af_bella": {"speaker": "Ryan", "language": "English", "instruct": "Speak warmly and expressively."},
|
|
"af_nicole": {"speaker": "Ryan", "language": "English", "instruct": "Speak in a calm, measured tone."},
|
|
"af_sarah": {"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"af_sky": {"speaker": "Ryan", "language": "English", "instruct": "Speak gently and softly."},
|
|
"bf_emma": {"speaker": "Ryan", "language": "English", "instruct": "Speak warmly and expressively."},
|
|
"bf_isabella":{"speaker": "Ryan", "language": "English", "instruct": "Speak gently and softly."},
|
|
"am_adam": {"speaker": "Ryan", "language": "English", "instruct": "Speak with a deep, authoritative voice."},
|
|
"am_michael":{"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"bm_george": {"speaker": "Ryan", "language": "English", "instruct": "Speak with a deep, authoritative voice."},
|
|
"bm_lewis": {"speaker": "Ryan", "language": "English", "instruct": "Speak in a calm, measured tone."},
|
|
}
|
|
DEFAULT_VOICE = "alloy"
|
|
|
|
# Kokoro prefix heuristic for voices not explicitly listed above.
|
|
# af_/bf_ = female English, am_/bm_ = male English, zf_/zm_ = Chinese
|
|
_KOKORO_PREFIXES = {
|
|
"af_": {"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"bf_": {"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"am_": {"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"bm_": {"speaker": "Ryan", "language": "English", "instruct": ""},
|
|
"zf_": {"speaker": "Vivian", "language": "Chinese", "instruct": ""},
|
|
"zm_": {"speaker": "Vivian", "language": "Chinese", "instruct": ""},
|
|
}
|
|
|
|
def resolve_voice(raw: str) -> dict:
|
|
"""Resolve a voice string to a Qwen3-TTS speaker config.
|
|
|
|
Handles:
|
|
- Standard names: "alloy", "echo", etc.
|
|
- Kokoro blends: "af_bella+bf_emma+af_nicole" (picks first component)
|
|
- Kokoro singles: "af_bella"
|
|
- Unknown: falls back to DEFAULT_VOICE
|
|
"""
|
|
# Take only the first voice in a + blend
|
|
name = raw.split("+")[0].strip().lower()
|
|
if name in VOICE_MAP:
|
|
return VOICE_MAP[name]
|
|
# Try Kokoro prefix heuristic
|
|
for prefix, info in _KOKORO_PREFIXES.items():
|
|
if name.startswith(prefix):
|
|
log.debug("Kokoro prefix match %r → %s", name, info["speaker"])
|
|
return info
|
|
log.warning("Unknown voice %r, falling back to %s", raw, DEFAULT_VOICE)
|
|
return VOICE_MAP[DEFAULT_VOICE]
|
|
|
|
# ── Load model ─────────────────────────────────────────────────────────────────
|
|
if USE_GRAPHS:
|
|
from faster_qwen3_tts import FasterQwen3TTS
|
|
log.info("Loading FasterQwen3TTS (HIP graph mode) %s on %s …", MODEL_PATH, DEVICE)
|
|
tts = FasterQwen3TTS.from_pretrained(MODEL_PATH, device=DEVICE, dtype=DTYPE)
|
|
|
|
def _synthesise(text, language, speaker, instruct):
|
|
# Cap audio length proportional to input text length.
|
|
# At 12Hz token rate, ~2.5 tokens per character is a generous ceiling.
|
|
# This prevents stochastic generation from producing absurdly long audio
|
|
# (e.g. "Hello world." generating 16s of audio with default max_new_tokens=2048).
|
|
max_new_tokens = max(60, int(len(text) * 2.5))
|
|
wavs, sr = tts.generate_custom_voice(
|
|
text=text, language=language, speaker=speaker,
|
|
instruct=instruct or None,
|
|
max_new_tokens=max_new_tokens,
|
|
)
|
|
return wavs, sr
|
|
|
|
def _synthesise_greedy(text, language, speaker):
|
|
"""Deterministic synthesis for warmup — uses tight token budget."""
|
|
max_new_tokens = max(60, int(len(text) * 2.5))
|
|
wavs, sr = tts.generate_custom_voice(
|
|
text=text, language=language, speaker=speaker,
|
|
instruct=None, do_sample=False,
|
|
max_new_tokens=max_new_tokens,
|
|
)
|
|
return wavs, sr
|
|
|
|
else:
|
|
from qwen_tts import Qwen3TTSModel
|
|
log.info("Loading Qwen3TTSModel (standard mode) %s on %s …", MODEL_PATH, DEVICE)
|
|
tts = Qwen3TTSModel.from_pretrained(
|
|
MODEL_PATH, device_map=DEVICE, dtype=DTYPE, attn_implementation="sdpa",
|
|
)
|
|
|
|
def _synthesise(text, language, speaker, instruct):
|
|
wavs, sr = tts.generate_custom_voice(
|
|
text=text, language=language, speaker=speaker, instruct=instruct,
|
|
)
|
|
return wavs, sr
|
|
|
|
def _synthesise_greedy(text, language, speaker):
|
|
return _synthesise(text, language, speaker, "")
|
|
|
|
# ── Patch: run the speech tokenizer decoder on CPU ────────────────────────────
|
|
# The 12Hz decoder is pure Conv1d/ConvTranspose1d. On AMD ROCm, MIOpen's solver
|
|
# for these ops falls back to ConvDirectNaiveConvFwd (named "naive" for a reason),
|
|
# causing 4-40s of GPU decode time per request.
|
|
#
|
|
# Moving to CPU sidesteps MIOpen entirely. The Ryzen's AVX2 path handles these
|
|
# small 1D convolutions in <100ms, giving end-to-end RTF > 1.0x on typical text.
|
|
|
|
def _move_decoder_to_cpu(model_obj):
|
|
try:
|
|
st = model_obj.model.model.speech_tokenizer # FasterQwen3TTS path
|
|
except AttributeError:
|
|
st = model_obj.model.speech_tokenizer # Qwen3TTSModel path
|
|
st.model.to("cpu")
|
|
st.device = torch.device("cpu")
|
|
log.info("Speech tokenizer decoder moved to CPU (bypasses MIOpen)")
|
|
|
|
_move_decoder_to_cpu(tts)
|
|
# Use greedy (deterministic) decoding so warmup produces consistent audio lengths
|
|
# and MIOpen compiles the exact shapes that common inputs will hit at runtime.
|
|
# The 3 texts below produce ~1s, ~4s, and ~6s of audio deterministically.
|
|
log.info("Warming up — HIP graph capture …")
|
|
_t = time.monotonic()
|
|
|
|
# One synthesis call captures both HIP graphs (talker + predictor).
|
|
# No MIOpen warmup needed — decoder runs on CPU now.
|
|
_synthesise_greedy("Hello.", "English", "Ryan")
|
|
log.info("Warm-up done in %.1fs — proxy ready. mode=%s",
|
|
time.monotonic() - _t, "HIP-graphs" if USE_GRAPHS else "standard-sdpa")
|
|
|
|
|
|
# ── Helpers ────────────────────────────────────────────────────────────────────
|
|
def wav_to_mp3(wav_bytes: bytes) -> bytes:
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
|
|
tmp_in.write(wav_bytes)
|
|
tmp_in_path = tmp_in.name
|
|
tmp_out_path = tmp_in_path.replace(".wav", ".mp3")
|
|
try:
|
|
subprocess.run(
|
|
["ffmpeg", "-y", "-i", tmp_in_path, "-codec:a", "libmp3lame", "-q:a", "4", tmp_out_path],
|
|
check=True, capture_output=True,
|
|
)
|
|
with open(tmp_out_path, "rb") as f:
|
|
return f.read()
|
|
finally:
|
|
os.unlink(tmp_in_path)
|
|
if os.path.exists(tmp_out_path):
|
|
os.unlink(tmp_out_path)
|
|
|
|
|
|
# ── Endpoints ──────────────────────────────────────────────────────────────────
|
|
@app.route("/models", methods=["GET"])
|
|
def models():
|
|
return jsonify({"object": "list", "data": [{"id": "tts-1", "object": "model"}]})
|
|
|
|
|
|
@app.route("/audio/speech", methods=["POST"])
|
|
def speech():
|
|
data = request.get_json(force=True, silent=True) or {}
|
|
text = data.get("input", "").strip()
|
|
voice = data.get("voice", DEFAULT_VOICE)
|
|
fmt = data.get("response_format", "mp3")
|
|
|
|
if not text:
|
|
abort(400, description="'input' field is required")
|
|
|
|
info = resolve_voice(voice)
|
|
log.info("Synthesising %d chars | voice=%s speaker=%s", len(text), voice, info["speaker"])
|
|
|
|
# Handle PCM streaming
|
|
if fmt == "pcm" and USE_GRAPHS:
|
|
log.info("Streaming PCM | %d chars | voice=%s speaker=%s",
|
|
len(text), voice, info["speaker"])
|
|
|
|
def generate_pcm():
|
|
t0 = time.monotonic()
|
|
chunks = 0
|
|
try:
|
|
for audio, sr, timing in tts.generate_custom_voice_streaming(
|
|
text=text,
|
|
language=info["language"],
|
|
speaker=info["speaker"],
|
|
instruct=info["instruct"] or None,
|
|
max_new_tokens=max(60, int(len(text) * 2.5)),
|
|
):
|
|
chunks += 1
|
|
pcm = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16)
|
|
yield pcm.tobytes()
|
|
except Exception as exc:
|
|
log.exception("PCM stream error after %d chunks", chunks)
|
|
return
|
|
log.info("PCM stream done: %d chunks in %.1fs", chunks, time.monotonic() - t0)
|
|
|
|
return Response(
|
|
stream_with_context(generate_pcm()),
|
|
mimetype="audio/pcm",
|
|
headers={"Cache-Control": "no-cache"},
|
|
)
|
|
# Fall through to regular MP3 path below
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
wavs, sr = _synthesise(text, info["language"], info["speaker"], info["instruct"])
|
|
elapsed = time.monotonic() - t0
|
|
audio_s = len(wavs[0]) / sr
|
|
log.info("Synthesis done in %.1fs audio=%.1fs RTF=%.2fx",
|
|
elapsed, audio_s, audio_s / elapsed)
|
|
except Exception as exc:
|
|
log.exception("TTS generation failed")
|
|
abort(500, description=str(exc))
|
|
|
|
wav_buf = io.BytesIO()
|
|
sf.write(wav_buf, wavs[0], sr, format="WAV")
|
|
wav_bytes = wav_buf.getvalue()
|
|
|
|
if fmt == "mp3":
|
|
audio_bytes = wav_to_mp3(wav_bytes)
|
|
mimetype = "audio/mpeg"
|
|
else:
|
|
audio_bytes = wav_bytes
|
|
mimetype = "audio/wav"
|
|
|
|
return send_file(io.BytesIO(audio_bytes), mimetype=mimetype)
|
|
|
|
|
|
# ── Error handlers ─────────────────────────────────────────────────────────────
|
|
@app.errorhandler(400)
|
|
@app.errorhandler(404)
|
|
@app.errorhandler(500)
|
|
@app.errorhandler(502)
|
|
def json_error(e):
|
|
return jsonify({"error": {"message": str(e), "type": "proxy_error"}}), e.code
|
|
|
|
|
|
if __name__ == "__main__":
|
|
port = int(os.getenv("PROXY_PORT", "5000"))
|
|
log.info("Starting proxy on port %d", port)
|
|
app.run(host="0.0.0.0", port=port, debug=False, threaded=True)
|