Skip to content

Commit e7bcb7c

Browse files
committed
add unit test for textchat & voicechat, update api files
Signed-off-by: LetongHan <[email protected]>
1 parent 6baa80e commit e7bcb7c

File tree

9 files changed

+130
-25
lines changed

9 files changed

+130
-25
lines changed

neural_chat/server/restful/finetune_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from neural_chat.cli.log import logger
2121
from neural_chat.server.restful.request import FinetuneRequest
2222
from neural_chat.server.restful.response import FinetuneResponse
23+
from neural_chat.config import NeuralChatConfig
24+
from neural_chat.chatbot import build_chatbot
2325

2426

2527
def check_finetune_params(request: FinetuneRequest) -> Optional[str]:
@@ -40,7 +42,7 @@ def set_chatbot(self, chatbot) -> None:
4042

4143
def get_chatbot(self):
4244
if self.chatbot is None:
43-
raise RuntimeError("Finetunebot instance has not been set.")
45+
raise RuntimeError("Chatbot instance has not been set.")
4446
return self.chatbot
4547

4648
def handle_finetune_request(self, request: FinetuneRequest) -> FinetuneResponse:
@@ -50,6 +52,9 @@ def handle_finetune_request(self, request: FinetuneRequest) -> FinetuneResponse:
5052

5153

5254
router = FinetuneAPIRouter()
55+
config = NeuralChatConfig()
56+
bot = build_chatbot(config)
57+
router.set_chatbot(bot)
5358

5459

5560
@router.post("/v1/finetune")

neural_chat/server/restful/request.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ class Text2ImageRequest(RequestBaseModel):
4444
sd_inference_token: Optional[str] = None
4545

4646

47-
class VoiceRequest(RequestBaseModel):
48-
voice: str
49-
50-
5147
class TextRequest(RequestBaseModel):
5248
text: str
5349

neural_chat/server/restful/response.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,6 @@ class ImageResponse(ResponseBaseModel):
3333
image: bytes
3434

3535

36-
class TextResponse(ResponseBaseModel):
37-
content: str
38-
39-
40-
class VoiceResponse(ResponseBaseModel):
41-
voice: str
42-
43-
4436
class RetrievalResponse(ResponseBaseModel):
4537
content: str
4638

neural_chat/server/restful/retrieval_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from neural_chat.cli.log import logger
2222
from neural_chat.server.restful.request import RetrievalRequest
2323
from neural_chat.server.restful.response import RetrievalResponse
24+
from neural_chat.config import NeuralChatConfig
25+
from neural_chat.chatbot import build_chatbot
2426

2527

2628
def check_retrieval_params(request: RetrievalRequest) -> Optional[str]:
@@ -52,6 +54,9 @@ def handle_retrieval_request(self, request: RetrievalRequest) -> RetrievalRespon
5254

5355

5456
router = RetrievalAPIRouter()
57+
config = NeuralChatConfig()
58+
bot = build_chatbot(config)
59+
router.set_chatbot(bot)
5560

5661

5762
@router.post("/v1/retrieval")

neural_chat/server/restful/textchat_api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
ChatCompletionRequest, ChatCompletionResponseChoice, ChatCompletionResponse,
2727
UsageInfo, ModelCard, ModelList, ModelPermission, ChatMessage
2828
)
29+
from neural_chat.config import NeuralChatConfig
30+
from neural_chat.chatbot import build_chatbot
2931

3032

3133
# TODO: process request and return params in Dict
@@ -172,18 +174,21 @@ async def handle_chat_completion_request(self, request: ChatCompletionRequest) -
172174
for usage_key, usage_value in task_usage.dict().items():
173175
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
174176

175-
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
177+
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
176178

177179

178180
router = TextChatAPIRouter()
181+
config = NeuralChatConfig()
182+
bot = build_chatbot(config)
183+
router.set_chatbot(bot)
179184

180185

181186
@router.post("/v1/models")
182187
async def models_endpoint() -> ModelList:
183188
return await router.handle_models_request()
184189

185190

186-
@router.post("/v1/completion")
191+
@router.post("/v1/completions")
187192
async def completion_endpoint(request: CompletionRequest) -> CompletionResponse:
188193
ret = check_completion_request()
189194
if ret is not None:

neural_chat/server/restful/voicechat_api.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from typing import Union
18+
from typing import ByteString
1919
from fastapi import APIRouter
2020
from neural_chat.cli.log import logger
21-
from neural_chat.server.restful.request import VoiceRequest, TextRequest
22-
from neural_chat.server.restful.response import VoiceResponse, TextResponse
21+
from neural_chat.config import NeuralChatConfig
22+
from neural_chat.chatbot import build_chatbot
2323

2424

2525
class VoiceChatAPIRouter(APIRouter):
@@ -36,28 +36,33 @@ def get_chatbot(self):
3636
raise RuntimeError("Chatbot instance has not been set.")
3737
return self.chatbot
3838

39-
async def handle_voice2text_request(self, request: VoiceRequest) -> TextResponse:
39+
async def handle_voice2text_request(self, request: ByteString) -> str:
4040
# TODO: implement voice to text
4141
chatbot = self.get_chatbot()
4242
# TODO: chatbot.voice2text()
43-
return TextResponse(content=None)
43+
result = chatbot.predict(request.voice)
44+
return result
4445

45-
async def handle_text2voice_request(self, request: TextRequest) -> VoiceResponse:
46+
async def handle_text2voice_request(self, text: str) -> ByteString:
4647
# TODO: implement text to voice
4748
chatbot = self.get_chatbot()
4849
# TODO: chatbot.text2voice()
49-
return VoiceResponse(content=None)
50+
result = chatbot.predict(text)
51+
return result
5052

5153

5254
router = VoiceChatAPIRouter()
55+
config = NeuralChatConfig()
56+
bot = build_chatbot(config)
57+
router.set_chatbot(bot)
5358

5459
# voice to text
5560
@router.post("/v1/voice/asr")
56-
async def voice2text(requst: VoiceRequest) -> TextResponse:
57-
return await router.handle_voice2text_request(requst)
61+
async def voice2text(request: ByteString) -> str:
62+
return await router.handle_voice2text_request(request)
5863

5964

6065
# text to voice
6166
@router.post("/v1/voice/tts")
62-
async def voice2text(requst: TextRequest) -> VoiceResponse:
67+
async def voice2text(requst: str) -> ByteString:
6368
return await router.handle_text2voice_request(requst)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
3+
# Get the host and port from the environment variables
4+
host = os.environ.get('MY_HOST')
5+
port = os.environ.get('MY_PORT')
6+
7+
# Check if the environment variables are set and not empty
8+
if host and port:
9+
# Combine the host and port to form the full URL
10+
HOST = f"http://{host}:{port}"
11+
API_COMPLETION = '/v1/completions'
12+
API_CHAT_COMPLETION = '/v1/chat/completions'
13+
API_ASR = '/v1/voice/asr'
14+
API_TTS = '/v1/voice/tts'
15+
16+
print("HOST URL:", HOST)
17+
print("Completions Endpoint:", API_COMPLETION)
18+
print("Chat completions Endpoint:", API_CHAT_COMPLETION)
19+
print("Voice ASR Endpoint:", API_ASR)
20+
print("Voice TTS Endpoint:", API_TTS)
21+
else:
22+
raise("Please set the environment variables MY_HOST and MY_PORT.")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env python
2+
import requests
3+
import unittest
4+
from neural_chat.tests.restful.config import HOST, API_COMPLETION, API_CHAT_COMPLETION
5+
from neural_chat.server.restful.openai_protocol import CompletionRequest, ChatCompletionRequest
6+
from neural_chat.cli.log import logger
7+
8+
9+
class UnitTest(unittest.TestCase):
10+
11+
def __init__(self, *args):
12+
super(UnitTest, self).__init__(*args)
13+
self.host = HOST
14+
15+
def test_completions(self):
16+
logger.info(f'Testing POST request: {self.host+API_COMPLETION}')
17+
request = CompletionRequest(
18+
model="mpt-7b-chat",
19+
prompt="This is a test."
20+
)
21+
response = requests.post(self.host+API_COMPLETION, data=request)
22+
logger.info('Response status code: {}'.format(response.status_code))
23+
logger.info('Response text: {}'.format(response.choices.text))
24+
self.assertEqual(response.status_code, 200, msg="Abnormal response status code.")
25+
26+
def test_chat_completions(self):
27+
logger.info(f'Testing POST request: {self.host+API_CHAT_COMPLETION}')
28+
request = ChatCompletionRequest(
29+
model="mpt-7b-chat",
30+
messages=[
31+
{"role": "system","content": "You are a helpful assistant."},
32+
{"role": "user","content": "Hello!"}
33+
]
34+
)
35+
response = requests.post(self.host+API_CHAT_COMPLETION, data=request)
36+
logger.info('Response status code: {}'.format(response.status_code))
37+
logger.info('Response text: {}'.format(response.choices.message.content))
38+
self.assertEqual(response.status_code, 200, msg="Abnormal response status code.")
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python
2+
import requests
3+
import unittest
4+
from neural_chat.tests.restful.config import HOST, API_ASR, API_TTS
5+
from neural_chat.cli.log import logger
6+
from neural_chat.pipeline.plugins.audio.asr import AudioSpeechRecognition
7+
8+
9+
class UnitTest(unittest.TestCase):
10+
11+
def __init__(self, *args):
12+
super(UnitTest, self).__init__(*args)
13+
self.host = HOST
14+
15+
def test_asr(self):
16+
logger.info(f'Testing POST request: {self.host+API_ASR}')
17+
request = ""
18+
response = requests.post(self.host+API_ASR, json=request)
19+
logger.info('Response status code: {}'.format(response.status_code))
20+
logger.info('Response text: {}'.format(response.text))
21+
self.assertEqual(response.status_code, 200, msg="Abnormal response status code.")
22+
23+
def test_tts(self):
24+
logger.info(f'Testing POST request: {self.host+API_TTS}')
25+
request = "Hello, nice to meet you!"
26+
response = requests.post(self.host+API_TTS, data=request)
27+
logger.info('Response status code: {}'.format(response.status_code))
28+
logger.info('Response voice: {}'.format(response.voice))
29+
self.assertEqual(response.status_code, 200, msg="Abnormal response status code.")
30+
31+
32+
if __name__ == "__main__":
33+
unittest.main()

0 commit comments

Comments
 (0)