Skip to content

Commit 85f2329

Browse files
feat: Add intelligent tool calling
1 parent 7c137bb commit 85f2329

File tree

4 files changed

+632
-101
lines changed

4 files changed

+632
-101
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ dependencies = [
1212
"python-dotenv>=1.1.0",
1313
"python-multipart>=0.0.20",
1414
"uvicorn[standard]>=0.34.2",
15+
"faiss-cpu>=1.7.4",
16+
"sentence-transformers>=2.2.2",
1517
]
1618

1719
[tool.setuptools]

registry/main.py

Lines changed: 231 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from typing import Annotated, List, Set
99
from datetime import datetime, timezone
1010

11+
import faiss
12+
import numpy as np
13+
from sentence_transformers import SentenceTransformer
14+
1115
from fastapi import (
1216
FastAPI,
1317
Request,
@@ -68,6 +72,18 @@
6872
# LOG_FILE_PATH = BASE_DIR / "registry.log"
6973
LOG_FILE_PATH = CONTAINER_LOG_DIR / "registry.log"
7074

75+
# --- FAISS Vector DB Configuration --- START
76+
FAISS_INDEX_PATH = SERVERS_DIR / "service_index.faiss"
77+
FAISS_METADATA_PATH = SERVERS_DIR / "service_index_metadata.json"
78+
EMBEDDING_DIMENSION = 384 # For all-MiniLM-L6-v2
79+
embedding_model = None # Will be loaded in lifespan
80+
faiss_index = None # Will be loaded/created in lifespan
81+
# Stores: { service_path: {"id": faiss_internal_id, "text_for_embedding": "...", "full_server_info": { ... }} }
82+
# faiss_internal_id is the ID used with faiss_index.add_with_ids()
83+
faiss_metadata_store = {}
84+
next_faiss_id_counter = 0
85+
# --- FAISS Vector DB Configuration --- END
86+
7187
# --- REMOVE Logging Setup from here --- START
7288
# # Ensure log directory exists
7389
# CONTAINER_LOG_DIR.mkdir(parents=True, exist_ok=True)
@@ -107,6 +123,159 @@
107123
# --- WebSocket Connection Management ---
108124
active_connections: Set[WebSocket] = set()
109125

126+
# --- FAISS Helper Functions --- START
127+
128+
def _get_text_for_embedding(server_info: dict) -> str:
129+
"""Prepares a consistent text string from server info for embedding."""
130+
name = server_info.get("server_name", "")
131+
description = server_info.get("description", "")
132+
tags = server_info.get("tags", [])
133+
tag_string = ", ".join(tags)
134+
return f"Name: {name}\\nDescription: {description}\\nTags: {tag_string}"
135+
136+
def load_faiss_data():
137+
global faiss_index, faiss_metadata_store, embedding_model, next_faiss_id_counter, CONTAINER_REGISTRY_DIR, SERVERS_DIR
138+
logger.info("Loading FAISS data and embedding model...")
139+
140+
SERVERS_DIR.mkdir(parents=True, exist_ok=True)
141+
142+
try:
143+
model_cache_path = CONTAINER_REGISTRY_DIR / ".cache"
144+
model_cache_path.mkdir(parents=True, exist_ok=True)
145+
# Set SENTENCE_TRANSFORMERS_HOME to use the defined cache path
146+
# This needs to be set before SentenceTransformer is imported or used if we want to control the cache location this way.
147+
# However, setting it here is fine as it's done before the model is instantiated.
148+
original_st_home = os.environ.get('SENTENCE_TRANSFORMERS_HOME')
149+
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_cache_path)
150+
logger.info(f"Attempting to load SentenceTransformer model 'all-MiniLM-L6-v2'. Cache: {model_cache_path}")
151+
152+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
153+
154+
# Restore original environment variable if it was set
155+
if original_st_home:
156+
os.environ['SENTENCE_TRANSFORMERS_HOME'] = original_st_home
157+
else:
158+
del os.environ['SENTENCE_TRANSFORMERS_HOME'] # Remove if not originally set
159+
160+
logger.info("SentenceTransformer model 'all-MiniLM-L6-v2' loaded successfully.")
161+
except Exception as e:
162+
logger.error(f"Failed to load SentenceTransformer model: {e}", exc_info=True)
163+
embedding_model = None
164+
165+
if FAISS_INDEX_PATH.exists() and FAISS_METADATA_PATH.exists():
166+
try:
167+
logger.info(f"Loading FAISS index from {FAISS_INDEX_PATH}")
168+
faiss_index = faiss.read_index(str(FAISS_INDEX_PATH))
169+
logger.info(f"Loading FAISS metadata from {FAISS_METADATA_PATH}")
170+
with open(FAISS_METADATA_PATH, "r") as f:
171+
loaded_metadata = json.load(f)
172+
faiss_metadata_store = loaded_metadata.get("metadata", {})
173+
next_faiss_id_counter = loaded_metadata.get("next_id", 0)
174+
logger.info(f"FAISS data loaded. Index size: {faiss_index.ntotal if faiss_index else 0}. Next ID: {next_faiss_id_counter}")
175+
if faiss_index and faiss_index.d != EMBEDDING_DIMENSION:
176+
logger.warning(f"Loaded FAISS index dimension ({faiss_index.d}) differs from expected ({EMBEDDING_DIMENSION}). Re-initializing.")
177+
faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(EMBEDDING_DIMENSION))
178+
faiss_metadata_store = {}
179+
next_faiss_id_counter = 0
180+
except Exception as e:
181+
logger.error(f"Error loading FAISS data: {e}. Re-initializing.", exc_info=True)
182+
faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(EMBEDDING_DIMENSION))
183+
faiss_metadata_store = {}
184+
next_faiss_id_counter = 0
185+
else:
186+
logger.info("FAISS index or metadata not found. Initializing new.")
187+
faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(EMBEDDING_DIMENSION))
188+
faiss_metadata_store = {}
189+
next_faiss_id_counter = 0
190+
191+
def save_faiss_data():
192+
global faiss_index, faiss_metadata_store, next_faiss_id_counter
193+
if faiss_index is None:
194+
logger.error("FAISS index is not initialized. Cannot save.")
195+
return
196+
try:
197+
SERVERS_DIR.mkdir(parents=True, exist_ok=True) # Ensure directory exists
198+
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH} (Size: {faiss_index.ntotal})")
199+
faiss.write_index(faiss_index, str(FAISS_INDEX_PATH))
200+
logger.info(f"Saving FAISS metadata to {FAISS_METADATA_PATH}")
201+
with open(FAISS_METADATA_PATH, "w") as f:
202+
json.dump({"metadata": faiss_metadata_store, "next_id": next_faiss_id_counter}, f, indent=2)
203+
logger.info("FAISS data saved successfully.")
204+
except Exception as e:
205+
logger.error(f"Error saving FAISS data: {e}", exc_info=True)
206+
207+
async def add_or_update_service_in_faiss(service_path: str, server_info: dict):
208+
global faiss_index, faiss_metadata_store, embedding_model, next_faiss_id_counter
209+
210+
if embedding_model is None or faiss_index is None:
211+
logger.error("Embedding model or FAISS index not initialized. Cannot add/update service in FAISS.")
212+
return
213+
214+
logger.info(f"Attempting to add/update service '{service_path}' in FAISS.")
215+
text_to_embed = _get_text_for_embedding(server_info)
216+
217+
current_faiss_id = -1
218+
needs_new_embedding = True # Assume new embedding is needed
219+
220+
existing_entry = faiss_metadata_store.get(service_path)
221+
222+
if existing_entry:
223+
current_faiss_id = existing_entry["id"]
224+
if existing_entry.get("text_for_embedding") == text_to_embed:
225+
needs_new_embedding = False
226+
logger.info(f"Text for embedding for '{service_path}' has not changed. Will update metadata store only if server_info differs.")
227+
else:
228+
logger.info(f"Text for embedding for '{service_path}' has changed. Re-embedding required.")
229+
else: # New service
230+
current_faiss_id = next_faiss_id_counter
231+
next_faiss_id_counter += 1
232+
logger.info(f"New service '{service_path}'. Assigning new FAISS ID: {current_faiss_id}.")
233+
needs_new_embedding = True # Definitely needs embedding
234+
235+
if needs_new_embedding:
236+
try:
237+
# Run model encoding in a separate thread to avoid blocking asyncio event loop
238+
embedding = await asyncio.to_thread(embedding_model.encode, [text_to_embed])
239+
embedding_np = np.array([embedding[0]], dtype=np.float32)
240+
241+
ids_to_remove = np.array([current_faiss_id])
242+
if existing_entry: # Only attempt removal if it was an existing entry
243+
try:
244+
# remove_ids returns number of vectors removed.
245+
# It's okay if the ID isn't found (returns 0).
246+
num_removed = faiss_index.remove_ids(ids_to_remove)
247+
if num_removed > 0:
248+
logger.info(f"Removed {num_removed} old vector(s) for FAISS ID {current_faiss_id} ({service_path}).")
249+
else:
250+
logger.info(f"No old vector found for FAISS ID {current_faiss_id} ({service_path}) during update, or ID not in index.")
251+
except Exception as e_remove: # Should be rare with IndexIDMap if ID was valid type
252+
logger.warning(f"Issue removing FAISS ID {current_faiss_id} for {service_path}: {e_remove}. Proceeding to add.")
253+
254+
faiss_index.add_with_ids(embedding_np, np.array([current_faiss_id]))
255+
logger.info(f"Added/Updated vector for '{service_path}' with FAISS ID {current_faiss_id}.")
256+
except Exception as e:
257+
logger.error(f"Error encoding or adding embedding for '{service_path}': {e}", exc_info=True)
258+
return # Don't update metadata or save if embedding failed
259+
260+
# Update metadata store if new, or if text changed, or if full_server_info changed
261+
# --- Enrich server_info with is_enabled status before storing --- START
262+
enriched_server_info = server_info.copy()
263+
enriched_server_info["is_enabled"] = MOCK_SERVICE_STATE.get(service_path, False) # Default to False if not found
264+
# --- Enrich server_info with is_enabled status before storing --- END
265+
266+
if existing_entry is None or needs_new_embedding or existing_entry.get("full_server_info") != enriched_server_info:
267+
faiss_metadata_store[service_path] = {
268+
"id": current_faiss_id,
269+
"text_for_embedding": text_to_embed,
270+
"full_server_info": enriched_server_info # Store the enriched server_info
271+
}
272+
logger.debug(f"Updated faiss_metadata_store for '{service_path}'.")
273+
await asyncio.to_thread(save_faiss_data) # Persist changes in a thread
274+
else:
275+
logger.debug(f"No changes to FAISS vector or enriched full_server_info for '{service_path}'. Skipping save.")
276+
277+
# --- FAISS Helper Functions --- END
278+
110279
async def broadcast_health_status():
111280
"""Sends the current health status to all connected WebSocket clients."""
112281
if active_connections:
@@ -662,14 +831,23 @@ async def perform_single_health_check(path: str) -> tuple[str, datetime | None]:
662831
# Save the updated server info to its file
663832
if not save_server_to_file(REGISTERED_SERVERS[path]):
664833
logger.error(f"ERROR: Failed to save updated tool list/count for {path} to file.")
834+
# --- Update FAISS after tool list/count change --- START
835+
# No explicit call here, will be handled by the one at the end of perform_single_health_check
836+
# logger.info(f"Updating FAISS metadata for '{path}' after tool list/count update.")
837+
# await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path]) # Moved to end
838+
# --- Update FAISS after tool list/count change --- END
665839
else:
666840
logger.info(f"Tool list for {path} remains unchanged. No update needed.")
667841
else:
668842
logger.info(f"Failed to retrieve tool list for healthy service {path}. List/Count remains unchanged.")
843+
# Even if tool list fetch failed, server is healthy.
844+
# FAISS update will occur at the end of this function with current REGISTERED_SERVERS[path].
669845
else:
670846
# This case should technically not be reachable due to earlier url check
671847
logger.info(f"Cannot fetch tool list for {path}: proxy_pass_url is missing.")
672848
# --- Check for transition to healthy state --- END
849+
# If it was already healthy, and tools changed, the above block (current_tool_list_str != new_tool_list_str) handles it.
850+
# The FAISS update with the latest REGISTERED_SERVERS[path] will happen at the end of this function.
673851

674852
elif proc.returncode == 28:
675853
current_status = f"error: timeout ({HEALTH_CHECK_TIMEOUT_SECONDS}s)"
@@ -703,6 +881,11 @@ async def perform_single_health_check(path: str) -> tuple[str, datetime | None]:
703881
SERVER_HEALTH_STATUS[path] = current_status
704882
logger.info(f"Final health status for {path}: {current_status}")
705883

884+
# --- Update FAISS with final server_info state after health check attempt ---
885+
if path in REGISTERED_SERVERS and embedding_model and faiss_index is not None:
886+
logger.info(f"Updating FAISS metadata for '{path}' post health check (status: {current_status}).")
887+
await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path])
888+
706889
# --- Regenerate Nginx if status affecting it changed --- START
707890
# Check if the service is enabled AND its Nginx-relevant status changed
708891
if is_enabled:
@@ -801,8 +984,25 @@ async def lifespan(app: FastAPI):
801984
logger.info("Logging configured. Running startup tasks...") # Now logger is configured
802985
# --- Configure Logging INSIDE lifespan --- END
803986

987+
# 0. Load FAISS data and embedding model
988+
load_faiss_data() # Loads model, empty index or existing index. Synchronous.
989+
804990
# 1. Load server definitions and persisted enabled/disabled state
805-
load_registered_servers_and_state()
991+
load_registered_servers_and_state() # This populates REGISTERED_SERVERS. Synchronous.
992+
993+
# 1.5 Sync FAISS with loaded servers (initial build or update)
994+
if embedding_model and faiss_index is not None: # Check faiss_index is not None
995+
logger.info("Performing initial FAISS synchronization with loaded server definitions...")
996+
sync_tasks = []
997+
for path, server_info in REGISTERED_SERVERS.items():
998+
# add_or_update_service_in_faiss is async, can be gathered
999+
sync_tasks.append(add_or_update_service_in_faiss(path, server_info))
1000+
1001+
if sync_tasks:
1002+
await asyncio.gather(*sync_tasks)
1003+
logger.info("Initial FAISS synchronization complete.")
1004+
else:
1005+
logger.warning("Skipping initial FAISS synchronization: embedding model or FAISS index not ready.")
8061006

8071007
# 2. Perform initial health checks concurrently for *enabled* services
8081008
logger.info("Performing initial health checks for enabled services...")
@@ -838,6 +1038,10 @@ async def lifespan(app: FastAPI):
8381038
else:
8391039
status, _ = result # Unpack the result tuple
8401040
logger.info(f"Initial health check completed for {path}: Status = {status}")
1041+
# Update FAISS with potentially changed server_info (e.g., num_tools from health check)
1042+
if path in REGISTERED_SERVERS and embedding_model and faiss_index is not None:
1043+
# This runs after each health check result, can be awaited individually
1044+
await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path])
8411045
else:
8421046
logger.info("No services are initially enabled.")
8431047

@@ -1063,6 +1267,14 @@ async def toggle_service_route(
10631267
# Update global state directly when disabling
10641268
SERVER_HEALTH_STATUS[service_path] = new_status
10651269
logger.info(f"Service {service_path} toggled OFF. Status set to disabled.")
1270+
# --- Update FAISS metadata for disabled service --- START
1271+
if embedding_model and faiss_index is not None:
1272+
logger.info(f"Updating FAISS metadata for disabled service {service_path}.")
1273+
# REGISTERED_SERVERS[service_path] contains the static definition
1274+
await add_or_update_service_in_faiss(service_path, REGISTERED_SERVERS[service_path])
1275+
else:
1276+
logger.warning(f"Skipped FAISS metadata update for disabled service {service_path}: model or index not ready.")
1277+
# --- Update FAISS metadata for disabled service --- END
10661278

10671279
# --- Send *targeted* update via WebSocket --- START
10681280
# Send immediate feedback for the toggled service only
@@ -1229,6 +1441,15 @@ async def register_service(
12291441
else:
12301442
logger.info("[DEBUG] Successfully regenerated Nginx configuration")
12311443

1444+
# --- Add to FAISS Index --- START
1445+
logger.info(f"[DEBUG] Adding/updating service '{path}' in FAISS index after registration...")
1446+
if embedding_model and faiss_index is not None:
1447+
await add_or_update_service_in_faiss(path, server_entry) # server_entry is the new service info
1448+
logger.info(f"[DEBUG] Service '{path}' processed for FAISS index.")
1449+
else:
1450+
logger.warning(f"[DEBUG] Skipped FAISS update for '{path}': model or index not ready.")
1451+
# --- Add to FAISS Index --- END
1452+
12321453
logger.info(f"[INFO] New service registered: '{name}' at path '{path}' by user '{username}'")
12331454

12341455
# --- Persist the updated state after registration --- START
@@ -1464,6 +1685,15 @@ async def edit_server_submit(
14641685
if not regenerate_nginx_config():
14651686
logger.error("ERROR: Failed to update Nginx configuration after edit.")
14661687
# Consider how to notify user - maybe flash message system needed
1688+
1689+
# --- Update FAISS Index --- START
1690+
logger.info(f"Updating service '{service_path}' in FAISS index after edit.")
1691+
if embedding_model and faiss_index is not None:
1692+
await add_or_update_service_in_faiss(service_path, updated_server_entry)
1693+
logger.info(f"Service '{service_path}' updated in FAISS index.")
1694+
else:
1695+
logger.warning(f"Skipped FAISS update for '{service_path}' post-edit: model or index not ready.")
1696+
# --- Update FAISS Index --- END
14671697

14681698
logger.info(f"Server '{name}' ({service_path}) updated by user '{username}'")
14691699

servers/mcpgw/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@ dependencies = [
1010
"httpx>=0.27.0", # Added httpx
1111
"python-dotenv>=1.0.0", # Added dotenv as it's used in server.py
1212
"websockets>=15.0.1",
13+
"faiss-cpu>=1.7.4",
14+
"sentence-transformers>=2.2.2", # For semantic search
15+
"scikit-learn>=1.3.0" # For cosine similarity
1316
]

0 commit comments

Comments
 (0)