Skip to content

Commit 75ccb42

Browse files
Update FM streaming RAG demo to use new NIMs (NVIDIA#234)
* Update FM streaming RAG demo for DC AI Summit - Uses new embedding and reranking NIMs - Uses Parakeet ASR NIM - Includes GNU Radio container - Updates deployment scripts for deploying across multiple machines - Various bug fixes and QoL updates * Add diagramss for README * Add README * By default, use build.nvidia.com endpoints for non-ASR NIMs * Restart Riva gRPC service on exceptions * Add GNU Radio folder * Update README to include frontend URL * Fix LLM model options * Fix bug when not using local Embedding / Reranking models
1 parent 63bb57f commit 75ccb42

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+914
-2131
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
deploy/.keys
2+
milvus/volumes

community/fm-asr-streaming-rag/README.md

Lines changed: 85 additions & 84 deletions
Large diffs are not rendered by default.

community/fm-asr-streaming-rag/chain-server/accumulator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
17+
import datetime
18+
import requests
1619
from common import get_logger
1720
from database import TimestampDatabase
1821
from langchain.text_splitter import RecursiveCharacterTextSplitter
1922

2023
logger = get_logger(__name__)
2124

25+
FRONTEND_URI = os.environ.get('FRONTEND_URI', None)
26+
2227
#todo: Multi-thread to handle multiple concurrent streams
2328
#todo: Add time-triggered embedding (i.e. embed after N seconds if no updates)
2429
class TextAccumulator:
@@ -45,4 +50,14 @@ def update(self, source_id, text):
4550
self.timestamp_db.insert_docs(new_docs, source_id)
4651
self.db_interface.add_docs(new_docs, source_id)
4752

53+
for doc in new_docs:
54+
endpoint = f"http://{FRONTEND_URI}/app/update_finalized_transcript"
55+
time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
56+
try:
57+
client_response = requests.post(endpoint, json={'transcript': f"[{time}] {doc}"})
58+
logger.debug(f'Posted update_finalized_transcript: {client_response._content}')
59+
logger.debug("--------------------------")
60+
except requests.exceptions.ConnectionError:
61+
logger.error(f"Failed to connect to the '{endpoint}' endpoint")
62+
4863
return {"status": f"Added {len(new_docs)} entries"}

community/fm-asr-streaming-rag/chain-server/chains.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Union
1716
from copy import copy
1817
from datetime import datetime, timedelta
1918
from langchain_core.output_parsers import StrOutputParser
2019
from langchain_core.prompts import ChatPromptTemplate
2120
from langchain.docstore.document import Document
2221

2322
from accumulator import TextAccumulator
24-
from retriever import NemoRetrieverInterface, NvidiaApiInterface
23+
from retriever import NVRetriever
2524
from common import get_logger, LLMConfig, TimeResponse, UserIntent
2625
from utils import get_llm, classify
2726
from prompts import RAG_PROMPT, INTENT_PROMPT, RECENCY_PROMPT, SUMMARIZATION_PROMPT
@@ -36,7 +35,7 @@ def __init__(
3635
self,
3736
config: LLMConfig,
3837
text_accumulator: TextAccumulator,
39-
retv_interface: Union[NemoRetrieverInterface, NvidiaApiInterface]
38+
retv_interface: NVRetriever
4039
):
4140
self.config = config
4241
self.text_accumulator = text_accumulator

community/fm-asr-streaming-rag/chain-server/common.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@
2424
from typing import Literal
2525
from langchain_community.utils.math import cosine_similarity
2626

27-
USE_NEMO_RETRIEVER = os.environ.get('USE_NEMO_RETRIEVER', 'False').lower() in ('true', '1')
28-
NVIDIA_API_KEY = os.environ.get('NVIDIA_API_KEY', 'null')
27+
NVIDIA_API_KEY = os.environ.get('NVIDIA_API_KEY', 'null')
28+
LLM_URI = os.environ.get('LLM_URI', None)
29+
RERANKING_MODEL = os.environ.get('RERANK_MODEL', None)
30+
RERANKING_URI = os.environ.get('RERANK_URI', None)
31+
EMBEDDING_MODEL = os.environ.get('EMBED_MODEL', None)
32+
EMBEDDING_URI = os.environ.get('EMBED_URI', None)
33+
MAX_DOCS = int(os.environ.get('MAX_DOCS_RETR', 8))
2934

3035
def get_logger(name):
3136
LOG_LEVEL = logging.getLevelName(os.environ.get('CHAIN_LOG_LEVEL', 'WARN').upper())
@@ -51,7 +56,7 @@ class LLMConfig(BaseModel):
5156
)
5257
# Model choice
5358
name: str = Field("Name of LLM instance to use")
54-
engine: str = Field("Name of engine ['nvai-api-endpoint', 'triton-trt-llm']")
59+
engine: str = Field("Name of engine ['nvai-api-endpoint', 'local-nim']")
5560
# Chain parameters
5661
use_knowledge_base: bool = Field(
5762
description="Whether to use a knowledge base", default=True
@@ -95,11 +100,7 @@ def nvapi_embedding(text):
95100
return embeddings
96101

97102
VALID_TIME_UNITS = ["seconds", "minutes", "hours", "days"]
98-
TIME_VECTORS = None # Lazy loading in 'sanitize_time_unit'
99-
if USE_NEMO_RETRIEVER:
100-
embedding_service = nemo_embedding
101-
else:
102-
embedding_service = nvapi_embedding
103+
TIME_VECTORS = nvapi_embedding(VALID_TIME_UNITS)
103104

104105
def sanitize_time_unit(time_unit):
105106
"""
@@ -111,10 +112,7 @@ def sanitize_time_unit(time_unit):
111112
if time_unit in VALID_TIME_UNITS:
112113
return time_unit
113114

114-
if TIME_VECTORS is None:
115-
TIME_VECTORS = embedding_service(VALID_TIME_UNITS)
116-
117-
unit_embedding = embedding_service([time_unit])
115+
unit_embedding = nvapi_embedding([time_unit])
118116
similarity = cosine_similarity(unit_embedding, TIME_VECTORS)
119117
return VALID_TIME_UNITS[np.argmax(similarity)]
120118

community/fm-asr-streaming-rag/chain-server/database.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import numpy as np
3131

3232
from common import get_logger
33-
from datetime import datetime
33+
from datetime import datetime, timedelta
3434
from langchain.docstore.document import Document
3535

3636
logger = get_logger(__name__)
@@ -87,8 +87,8 @@ def recent(self, tstamp):
8787
def past(self, tstamp, window=90):
8888
""" Return entries within 'window' seconds of tstamp
8989
"""
90-
tstart = tstamp - datetime.timedelta(seconds=window)
91-
tend = tstamp + datetime.timedelta(seconds=window)
90+
tstart = tstamp - timedelta(seconds=window)
91+
tend = tstamp + timedelta(seconds=window)
9292
self.cursor.execute('SELECT * FROM messages WHERE timestamp BETWEEN ? AND ?', (tstart, tend))
9393
docs = self.cursor.fetchall()
9494
return [self.reformat(doc) for doc in docs]

community/fm-asr-streaming-rag/chain-server/prompts.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def format_json(text: str):
4949
recent timeframe. Examples: "Can you summarize the last hour of content?", "What have the main \
5050
topics been over the last 5 minutes?", "Tell me the main stories of the past 2 hours.".
5151
- 'TimeWindow': If the user is asking about the focus of the conversation from a specified time in \
52-
the past. Examples: "What were they talking about 15 minutes ago?", "What was the focus an hour ago?".
52+
the past. Examples: "What were they talking about 15 minutes ago?", "What was the focus an hour ago?", "What are they talking about right now?".
5353
- If the user's intent is not clear, or if the intent cannot be confidently determined, classify \
5454
this as 'Unknown'.
5555
@@ -101,11 +101,17 @@ def format_json(text: str):
101101
f"'{recency_examples[1]}' --> '{format_json(recency_examples_obj[1].model_dump_json())}'\n" +
102102
f"'{recency_examples[2]}' --> '{format_json(recency_examples_obj[2].model_dump_json())}'\n" + """
103103
104-
Convert the user input below into this JSON format.
104+
Convert the user input below into this JSON format. Make sure you use valid JSON and \
105+
don't worry about escaping quotes, just give valid JSON blobs.
105106
""")
106107

107108
SUMMARIZATION_PROMPT = """\
108109
You are a sophisticated summarization tool designed to condense large blocks \
109110
of text into a concise summary. Given the user text, reduce the character \
110111
count by distilling into only the most important information.
112+
113+
Do not say you are summarizing, i.e. do not say "Here's a summary...", just \
114+
condense the text to the best of your abilities.
115+
116+
Summary:
111117
"""
Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,6 @@
11
fastapi==0.104.1
22
uvicorn[standard]==0.24.0
3-
python-multipart==0.0.6
4-
unstructured[all-docs]==0.11.2
5-
sentence-transformers==2.2.2
6-
llama-index==0.9.22
73
pymilvus==2.3.5
8-
dataclass-wizard==0.22.2
9-
opencv-python==4.8.0.74
10-
minio==7.2.0
11-
asyncpg==0.29.0
12-
psycopg2-binary==2.9.9
13-
pgvector==0.2.4
14-
langchain==0.1.14
15-
langchain-core==0.1.40
16-
langchain-nvidia-ai-endpoints==0.0.12
17-
langchain-nvidia-trt==0.0.1rc0
18-
nemollm==0.3.4
19-
opentelemetry-sdk==1.21.0
20-
opentelemetry-api==1.21.0
21-
opentelemetry-exporter-otlp-proto-grpc==1.21.0
4+
langchain==0.2.6
5+
langchain_core==0.2.25
6+
langchain_nvidia_ai_endpoints==0.2.0

0 commit comments

Comments
 (0)