|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import inspect |
| 4 | +import typing as t |
| 5 | +from numpy import str_ |
| 6 | + |
| 7 | +import torch |
| 8 | +import gradio as gr |
| 9 | +import visual_chatgpt as vc |
| 10 | +from langchain.llms.openai import OpenAI |
| 11 | +from langchain.agents.tools import Tool |
| 12 | +from langchain.chains.conversation.memory import ConversationBufferMemory |
| 13 | + |
| 14 | +import bentoml |
| 15 | +from bentoml.io import JSON |
| 16 | + |
| 17 | + |
| 18 | +# In local mode, ChatBot pass images to models using image's path. In |
| 19 | +# distributed mode, ChatBot needs to send the content of image files |
| 20 | +# over network to models/runners |
| 21 | + |
| 22 | +def path_to_tuple(path: str): |
| 23 | + with open(path, "rb") as f: |
| 24 | + bs = f.read() |
| 25 | + return (path, bs) |
| 26 | + |
| 27 | + |
| 28 | +def tuple_to_path(t: tuple[str, bytes]): |
| 29 | + path, bs = t |
| 30 | + with open(path, "wb") as f: |
| 31 | + f.write(bs) |
| 32 | + return path |
| 33 | + |
| 34 | + |
| 35 | +def path_and_text_to_tuple(path_and_text: str): |
| 36 | + path, _, text = path_and_text.partition(",") |
| 37 | + img_tuple = path_to_tuple(path) |
| 38 | + return img_tuple + (text, ) |
| 39 | + |
| 40 | + |
| 41 | +def tuple_to_path_and_text(t: tuple[str, bytes, str]): |
| 42 | + path, bs, text = t |
| 43 | + path = tuple_to_path((path, bs)) |
| 44 | + return ",".join([path, text]) |
| 45 | + |
| 46 | + |
| 47 | +TOOL_DIST_PROCESSORS = { |
| 48 | + # image input, text out |
| 49 | + "ImageCaptioning.inference": { |
| 50 | + "runner_out": lambda captions: captions, |
| 51 | + "api_in": lambda captions: captions, |
| 52 | + }, |
| 53 | + |
| 54 | + # text input, image out |
| 55 | + "Text2Image.inference": { |
| 56 | + "api_out": lambda text: text, |
| 57 | + "runner_in": lambda text: text, |
| 58 | + }, |
| 59 | + |
| 60 | + # image and text input, image out |
| 61 | + "InstructPix2Pix.inference": { |
| 62 | + "api_out": path_and_text_to_tuple, |
| 63 | + "runner_in": tuple_to_path_and_text, |
| 64 | + }, |
| 65 | + "PoseText2Image.inference": { |
| 66 | + "api_out": path_and_text_to_tuple, |
| 67 | + "runner_in": tuple_to_path_and_text, |
| 68 | + }, |
| 69 | + "SegText2Image.inference": { |
| 70 | + "api_out": path_and_text_to_tuple, |
| 71 | + "runner_in": tuple_to_path_and_text, |
| 72 | + }, |
| 73 | + "DepthText2Image.inference": { |
| 74 | + "api_out": path_and_text_to_tuple, |
| 75 | + "runner_in": tuple_to_path_and_text, |
| 76 | + }, |
| 77 | + "NormalText2Image.inference": { |
| 78 | + "api_out": path_and_text_to_tuple, |
| 79 | + "runner_in": tuple_to_path_and_text, |
| 80 | + }, |
| 81 | + "Text2Box.inference": { |
| 82 | + "api_out": path_and_text_to_tuple, |
| 83 | + "runner_in": tuple_to_path_and_text, |
| 84 | + }, |
| 85 | + |
| 86 | + # image and text input, text out |
| 87 | + "VisualQuestionAnswering.inference": { |
| 88 | + "api_out": path_and_text_to_tuple, |
| 89 | + "runner_in": tuple_to_path_and_text, |
| 90 | + "runner_out": lambda text: text, |
| 91 | + "api_in": lambda text: text, |
| 92 | + }, |
| 93 | +} |
| 94 | + |
| 95 | + |
| 96 | +class BaseToolRunnable(bentoml.Runnable): |
| 97 | + pass |
| 98 | + |
| 99 | + |
| 100 | +# a class to wrap a runner and proxy/adapt model calls to runner calls |
| 101 | +class BaseToolProxy: |
| 102 | + TOOL_NAME: str |
| 103 | + RUNNABLE_CLASS: type[BaseToolRunnable] |
| 104 | + |
| 105 | + |
| 106 | +def make_tool_runnable_method( |
| 107 | + method_name: str, |
| 108 | + processors: dict[str, t.Callable[[t.Any], t.Any]] | None = None, |
| 109 | +) -> t.Callable[[BaseToolRunnable, t.Any], t.Any]: |
| 110 | + |
| 111 | + if processors is None: |
| 112 | + |
| 113 | + def _run(self: BaseToolRunnable, inputs: t.Any): |
| 114 | + method = getattr(self.model, method_name) |
| 115 | + return method(inputs) |
| 116 | + |
| 117 | + return _run |
| 118 | + |
| 119 | + preprocessor = processors.get("runner_in", tuple_to_path) |
| 120 | + postprocessor = processors.get("runner_out", path_to_tuple) |
| 121 | + |
| 122 | + def _run(self: BaseToolRunnable, inputs: t.Any) -> t.Any: |
| 123 | + method = getattr(self.model, method_name) |
| 124 | + processed_inputs = preprocessor(inputs) |
| 125 | + output = method(processed_inputs) |
| 126 | + processed_output = postprocessor(output) |
| 127 | + return processed_output |
| 128 | + |
| 129 | + return _run |
| 130 | + |
| 131 | + |
| 132 | +def make_tool_proxy_method( |
| 133 | + method_name: str, |
| 134 | + processors: dict[str, t.Callable[[t.Any], t.Any]] | None = None, |
| 135 | +) -> t.Callable[[BaseToolRunnable, t.Any], t.Any]: |
| 136 | + |
| 137 | + if processors is None: |
| 138 | + |
| 139 | + def _run(self: BaseToolProxy, inputs: t.Any): |
| 140 | + runner_method = getattr(self.runner, method_name) |
| 141 | + return runner_method.run(inputs) |
| 142 | + |
| 143 | + return _run |
| 144 | + |
| 145 | + # the order is revert for api |
| 146 | + preprocessor = processors.get("api_out", path_to_tuple) |
| 147 | + postprocessor = processors.get("api_in", tuple_to_path) |
| 148 | + |
| 149 | + def _run(self: BaseToolProxy, inputs: t.Any) -> t.Any: |
| 150 | + runner_method = getattr(self.runner, method_name) |
| 151 | + processed_inputs = preprocessor(inputs) |
| 152 | + output = runner_method.run(processed_inputs) |
| 153 | + processed_output = postprocessor(output) |
| 154 | + return processed_output |
| 155 | + |
| 156 | + return _run |
| 157 | + |
| 158 | + |
| 159 | +def create_proxy_class(tool_class: type[object], local: bool = False, gpu: bool = False) -> type[BaseToolProxy]: |
| 160 | + class ToolRunnable(BaseToolRunnable): |
| 161 | + |
| 162 | + SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") if gpu else ("cpu", ) |
| 163 | + SUPPORTS_CPU_MULTI_THREADING = True |
| 164 | + |
| 165 | + def __init__(self): |
| 166 | + self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| 167 | + self.model = tool_class(self.device) |
| 168 | + |
| 169 | + class ToolProxy(BaseToolProxy): |
| 170 | + |
| 171 | + TOOL_NAME = tool_class.__name__ |
| 172 | + RUNNABLE_CLASS: type[BaseToolRunnable] = ToolRunnable |
| 173 | + |
| 174 | + def __init__(self, runner_name: str | None = None): |
| 175 | + if not runner_name: |
| 176 | + runner_name = f"{tool_class.__name__}_runner".lower() |
| 177 | + self.runner = bentoml.Runner(self.RUNNABLE_CLASS, name=runner_name) |
| 178 | + |
| 179 | + # add method to runnable and proxy model method calls to |
| 180 | + # corresponding runner methods |
| 181 | + for e in dir(tool_class): |
| 182 | + if e.startswith("inference"): |
| 183 | + |
| 184 | + method = getattr(tool_class, e) |
| 185 | + |
| 186 | + if local: |
| 187 | + processors = None |
| 188 | + else: |
| 189 | + full_name = f"{tool_class.__name__}.{e}" |
| 190 | + processors = TOOL_DIST_PROCESSORS.get(full_name, dict()) |
| 191 | + |
| 192 | + ToolRunnable.add_method( |
| 193 | + make_tool_runnable_method(e, processors=processors), |
| 194 | + name=e, |
| 195 | + batchable=False, |
| 196 | + ) |
| 197 | + |
| 198 | + model_method = make_tool_proxy_method(e, processors=processors) |
| 199 | + model_method.name = method.name |
| 200 | + model_method.description = method.description |
| 201 | + setattr(ToolProxy, e, model_method) |
| 202 | + |
| 203 | + return ToolProxy |
| 204 | + |
| 205 | + |
| 206 | +# helper function to convert EnvVar or cli argument string to load_dict |
| 207 | +def parse_load_dict(s: str) -> dict[str, str]: |
| 208 | + return { |
| 209 | + e.split('_')[0].strip(): e.split('_')[1].strip() |
| 210 | + for e in s.split(',') |
| 211 | + } |
| 212 | + |
| 213 | + |
| 214 | +class BentoMLConversationBot(vc.ConversationBot): |
| 215 | + def __init__(self, load_dict: dict[str, str], local: bool = False): |
| 216 | + |
| 217 | + print(f"Initializing VisualChatGPT, load_dict={load_dict}") |
| 218 | + |
| 219 | + if 'ImageCaptioning' not in load_dict: |
| 220 | + raise ValueError("You have to load ImageCaptioning as a basic function for VisualChatGPT") |
| 221 | + |
| 222 | + self.models = {} |
| 223 | + # Load Basic Foundation Models |
| 224 | + for class_name, resource in load_dict.items(): |
| 225 | + gpu = resource.startswith("cuda") |
| 226 | + tool_class = getattr(vc, class_name) |
| 227 | + proxy_class = create_proxy_class(tool_class, local=local, gpu=gpu) |
| 228 | + self.models[proxy_class.TOOL_NAME] = proxy_class() |
| 229 | + |
| 230 | + # Load Template Foundation Models |
| 231 | + # for class_name, module in vc.__dict__.items(): |
| 232 | + # if getattr(module, 'template_model', False): |
| 233 | + # template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'} |
| 234 | + # loaded_names = set([type(e).TOOL_NAME for e in self.models.values() |
| 235 | + # if not e.template_model]) |
| 236 | + # if template_required_names.issubset(loaded_names): |
| 237 | + # template_class = getattr(vc, class_name) |
| 238 | + # self.models[class_name] = template_class( |
| 239 | + # **{name: self.models[name] for name in template_required_names}) |
| 240 | + |
| 241 | + print(f"All the Available Functions: {self.models}") |
| 242 | + |
| 243 | + self.tools = [] |
| 244 | + for instance in self.models.values(): |
| 245 | + for e in dir(instance): |
| 246 | + if e.startswith("inference"): |
| 247 | + func = getattr(instance, e) |
| 248 | + self.tools.append( |
| 249 | + Tool(name=func.name, description=func.description, func=func) |
| 250 | + ) |
| 251 | + self.llm = OpenAI(temperature=0) |
| 252 | + self.memory = ConversationBufferMemory( |
| 253 | + memory_key="chat_history", output_key="output" |
| 254 | + ) |
| 255 | + |
| 256 | + |
| 257 | +def create_gradio_blocks(bot): |
| 258 | + with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo: |
| 259 | + lang = gr.Radio(choices=["Chinese", "English"], value=None, label="Language") |
| 260 | + chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT") |
| 261 | + state = gr.State([]) |
| 262 | + with gr.Row(visible=False) as input_raws: |
| 263 | + with gr.Column(scale=0.7): |
| 264 | + txt = gr.Textbox( |
| 265 | + show_label=False, |
| 266 | + placeholder="Enter text and press enter, or upload an image", |
| 267 | + ).style(container=False) |
| 268 | + with gr.Column(scale=0.15, min_width=0): |
| 269 | + clear = gr.Button("Clear") |
| 270 | + with gr.Column(scale=0.15, min_width=0): |
| 271 | + btn = gr.UploadButton(label="🖼️", file_types=["image"]) |
| 272 | + |
| 273 | + lang.change(bot.init_agent, [lang], [input_raws, lang, txt, clear]) |
| 274 | + txt.submit(bot.run_text, [txt, state], [chatbot, state]) |
| 275 | + txt.submit(lambda: "", None, txt) |
| 276 | + btn.upload(bot.run_image, [btn, state, txt, lang], [chatbot, state, txt]) |
| 277 | + clear.click(bot.memory.clear) |
| 278 | + clear.click(lambda: [], None, chatbot) |
| 279 | + clear.click(lambda: [], None, state) |
| 280 | + return demo |
| 281 | + |
| 282 | + |
| 283 | +def create_bentoml_service(bot, name="bentoml-visual-chatgpt", gradio_blocks=None): |
| 284 | + runners = [model.runner for model in bot.models.values()] |
| 285 | + svc = bentoml.Service(name, runners=runners) |
| 286 | + |
| 287 | + |
| 288 | + # Dummy api endpoint |
| 289 | + @svc.api(input=JSON(), output=JSON()) |
| 290 | + def echo(d): |
| 291 | + return d |
| 292 | + |
| 293 | + if gradio_blocks: |
| 294 | + svc.mount_asgi_app(gradio_blocks.app, path="/ui") |
| 295 | + |
| 296 | + return svc |
0 commit comments