Skip to content

Commit c5cbb22

Browse files
authored
Add streaming support for Neuphonic (livekit#3182)
1 parent 5500c20 commit c5cbb22

File tree

1 file changed

+223
-45
lines changed
  • livekit-plugins/livekit-plugins-neuphonic/livekit/plugins/neuphonic

1 file changed

+223
-45
lines changed

livekit-plugins/livekit-plugins-neuphonic/livekit/plugins/neuphonic/tts.py

Lines changed: 223 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,89 +12,131 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from __future__ import annotations
15+
from __future__ import annotations # noqa: I001
1616

1717
import asyncio
1818
import base64
1919
import json
2020
import os
21+
import weakref
2122
from dataclasses import dataclass, replace
2223

2324
import aiohttp
24-
2525
from livekit.agents import (
2626
APIConnectionError,
2727
APIConnectOptions,
28+
APIError,
2829
APIStatusError,
2930
APITimeoutError,
31+
tokenize,
3032
tts,
3133
utils,
3234
)
3335
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
3436
from livekit.agents.utils import is_given
3537

36-
from .models import TTSLangCodes
38+
from .log import logger # noqa: I001
39+
from .models import TTSLangCodes # noqa: I001
3740

38-
API_BASE_URL = "api.neuphonic.com"
39-
AUTHORIZATION_HEADER = "X-API-KEY"
41+
API_AUTH_HEADER = "x-api-key"
4042

4143

4244
@dataclass
4345
class _TTSOptions:
44-
base_url: str
4546
lang_code: TTSLangCodes | str
46-
api_key: str
47+
encoding: str
4748
sample_rate: int
48-
speed: float
49-
voice_id: str | None
49+
voice_id: str
50+
speed: float | None
51+
api_key: str
52+
base_url: str
53+
word_tokenizer: tokenize.WordTokenizer
54+
55+
def get_http_url(self, path: str) -> str:
56+
return f"{self.base_url}{path}"
57+
58+
def get_ws_url(self, path: str) -> str:
59+
return f"{self.base_url.replace('http', 'ws', 1)}{path}"
5060

5161

5262
class TTS(tts.TTS):
5363
def __init__(
5464
self,
5565
*,
56-
voice_id: str = "8e9c4bc8-3979-48ab-8626-df53befc2090",
5766
api_key: str | None = None,
5867
lang_code: TTSLangCodes | str = "en",
59-
speed: float = 1.0,
68+
encoding: str = "pcm_linear",
69+
voice_id: str = "8e9c4bc8-3979-48ab-8626-df53befc2090",
70+
speed: float | None = 1.0,
6071
sample_rate: int = 22050,
6172
http_session: aiohttp.ClientSession | None = None,
62-
base_url: str = API_BASE_URL,
73+
word_tokenizer: NotGivenOr[tokenize.WordTokenizer] = NOT_GIVEN,
74+
tokenizer: NotGivenOr[tokenize.SentenceTokenizer] = NOT_GIVEN,
75+
base_url: str = "https://api.neuphonic.com",
6376
) -> None:
6477
"""
65-
Create a new instance of the Neuphonic TTS.
78+
Create a new instance of NeuPhonic TTS.
6679
67-
See https://docs.neuphonic.com for more documentation on all of these options, or go to https://app.neuphonic.com/ to test out different options.
80+
See https://docs.neuphonic.com for more details on the NeuPhonic API.
6881
6982
Args:
70-
voice_id (str, optional): The voice ID for the desired voice. Defaults to None.
71-
lang_code (TTSLanguages | str, optional): The language code for synthesis. Defaults to "en".
72-
encoding (TTSEncodings | str, optional): The audio encoding format. Defaults to "pcm_mulaw".
83+
lang_code (TTSLangCodes | str, optional): The language code for synthesis. Defaults to "en".
84+
encoding (str, optional): The audio encoding format. Defaults to "pcm_linear".
85+
voice_id (str, optional): The voice ID for the desired voice.
7386
speed (float, optional): The audio playback speed. Defaults to 1.0.
7487
sample_rate (int, optional): The audio sample rate in Hz. Defaults to 22050.
75-
api_key (str | None, optional): The Neuphonic API key. If not provided, it will be read from the NEUPHONIC_API_KEY environment variable.
88+
api_key (str, optional): The NeuPhonic API key. If not provided, it will be read from the NEUPHONIC_API_KEY environment variable.
7689
http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
77-
base_url (str, optional): The base URL for the Neuphonic API. Defaults to "api.neuphonic.com".
90+
word_tokenizer (tokenize.WordTokenizer, optional): The word tokenizer to use. Defaults to tokenize.basic.WordTokenizer().
91+
tokenizer (tokenize.SentenceTokenizer, optional): The sentence tokenizer to use. Defaults to tokenize.blingfire.SentenceTokenizer().
92+
base_url (str, optional): The base URL for the NeuPhonic API. Defaults to "https://api.neuphonic.com".
7893
""" # noqa: E501
94+
7995
super().__init__(
80-
capabilities=tts.TTSCapabilities(streaming=False),
96+
capabilities=tts.TTSCapabilities(streaming=True),
8197
sample_rate=sample_rate,
8298
num_channels=1,
8399
)
100+
neuphonic_api_key = api_key or os.environ.get("NEUPHONIC_API_KEY")
101+
if not neuphonic_api_key:
102+
raise ValueError("NEUPHONIC_API_KEY must be set")
84103

85-
api_key = api_key or os.environ.get("NEUPHONIC_API_KEY")
86-
if not api_key:
87-
raise ValueError("API key must be provided or set in NEUPHONIC_API_KEY")
104+
if not is_given(word_tokenizer):
105+
word_tokenizer = tokenize.basic.WordTokenizer(ignore_punctuation=False)
88106

89107
self._opts = _TTSOptions(
90-
voice_id=voice_id,
91108
lang_code=lang_code,
92-
api_key=api_key,
93-
speed=speed,
109+
encoding=encoding,
94110
sample_rate=sample_rate,
111+
voice_id=voice_id,
112+
speed=speed,
113+
api_key=neuphonic_api_key,
95114
base_url=base_url,
115+
word_tokenizer=word_tokenizer,
96116
)
97117
self._session = http_session
118+
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
119+
connect_cb=self._connect_ws,
120+
close_cb=self._close_ws,
121+
max_session_duration=300,
122+
mark_refreshed_on_get=True,
123+
)
124+
self._streams = weakref.WeakSet[SynthesizeStream]()
125+
self._sentence_tokenizer = (
126+
tokenizer if is_given(tokenizer) else tokenize.blingfire.SentenceTokenizer()
127+
)
128+
129+
async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
130+
session = self._ensure_session()
131+
url = self._opts.get_ws_url(
132+
f"/speak/en?api_key={self._opts.api_key}&speed={self._opts.speed}&lang_code={self._opts.lang_code}&sampling_rate={self._opts.sample_rate}&voice_id={self._opts.voice_id}"
133+
)
134+
135+
headers = {API_AUTH_HEADER: self._opts.api_key}
136+
return await asyncio.wait_for(session.ws_connect(url, headers=headers), timeout)
137+
138+
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
139+
await ws.close()
98140

99141
@property
100142
def model(self) -> str:
@@ -110,43 +152,56 @@ def _ensure_session(self) -> aiohttp.ClientSession:
110152

111153
return self._session
112154

155+
def prewarm(self) -> None:
156+
self._pool.prewarm()
157+
113158
def update_options(
114159
self,
115160
*,
161+
lang_code: NotGivenOr[TTSLangCodes | str] = NOT_GIVEN,
116162
voice_id: NotGivenOr[str] = NOT_GIVEN,
117-
lang_code: NotGivenOr[TTSLangCodes] = NOT_GIVEN,
118-
speed: NotGivenOr[float] = NOT_GIVEN,
119-
sample_rate: NotGivenOr[int] = NOT_GIVEN,
163+
speed: NotGivenOr[float | None] = NOT_GIVEN,
120164
) -> None:
121165
"""
122166
Update the Text-to-Speech (TTS) configuration options.
123167
124-
This method allows updating the TTS settings, including model type, voice_id, lang_code,
125-
encoding, speed and sample_rate. If any parameter is not provided, the existing value will be
126-
retained.
168+
This allows updating the TTS settings, including lang_code, voice_id, and speed.
169+
If any parameter is not provided, the existing value will be retained.
127170
128171
Args:
129-
model (TTSModels | str, optional): The Neuphonic model to use.
172+
lang_code (TTSLangCodes | str, optional): The language code for synthesis.
130173
voice_id (str, optional): The voice ID for the desired voice.
131-
lang_code (TTSLanguages | str, optional): The language code for synthesis..
132-
encoding (TTSEncodings | str, optional): The audio encoding format.
133174
speed (float, optional): The audio playback speed.
134-
sample_rate (int, optional): The audio sample rate in Hz.
135-
""" # noqa: E501
136-
if is_given(voice_id):
137-
self._opts.voice_id = voice_id
175+
"""
138176
if is_given(lang_code):
139177
self._opts.lang_code = lang_code
178+
if is_given(voice_id):
179+
self._opts.voice_id = voice_id
140180
if is_given(speed):
141181
self._opts.speed = speed
142-
if is_given(sample_rate):
143-
self._opts.sample_rate = sample_rate
144182

145183
def synthesize(
146-
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
184+
self,
185+
text: str,
186+
*,
187+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
147188
) -> ChunkedStream:
148189
return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)
149190

191+
def stream(
192+
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
193+
) -> SynthesizeStream:
194+
stream = SynthesizeStream(tts=self, conn_options=conn_options)
195+
self._streams.add(stream)
196+
return stream
197+
198+
async def aclose(self) -> None:
199+
for stream in list(self._streams):
200+
await stream.aclose()
201+
202+
self._streams.clear()
203+
await self._pool.aclose()
204+
150205

151206
class ChunkedStream(tts.ChunkedStream):
152207
"""Synthesize chunked text using the SSE endpoint"""
@@ -165,8 +220,8 @@ def __init__(
165220
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
166221
try:
167222
async with self._tts._ensure_session().post(
168-
f"https://{self._opts.base_url}/sse/speak/{self._opts.lang_code}",
169-
headers={AUTHORIZATION_HEADER: self._opts.api_key},
223+
f"{self._opts.base_url}/sse/speak/{self._opts.lang_code}",
224+
headers={API_AUTH_HEADER: self._opts.api_key},
170225
json={
171226
"text": self._input_text,
172227
"voice_id": self._opts.voice_id,
@@ -235,7 +290,130 @@ def _parse_sse_message(message: str) -> dict | None:
235290

236291
if message_dict.get("errors") is not None:
237292
raise Exception(
238-
f"received error status {message_dict['status_code']}: {message_dict['errors']}"
293+
f"received error status {message_dict['status_code']}:{message_dict['errors']}"
239294
)
240295

241296
return message_dict
297+
298+
299+
class SynthesizeStream(tts.SynthesizeStream):
300+
def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
301+
super().__init__(tts=tts, conn_options=conn_options)
302+
self._tts: TTS = tts
303+
self._opts = replace(tts._opts)
304+
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
305+
306+
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
307+
request_id = utils.shortuuid()
308+
output_emitter.initialize(
309+
request_id=request_id,
310+
sample_rate=self._opts.sample_rate,
311+
num_channels=1,
312+
mime_type="audio/pcm",
313+
stream=True,
314+
)
315+
316+
async def _tokenize_input() -> None:
317+
word_stream = None
318+
async for input in self._input_ch:
319+
if isinstance(input, str):
320+
if word_stream is None:
321+
word_stream = self._opts.word_tokenizer.stream()
322+
self._segments_ch.send_nowait(word_stream)
323+
word_stream.push_text(input)
324+
elif isinstance(input, self._FlushSentinel):
325+
if word_stream:
326+
word_stream.end_input()
327+
word_stream = None
328+
329+
self._segments_ch.close()
330+
331+
async def _run_segments() -> None:
332+
async for word_stream in self._segments_ch:
333+
await self._run_ws(word_stream, output_emitter)
334+
335+
tasks = [
336+
asyncio.create_task(_tokenize_input()),
337+
asyncio.create_task(_run_segments()),
338+
]
339+
try:
340+
await asyncio.gather(*tasks)
341+
except asyncio.TimeoutError:
342+
raise APITimeoutError() from None
343+
except aiohttp.ClientResponseError as e:
344+
raise APIStatusError(
345+
message=e.message,
346+
status_code=e.status,
347+
request_id=request_id,
348+
body=None,
349+
) from None
350+
except Exception as e:
351+
raise APIConnectionError() from e
352+
finally:
353+
await utils.aio.gracefully_cancel(*tasks)
354+
355+
async def _run_ws(
356+
self, word_stream: tokenize.WordStream, output_emitter: tts.AudioEmitter
357+
) -> None:
358+
segment_id = utils.shortuuid()
359+
output_emitter.start_segment(segment_id=segment_id)
360+
361+
async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
362+
async for word in word_stream:
363+
text_msg = {"text": f"{word.token} "}
364+
self._mark_started()
365+
await ws.send_str(json.dumps(text_msg))
366+
367+
stop_msg = {"text": "<STOP>"}
368+
await ws.send_str(json.dumps(stop_msg))
369+
370+
async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
371+
while True:
372+
msg = await ws.receive()
373+
374+
if msg.type in (
375+
aiohttp.WSMsgType.CLOSE,
376+
aiohttp.WSMsgType.CLOSED,
377+
aiohttp.WSMsgType.CLOSING,
378+
):
379+
raise APIStatusError("NeuPhonic websocket connection closed unexpectedly")
380+
381+
if msg.type == aiohttp.WSMsgType.TEXT:
382+
try:
383+
resp = json.loads(msg.data)
384+
except json.JSONDecodeError:
385+
logger.warning("Invalid JSON from NeuPhonic")
386+
continue
387+
388+
if resp.get("type") == "error":
389+
raise APIError(f"NeuPhonic returned error: {resp}")
390+
391+
data = resp.get("data", {})
392+
audio_data = data.get("audio")
393+
if audio_data and audio_data != "":
394+
try:
395+
b64data = base64.b64decode(audio_data)
396+
if b64data:
397+
output_emitter.push(b64data)
398+
except Exception as e:
399+
logger.warning("Failed to decode NeuPhonic audio data: %s", e)
400+
401+
if data.get("stop"):
402+
output_emitter.end_segment()
403+
break
404+
405+
elif msg.type == aiohttp.WSMsgType.BINARY:
406+
pass
407+
else:
408+
logger.warning("Unexpected NeuPhonic message type: %s", msg.type)
409+
410+
async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
411+
tasks = [
412+
asyncio.create_task(send_task(ws)),
413+
asyncio.create_task(recv_task(ws)),
414+
]
415+
416+
try:
417+
await asyncio.gather(*tasks)
418+
finally:
419+
await utils.aio.gracefully_cancel(*tasks)

0 commit comments

Comments
 (0)