Skip to content

Commit 7f25f34

Browse files
committed
add ut
Signed-off-by: Lv, Liang1 <[email protected]>
1 parent 23cc784 commit 7f25f34

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

neural_chat/chatbot.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .config import FinetuningConfig
2222
from .pipeline.finetuning.finetuning import Finetuning
2323
from .config import DeviceOptions, BackendOptions, AudioOptions
24-
from models.base_model import get_model_adapter
24+
from .models.base_model import get_model_adapter
2525
from .utils.common import get_device_type, get_backend_type
2626
from .pipeline.plugins.caching.cache import init_similar_cache_from_config
2727
from .pipeline.plugins.audio.asr import AudioSpeechRecognition
@@ -35,8 +35,11 @@ def build_chatbot(config: NeuralChatConfig):
3535
"""Build the chatbot with a given configuration.
3636
3737
Args:
38-
config (NeuralChatConfig): The class of NeuralChatConfig containing model path,
39-
device, backend, plugin config etc.
38+
config (NeuralChatConfig): Configuration for building the chatbot.
39+
40+
Returns:
41+
adapter: The chatbot model adapter.
42+
4043
Example:
4144
from neural_chat.config import NeuralChatConfig
4245
from neural_chat.chatbot import build_chatbot
@@ -45,7 +48,7 @@ def build_chatbot(config: NeuralChatConfig):
4548
response = chatbot.predict("Tell me about Intel Xeon Scalable Processors.")
4649
"""
4750
# Validate input parameters
48-
if config.device not in [option.name for option in DeviceOptions]:
51+
if config.device not in [option.name.lower() for option in DeviceOptions]:
4952
valid_options = ", ".join([option.name.lower() for option in DeviceOptions])
5053
raise ValueError(f"Invalid device value '{config.device}'. Must be one of {valid_options}")
5154

@@ -111,10 +114,20 @@ def build_chatbot(config: NeuralChatConfig):
111114
return adapter
112115

113116
def finetune_model(config: FinetuningConfig):
117+
"""Finetune the model based on the provided configuration.
118+
119+
Args:
120+
config (FinetuningConfig): Configuration for finetuning the model.
121+
"""
122+
114123
assert config is not None, "FinetuningConfig is needed for finetuning."
115124
finetuning = Finetuning(config)
116125
finetuning.finetune()
117126

118127
def optimize_model(config: OptimizationConfig):
119-
# Implement the logic to optimize the model
128+
"""Optimize the model based on the provided configuration.
129+
130+
Args:
131+
config (OptimizationConfig): Configuration for optimizing the model.
132+
"""
120133
pass

neural_chat/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
transformers>=4.31.0
2-
perft
2+
peft
33
fastchat
44
torch
55
intel_extension_for_pytorch
@@ -8,3 +8,5 @@ speechbrain
88
paddlepaddle
99
paddlespeech
1010
shortuuid
11+
gptcache
12+
librosa>=0.10.0
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import unittest
19+
from neural_chat.chatbot import build_chatbot
20+
from neural_chat.config import NeuralChatConfig
21+
22+
class TestChatbotBuilder(unittest.TestCase):
23+
def setUp(self):
24+
pass
25+
26+
def test_build_chatbot_valid_config(self):
27+
config = NeuralChatConfig()
28+
config.device = "cpu"
29+
config.backend = "torch"
30+
config.model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
31+
32+
chatbot = build_chatbot(config)
33+
self.assertIsNotNone(chatbot)
34+
response = chatbot.predict("Tell me about Intel Xeon Scalable Processors.")
35+
print(response)
36+
37+
def test_build_chatbot_invalid_config(self):
38+
# Similar to the previous test, but with an invalid configuration
39+
pass
40+
41+
def test_build_chatbot_retrieval(self):
42+
# Test the retrieval logic
43+
pass
44+
45+
def test_build_chatbot_audio(self):
46+
# Test the audio logic
47+
pass
48+
49+
# Add more tests for other components of the function
50+
51+
if __name__ == '__main__':
52+
unittest.main()

neural_chat/utils/common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,18 @@
1616
# limitations under the License.
1717

1818
import torch
19-
import habana_frameworks.torch.hpu as hthpu
19+
20+
try:
21+
import habana_frameworks.torch.hpu as hthpu
22+
is_hpu_available = True
23+
except ImportError:
24+
print("Package 'habana_frameworks.torch.hpu' is not installed.")
25+
is_hpu_available = False
2026

2127
def get_device_type():
2228
if torch.cuda.is_available():
2329
device = "cuda"
24-
elif hthpu.is_available():
30+
elif is_hpu_available:
2531
device = "hpu"
2632
elif torch.xpu.is_available():
2733
device = "xpu"

0 commit comments

Comments
 (0)