Skip to content

Commit c81cdc6

Browse files
authored
experimental: add multi-modal end to end RAG example (NVIDIA#41)
1 parent 3d29acf commit c81cdc6

23 files changed

+1671
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.pyc
2+
__pycache__
3+
vectorstore/image_references
4+
vectorstore/table_references
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import random
17+
import os
18+
import base64
19+
import datetime
20+
import argparse
21+
import pandas as pd
22+
from PIL import Image
23+
from io import BytesIO
24+
25+
import streamlit as st
26+
import streamlit_analytics
27+
from streamlit_feedback import streamlit_feedback
28+
29+
from bot_config.utils import get_config
30+
from utils.memory import init_memory, get_summary, add_history_to_memory
31+
from guardrails.fact_check import fact_check
32+
from llm.llm_client import LLMClient
33+
from retriever.embedder import NVIDIAEmbedders, HuggingFaceEmbeders
34+
from retriever.vector import MilvusVectorClient, QdrantClient
35+
from retriever.retriever import Retriever
36+
from utils.feedback import feedback_kwargs
37+
38+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
39+
from langchain_core.messages import HumanMessage
40+
41+
llm_client = LLMClient("mixtral_8x7b")
42+
43+
# Start the analytics service (using browser.usageStats)
44+
streamlit_analytics.start_tracking()
45+
46+
# get the config from the command line, or set a default
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument('-c', '--config', help = "Provide a chatbot config to run the deployment")
49+
50+
st.set_page_config(
51+
page_title = "Multimodal RAG Assistant",
52+
page_icon = ":speech_balloon:",
53+
layout = "wide",
54+
)
55+
56+
@st.cache_data()
57+
def load_config(cfg_arg):
58+
try:
59+
config = get_config(os.path.join("bot_config", cfg_arg + ".config"))
60+
return config
61+
except Exception as e:
62+
print("Error loading config:", e)
63+
return None
64+
65+
args = vars(parser.parse_args())
66+
cfg_arg = args["config"]
67+
68+
# Initialize session state variables if not already present
69+
70+
if 'prompt_value' not in st.session_state:
71+
st.session_state['prompt_value'] = None
72+
73+
if cfg_arg and "config" not in st.session_state:
74+
st.session_state.config = load_config(cfg_arg)
75+
76+
if "config" not in st.session_state:
77+
st.session_state.config = load_config("multimodal")
78+
print(st.session_state.config)
79+
80+
if "messages" not in st.session_state:
81+
st.session_state.messages = [
82+
{"role": "assistant", "content": "Ask me a question!"}
83+
]
84+
if "sources" not in st.session_state:
85+
st.session_state.sources = []
86+
87+
if "image_query" not in st.session_state:
88+
st.session_state.image_query = ""
89+
90+
if "queried" not in st.session_state:
91+
st.session_state.queried = False
92+
93+
if "memory" not in st.session_state:
94+
st.session_state.memory = init_memory(llm_client.llm, st.session_state.config['summary_prompt'])
95+
memory = st.session_state.memory
96+
97+
98+
with st.sidebar:
99+
prev_cfg = st.session_state.config
100+
try:
101+
defaultidx = [["multimodal"]].index(st.session_state.config["name"].lower())
102+
except:
103+
defaultidx = 0
104+
st.header("Bot Configuration")
105+
cfg_name = st.selectbox("Select a configuration/type of bot.", (["multimodal"]), index=defaultidx)
106+
st.session_state.config = get_config(os.path.join("bot_config", cfg_name+".config"))
107+
config = get_config(os.path.join("bot_config", cfg_name+".config"))
108+
if st.session_state.config != prev_cfg:
109+
st.experimental_rerun()
110+
111+
st.success("Select an experience above.")
112+
113+
st.header("Image Input Query")
114+
115+
# with st.form("my-form", clear_on_submit=True):
116+
uploaded_file = st.file_uploader("Upload an image (JPG/JPEG/PNG) along with a text input:", accept_multiple_files = False)
117+
# submitted = st.form_submit_button("UPLOAD!")
118+
119+
if uploaded_file and st.session_state.image_query == "":
120+
st.success("Image loaded for multimodal RAG Q&A.")
121+
st.session_state.image_query = os.path.join("/tmp/", uploaded_file.name)
122+
with open(st.session_state.image_query,"wb") as f:
123+
f.write(uploaded_file.read())
124+
125+
with st.spinner("Getting image description using NeVA"):
126+
neva = LLMClient("neva_22b")
127+
image = Image.open(st.session_state.image_query).convert("RGB")
128+
buffered = BytesIO()
129+
image.save(buffered, format="JPEG", quality=20) # Quality = 20 is a workaround (WAR)
130+
b64_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
131+
res = neva.multimodal_invoke(b64_string, creativity = 0, quality = 9, complexity = 0, verbosity = 9)
132+
st.session_state.image_query = res.content
133+
134+
if not uploaded_file:
135+
st.session_state.image_query = ""
136+
137+
# Page title
138+
st.header(config["page_title"])
139+
st.markdown(config["instructions"])
140+
141+
# init the vector client
142+
if "vector_client" not in st.session_state or st.session_state.vector_client.collection_name != config["core_docs_directory_name"]:
143+
try:
144+
st.session_state.vector_client = MilvusVectorClient(hostname="localhost", port="19530", collection_name=config["core_docs_directory_name"])
145+
except Exception as e:
146+
st.write(f"Failed to connect to Milvus vector DB, exception: {e}. Please follow steps to initialize the vector DB, or upload documents to the knowledge base and add them to the vector DB.")
147+
st.stop()
148+
# init the embedder
149+
if "query_embedder" not in st.session_state:
150+
st.session_state.query_embedder = NVIDIAEmbedders(name="nvolveqa_40k", type="query")
151+
# init the retriever
152+
if "retriever" not in st.session_state:
153+
st.session_state.retriever = Retriever(embedder=st.session_state.query_embedder , vector_client=st.session_state.vector_client)
154+
retriever = st.session_state.retriever
155+
156+
messages = st.session_state.messages
157+
158+
for n, msg in enumerate(messages):
159+
st.chat_message(msg["role"]).write(msg["content"])
160+
if msg["role"] == "assistant" and n > 1:
161+
with st.chat_message("assistant"):
162+
ctr = 0
163+
for key in st.session_state.sources.keys():
164+
ctr += 1
165+
with st.expander(os.path.basename(key)):
166+
source = st.session_state.sources[key]
167+
if "source" in source["doc_metadata"]:
168+
source_str = source["doc_metadata"]["source"]
169+
if "page" in source_str and "block" in source_str:
170+
download_path = source_str.split("page")[0].strip("-")+".pdf"
171+
file_name = os.path.basename(download_path)
172+
try:
173+
f = open(download_path, 'rb').read()
174+
st.download_button("Download now", f, key=download_path+str(n)+str(ctr), file_name=file_name)
175+
except:
176+
st.write("failed to provide download for this file: ", file_name)
177+
elif "ppt" in source_str:
178+
ppt_path = os.path.basename(source_str).replace('.pptx', '.pdf').replace('.ppt', '.pdf')
179+
download_path = os.path.join("vectorstore/ppt_references", ppt_path)
180+
file_name = os.path.basename(download_path)
181+
f = open(download_path, "rb").read()
182+
st.download_button("Download now", f, key=download_path+str(n)+str(ctr), file_name=file_name)
183+
else:
184+
download_path = source["doc_metadata"]["image"]
185+
file_name = os.path.basename(download_path)
186+
try:
187+
f = open(download_path, 'rb').read()
188+
st.download_button("Download now", f, key=download_path+str(n)+str(ctr), file_name=file_name)
189+
except Exception as e:
190+
print("failed to provide download for ", file_name)
191+
print(f"Exception: {e}")
192+
if "type" in source["doc_metadata"]:
193+
if source["doc_metadata"]["type"] == "table":
194+
# get the pandas table and show in Streamlit
195+
df = pd.read_excel(source["doc_metadata"]["dataframe"])
196+
st.write(df)
197+
image = Image.open(source["doc_metadata"]["image"])
198+
st.image(image, caption = os.path.basename(source["doc_metadata"]["source"]))
199+
elif source["doc_metadata"]["type"] == "image":
200+
image = Image.open(source["doc_metadata"]["image"])
201+
st.image(image, caption = os.path.basename(source["doc_metadata"]["source"]))
202+
else:
203+
st.write(source["doc_content"])
204+
else:
205+
st.write(source["doc_content"])
206+
207+
feedback_key = f"feedback_{int(n/2)}"
208+
209+
if feedback_key not in st.session_state:
210+
st.session_state[feedback_key] = None
211+
col1, col2 = st.columns(2)
212+
with col1:
213+
st.write("**Please provide feedback by clicking one of these icons:**")
214+
with col2:
215+
streamlit_feedback(**feedback_kwargs, args=[messages[-2]["content"].strip(), messages[-1]["content"].strip()], key=feedback_key, align="flex-start")
216+
217+
# Check if the topic has changed
218+
if st.session_state['prompt_value'] == None:
219+
prompt_value = "Hi, what can you help me with?"
220+
st.session_state["prompt_value"] = prompt_value
221+
222+
colx, coly = st.columns([1,20])
223+
224+
placeholder = st.empty()
225+
with placeholder:
226+
with st.form("chat-form", clear_on_submit=True):
227+
instr = 'Hi there! Enter what you want to let me know here.'
228+
col1, col2 = st.columns([20,2])
229+
with col1:
230+
prompt_value = st.session_state["prompt_value"]
231+
prompt = st.text_input(
232+
instr,
233+
value=prompt_value,
234+
placeholder=instr,
235+
label_visibility='collapsed'
236+
)
237+
with col2:
238+
submitted = st.form_submit_button("Chat")
239+
if submitted and len(prompt) > 0:
240+
placeholder.empty()
241+
st.session_state['prompt_value'] = None
242+
243+
if len(prompt) > 0 and submitted == True:
244+
with st.chat_message("user"):
245+
st.write(prompt)
246+
247+
if st.session_state.image_query:
248+
prompt = f"\nI have uploaded an image with the following description: {st.session_state.image_query}" + "Here is the question: " + prompt
249+
transformed_query = {"text": prompt}
250+
messages.append({"role": "user", "content": transformed_query["text"]})
251+
252+
with st.spinner("Obtaining references from documents..."):
253+
BASE_DIR = os.path.abspath("vectorstore")
254+
CORE_DIR = os.path.join(BASE_DIR, config["core_docs_directory_name"])
255+
context, sources = retriever.get_relevant_docs(transformed_query["text"])
256+
st.session_state.sources = sources
257+
augmented_prompt = "Relevant documents:" + context + "\n\n[[QUESTION]]\n\n" + transformed_query["text"] #+ "\n" + config["footer"]
258+
system_prompt = config["header"]
259+
# Display assistant response in chat message container
260+
with st.chat_message("assistant"):
261+
response = llm_client.chat_with_prompt(system_prompt, augmented_prompt)
262+
message_placeholder = st.empty()
263+
full_response = ""
264+
for chunk in response:
265+
full_response += chunk
266+
message_placeholder.markdown(full_response + "▌")
267+
message_placeholder.markdown(full_response)
268+
269+
add_history_to_memory(memory, transformed_query["text"], full_response)
270+
with st.spinner("Running fact checking/guardrails..."):
271+
full_response += "\n\nFact Check result: "
272+
res = fact_check(context, transformed_query["text"], full_response)
273+
for response in res:
274+
full_response += response
275+
message_placeholder.markdown(full_response + "▌")
276+
message_placeholder.markdown(full_response)
277+
278+
with st.chat_message("assistant"):
279+
messages.append(
280+
{"role": "assistant", "content": full_response}
281+
)
282+
st.write(full_response)
283+
st.experimental_rerun()
284+
elif len(messages) > 1:
285+
summary_placeholder = st.empty()
286+
summary_button = summary_placeholder.button("Click to see summary")
287+
if summary_button:
288+
with st.chat_message("assistant"):
289+
summary_placeholder.empty()
290+
st.markdown(get_summary(memory))
291+
292+
streamlit_analytics.stop_tracking()

0 commit comments

Comments
 (0)