1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from __future__ import annotations
15+ from __future__ import annotations # noqa: I001
1616
1717import asyncio
1818import base64
1919import json
2020import os
21+ import weakref
2122from dataclasses import dataclass , replace
2223
2324import aiohttp
24-
2525from livekit .agents import (
2626 APIConnectionError ,
2727 APIConnectOptions ,
28+ APIError ,
2829 APIStatusError ,
2930 APITimeoutError ,
31+ tokenize ,
3032 tts ,
3133 utils ,
3234)
3335from livekit .agents .types import DEFAULT_API_CONNECT_OPTIONS , NOT_GIVEN , NotGivenOr
3436from livekit .agents .utils import is_given
3537
36- from .models import TTSLangCodes
38+ from .log import logger # noqa: I001
39+ from .models import TTSLangCodes # noqa: I001
3740
38- API_BASE_URL = "api.neuphonic.com"
39- AUTHORIZATION_HEADER = "X-API-KEY"
41+ API_AUTH_HEADER = "x-api-key"
4042
4143
4244@dataclass
4345class _TTSOptions :
44- base_url : str
4546 lang_code : TTSLangCodes | str
46- api_key : str
47+ encoding : str
4748 sample_rate : int
48- speed : float
49- voice_id : str | None
49+ voice_id : str
50+ speed : float | None
51+ api_key : str
52+ base_url : str
53+ word_tokenizer : tokenize .WordTokenizer
54+
55+ def get_http_url (self , path : str ) -> str :
56+ return f"{ self .base_url } { path } "
57+
58+ def get_ws_url (self , path : str ) -> str :
59+ return f"{ self .base_url .replace ('http' , 'ws' , 1 )} { path } "
5060
5161
5262class TTS (tts .TTS ):
5363 def __init__ (
5464 self ,
5565 * ,
56- voice_id : str = "8e9c4bc8-3979-48ab-8626-df53befc2090" ,
5766 api_key : str | None = None ,
5867 lang_code : TTSLangCodes | str = "en" ,
59- speed : float = 1.0 ,
68+ encoding : str = "pcm_linear" ,
69+ voice_id : str = "8e9c4bc8-3979-48ab-8626-df53befc2090" ,
70+ speed : float | None = 1.0 ,
6071 sample_rate : int = 22050 ,
6172 http_session : aiohttp .ClientSession | None = None ,
62- base_url : str = API_BASE_URL ,
73+ word_tokenizer : NotGivenOr [tokenize .WordTokenizer ] = NOT_GIVEN ,
74+ tokenizer : NotGivenOr [tokenize .SentenceTokenizer ] = NOT_GIVEN ,
75+ base_url : str = "https://api.neuphonic.com" ,
6376 ) -> None :
6477 """
65- Create a new instance of the Neuphonic TTS.
78+ Create a new instance of NeuPhonic TTS.
6679
67- See https://docs.neuphonic.com for more documentation on all of these options, or go to https://app.neuphonic.com/ to test out different options .
80+ See https://docs.neuphonic.com for more details on the NeuPhonic API .
6881
6982 Args:
70- voice_id ( str, optional): The voice ID for the desired voice . Defaults to None .
71- lang_code (TTSLanguages | str, optional): The language code for synthesis . Defaults to "en ".
72- encoding (TTSEncodings | str, optional): The audio encoding format. Defaults to "pcm_mulaw" .
83+ lang_code (TTSLangCodes | str, optional): The language code for synthesis . Defaults to "en" .
84+ encoding ( str, optional): The audio encoding format . Defaults to "pcm_linear ".
85+ voice_id ( str, optional): The voice ID for the desired voice .
7386 speed (float, optional): The audio playback speed. Defaults to 1.0.
7487 sample_rate (int, optional): The audio sample rate in Hz. Defaults to 22050.
75- api_key (str | None , optional): The Neuphonic API key. If not provided, it will be read from the NEUPHONIC_API_KEY environment variable.
88+ api_key (str, optional): The NeuPhonic API key. If not provided, it will be read from the NEUPHONIC_API_KEY environment variable.
7689 http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
77- base_url (str, optional): The base URL for the Neuphonic API. Defaults to "api.neuphonic.com".
90+ word_tokenizer (tokenize.WordTokenizer, optional): The word tokenizer to use. Defaults to tokenize.basic.WordTokenizer().
91+ tokenizer (tokenize.SentenceTokenizer, optional): The sentence tokenizer to use. Defaults to tokenize.blingfire.SentenceTokenizer().
92+ base_url (str, optional): The base URL for the NeuPhonic API. Defaults to "https://api.neuphonic.com".
7893 """ # noqa: E501
94+
7995 super ().__init__ (
80- capabilities = tts .TTSCapabilities (streaming = False ),
96+ capabilities = tts .TTSCapabilities (streaming = True ),
8197 sample_rate = sample_rate ,
8298 num_channels = 1 ,
8399 )
100+ neuphonic_api_key = api_key or os .environ .get ("NEUPHONIC_API_KEY" )
101+ if not neuphonic_api_key :
102+ raise ValueError ("NEUPHONIC_API_KEY must be set" )
84103
85- api_key = api_key or os .environ .get ("NEUPHONIC_API_KEY" )
86- if not api_key :
87- raise ValueError ("API key must be provided or set in NEUPHONIC_API_KEY" )
104+ if not is_given (word_tokenizer ):
105+ word_tokenizer = tokenize .basic .WordTokenizer (ignore_punctuation = False )
88106
89107 self ._opts = _TTSOptions (
90- voice_id = voice_id ,
91108 lang_code = lang_code ,
92- api_key = api_key ,
93- speed = speed ,
109+ encoding = encoding ,
94110 sample_rate = sample_rate ,
111+ voice_id = voice_id ,
112+ speed = speed ,
113+ api_key = neuphonic_api_key ,
95114 base_url = base_url ,
115+ word_tokenizer = word_tokenizer ,
96116 )
97117 self ._session = http_session
118+ self ._pool = utils .ConnectionPool [aiohttp .ClientWebSocketResponse ](
119+ connect_cb = self ._connect_ws ,
120+ close_cb = self ._close_ws ,
121+ max_session_duration = 300 ,
122+ mark_refreshed_on_get = True ,
123+ )
124+ self ._streams = weakref .WeakSet [SynthesizeStream ]()
125+ self ._sentence_tokenizer = (
126+ tokenizer if is_given (tokenizer ) else tokenize .blingfire .SentenceTokenizer ()
127+ )
128+
129+ async def _connect_ws (self , timeout : float ) -> aiohttp .ClientWebSocketResponse :
130+ session = self ._ensure_session ()
131+ url = self ._opts .get_ws_url (
132+ f"/speak/en?api_key={ self ._opts .api_key } &speed={ self ._opts .speed } &lang_code={ self ._opts .lang_code } &sampling_rate={ self ._opts .sample_rate } &voice_id={ self ._opts .voice_id } "
133+ )
134+
135+ headers = {API_AUTH_HEADER : self ._opts .api_key }
136+ return await asyncio .wait_for (session .ws_connect (url , headers = headers ), timeout )
137+
138+ async def _close_ws (self , ws : aiohttp .ClientWebSocketResponse ) -> None :
139+ await ws .close ()
98140
99141 @property
100142 def model (self ) -> str :
@@ -110,43 +152,56 @@ def _ensure_session(self) -> aiohttp.ClientSession:
110152
111153 return self ._session
112154
155+ def prewarm (self ) -> None :
156+ self ._pool .prewarm ()
157+
113158 def update_options (
114159 self ,
115160 * ,
161+ lang_code : NotGivenOr [TTSLangCodes | str ] = NOT_GIVEN ,
116162 voice_id : NotGivenOr [str ] = NOT_GIVEN ,
117- lang_code : NotGivenOr [TTSLangCodes ] = NOT_GIVEN ,
118- speed : NotGivenOr [float ] = NOT_GIVEN ,
119- sample_rate : NotGivenOr [int ] = NOT_GIVEN ,
163+ speed : NotGivenOr [float | None ] = NOT_GIVEN ,
120164 ) -> None :
121165 """
122166 Update the Text-to-Speech (TTS) configuration options.
123167
124- This method allows updating the TTS settings, including model type, voice_id, lang_code,
125- encoding, speed and sample_rate. If any parameter is not provided, the existing value will be
126- retained.
168+ This allows updating the TTS settings, including lang_code, voice_id, and speed.
169+ If any parameter is not provided, the existing value will be retained.
127170
128171 Args:
129- model (TTSModels | str, optional): The Neuphonic model to use .
172+ lang_code (TTSLangCodes | str, optional): The language code for synthesis .
130173 voice_id (str, optional): The voice ID for the desired voice.
131- lang_code (TTSLanguages | str, optional): The language code for synthesis..
132- encoding (TTSEncodings | str, optional): The audio encoding format.
133174 speed (float, optional): The audio playback speed.
134- sample_rate (int, optional): The audio sample rate in Hz.
135- """ # noqa: E501
136- if is_given (voice_id ):
137- self ._opts .voice_id = voice_id
175+ """
138176 if is_given (lang_code ):
139177 self ._opts .lang_code = lang_code
178+ if is_given (voice_id ):
179+ self ._opts .voice_id = voice_id
140180 if is_given (speed ):
141181 self ._opts .speed = speed
142- if is_given (sample_rate ):
143- self ._opts .sample_rate = sample_rate
144182
145183 def synthesize (
146- self , text : str , * , conn_options : APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
184+ self ,
185+ text : str ,
186+ * ,
187+ conn_options : APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS ,
147188 ) -> ChunkedStream :
148189 return ChunkedStream (tts = self , input_text = text , conn_options = conn_options )
149190
191+ def stream (
192+ self , * , conn_options : APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
193+ ) -> SynthesizeStream :
194+ stream = SynthesizeStream (tts = self , conn_options = conn_options )
195+ self ._streams .add (stream )
196+ return stream
197+
198+ async def aclose (self ) -> None :
199+ for stream in list (self ._streams ):
200+ await stream .aclose ()
201+
202+ self ._streams .clear ()
203+ await self ._pool .aclose ()
204+
150205
151206class ChunkedStream (tts .ChunkedStream ):
152207 """Synthesize chunked text using the SSE endpoint"""
@@ -165,8 +220,8 @@ def __init__(
165220 async def _run (self , output_emitter : tts .AudioEmitter ) -> None :
166221 try :
167222 async with self ._tts ._ensure_session ().post (
168- f"https:// { self ._opts .base_url } /sse/speak/{ self ._opts .lang_code } " ,
169- headers = {AUTHORIZATION_HEADER : self ._opts .api_key },
223+ f"{ self ._opts .base_url } /sse/speak/{ self ._opts .lang_code } " ,
224+ headers = {API_AUTH_HEADER : self ._opts .api_key },
170225 json = {
171226 "text" : self ._input_text ,
172227 "voice_id" : self ._opts .voice_id ,
@@ -235,7 +290,130 @@ def _parse_sse_message(message: str) -> dict | None:
235290
236291 if message_dict .get ("errors" ) is not None :
237292 raise Exception (
238- f"received error status { message_dict ['status_code' ]} : { message_dict ['errors' ]} "
293+ f"received error status { message_dict ['status_code' ]} :{ message_dict ['errors' ]} "
239294 )
240295
241296 return message_dict
297+
298+
299+ class SynthesizeStream (tts .SynthesizeStream ):
300+ def __init__ (self , * , tts : TTS , conn_options : APIConnectOptions ):
301+ super ().__init__ (tts = tts , conn_options = conn_options )
302+ self ._tts : TTS = tts
303+ self ._opts = replace (tts ._opts )
304+ self ._segments_ch = utils .aio .Chan [tokenize .WordStream ]()
305+
306+ async def _run (self , output_emitter : tts .AudioEmitter ) -> None :
307+ request_id = utils .shortuuid ()
308+ output_emitter .initialize (
309+ request_id = request_id ,
310+ sample_rate = self ._opts .sample_rate ,
311+ num_channels = 1 ,
312+ mime_type = "audio/pcm" ,
313+ stream = True ,
314+ )
315+
316+ async def _tokenize_input () -> None :
317+ word_stream = None
318+ async for input in self ._input_ch :
319+ if isinstance (input , str ):
320+ if word_stream is None :
321+ word_stream = self ._opts .word_tokenizer .stream ()
322+ self ._segments_ch .send_nowait (word_stream )
323+ word_stream .push_text (input )
324+ elif isinstance (input , self ._FlushSentinel ):
325+ if word_stream :
326+ word_stream .end_input ()
327+ word_stream = None
328+
329+ self ._segments_ch .close ()
330+
331+ async def _run_segments () -> None :
332+ async for word_stream in self ._segments_ch :
333+ await self ._run_ws (word_stream , output_emitter )
334+
335+ tasks = [
336+ asyncio .create_task (_tokenize_input ()),
337+ asyncio .create_task (_run_segments ()),
338+ ]
339+ try :
340+ await asyncio .gather (* tasks )
341+ except asyncio .TimeoutError :
342+ raise APITimeoutError () from None
343+ except aiohttp .ClientResponseError as e :
344+ raise APIStatusError (
345+ message = e .message ,
346+ status_code = e .status ,
347+ request_id = request_id ,
348+ body = None ,
349+ ) from None
350+ except Exception as e :
351+ raise APIConnectionError () from e
352+ finally :
353+ await utils .aio .gracefully_cancel (* tasks )
354+
355+ async def _run_ws (
356+ self , word_stream : tokenize .WordStream , output_emitter : tts .AudioEmitter
357+ ) -> None :
358+ segment_id = utils .shortuuid ()
359+ output_emitter .start_segment (segment_id = segment_id )
360+
361+ async def send_task (ws : aiohttp .ClientWebSocketResponse ) -> None :
362+ async for word in word_stream :
363+ text_msg = {"text" : f"{ word .token } " }
364+ self ._mark_started ()
365+ await ws .send_str (json .dumps (text_msg ))
366+
367+ stop_msg = {"text" : "<STOP>" }
368+ await ws .send_str (json .dumps (stop_msg ))
369+
370+ async def recv_task (ws : aiohttp .ClientWebSocketResponse ) -> None :
371+ while True :
372+ msg = await ws .receive ()
373+
374+ if msg .type in (
375+ aiohttp .WSMsgType .CLOSE ,
376+ aiohttp .WSMsgType .CLOSED ,
377+ aiohttp .WSMsgType .CLOSING ,
378+ ):
379+ raise APIStatusError ("NeuPhonic websocket connection closed unexpectedly" )
380+
381+ if msg .type == aiohttp .WSMsgType .TEXT :
382+ try :
383+ resp = json .loads (msg .data )
384+ except json .JSONDecodeError :
385+ logger .warning ("Invalid JSON from NeuPhonic" )
386+ continue
387+
388+ if resp .get ("type" ) == "error" :
389+ raise APIError (f"NeuPhonic returned error: { resp } " )
390+
391+ data = resp .get ("data" , {})
392+ audio_data = data .get ("audio" )
393+ if audio_data and audio_data != "" :
394+ try :
395+ b64data = base64 .b64decode (audio_data )
396+ if b64data :
397+ output_emitter .push (b64data )
398+ except Exception as e :
399+ logger .warning ("Failed to decode NeuPhonic audio data: %s" , e )
400+
401+ if data .get ("stop" ):
402+ output_emitter .end_segment ()
403+ break
404+
405+ elif msg .type == aiohttp .WSMsgType .BINARY :
406+ pass
407+ else :
408+ logger .warning ("Unexpected NeuPhonic message type: %s" , msg .type )
409+
410+ async with self ._tts ._pool .connection (timeout = self ._conn_options .timeout ) as ws :
411+ tasks = [
412+ asyncio .create_task (send_task (ws )),
413+ asyncio .create_task (recv_task (ws )),
414+ ]
415+
416+ try :
417+ await asyncio .gather (* tasks )
418+ finally :
419+ await utils .aio .gracefully_cancel (* tasks )
0 commit comments