|
8 | 8 | from typing import Annotated, List, Set |
9 | 9 | from datetime import datetime, timezone |
10 | 10 |
|
| 11 | +import faiss |
| 12 | +import numpy as np |
| 13 | +from sentence_transformers import SentenceTransformer |
| 14 | + |
11 | 15 | from fastapi import ( |
12 | 16 | FastAPI, |
13 | 17 | Request, |
|
68 | 72 | # LOG_FILE_PATH = BASE_DIR / "registry.log" |
69 | 73 | LOG_FILE_PATH = CONTAINER_LOG_DIR / "registry.log" |
70 | 74 |
|
| 75 | +# --- FAISS Vector DB Configuration --- START |
| 76 | +FAISS_INDEX_PATH = SERVERS_DIR / "service_index.faiss" |
| 77 | +FAISS_METADATA_PATH = SERVERS_DIR / "service_index_metadata.json" |
| 78 | +EMBEDDING_DIMENSION = 384 # For all-MiniLM-L6-v2 |
| 79 | +embedding_model = None # Will be loaded in lifespan |
| 80 | +faiss_index = None # Will be loaded/created in lifespan |
| 81 | +# Stores: { service_path: {"id": faiss_internal_id, "text_for_embedding": "...", "full_server_info": { ... }} } |
| 82 | +# faiss_internal_id is the ID used with faiss_index.add_with_ids() |
| 83 | +faiss_metadata_store = {} |
| 84 | +next_faiss_id_counter = 0 |
| 85 | +# --- FAISS Vector DB Configuration --- END |
| 86 | + |
71 | 87 | # --- REMOVE Logging Setup from here --- START |
72 | 88 | # # Ensure log directory exists |
73 | 89 | # CONTAINER_LOG_DIR.mkdir(parents=True, exist_ok=True) |
|
107 | 123 | # --- WebSocket Connection Management --- |
108 | 124 | active_connections: Set[WebSocket] = set() |
109 | 125 |
|
| 126 | +# --- FAISS Helper Functions --- START |
| 127 | + |
| 128 | +def _get_text_for_embedding(server_info: dict) -> str: |
| 129 | + """Prepares a consistent text string from server info for embedding.""" |
| 130 | + name = server_info.get("server_name", "") |
| 131 | + description = server_info.get("description", "") |
| 132 | + tags = server_info.get("tags", []) |
| 133 | + tag_string = ", ".join(tags) |
| 134 | + return f"Name: {name}\\nDescription: {description}\\nTags: {tag_string}" |
| 135 | + |
| 136 | +def load_faiss_data(): |
| 137 | + global faiss_index, faiss_metadata_store, embedding_model, next_faiss_id_counter, CONTAINER_REGISTRY_DIR, SERVERS_DIR |
| 138 | + logger.info("Loading FAISS data and embedding model...") |
| 139 | + |
| 140 | + SERVERS_DIR.mkdir(parents=True, exist_ok=True) |
| 141 | + |
| 142 | + try: |
| 143 | + model_cache_path = CONTAINER_REGISTRY_DIR / ".cache" |
| 144 | + model_cache_path.mkdir(parents=True, exist_ok=True) |
| 145 | + # Set SENTENCE_TRANSFORMERS_HOME to use the defined cache path |
| 146 | + # This needs to be set before SentenceTransformer is imported or used if we want to control the cache location this way. |
| 147 | + # However, setting it here is fine as it's done before the model is instantiated. |
| 148 | + original_st_home = os.environ.get('SENTENCE_TRANSFORMERS_HOME') |
| 149 | + os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_cache_path) |
| 150 | + logger.info(f"Attempting to load SentenceTransformer model 'all-MiniLM-L6-v2'. Cache: {model_cache_path}") |
| 151 | + |
| 152 | + embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
| 153 | + |
| 154 | + # Restore original environment variable if it was set |
| 155 | + if original_st_home: |
| 156 | + os.environ['SENTENCE_TRANSFORMERS_HOME'] = original_st_home |
| 157 | + else: |
| 158 | + del os.environ['SENTENCE_TRANSFORMERS_HOME'] # Remove if not originally set |
| 159 | + |
| 160 | + logger.info("SentenceTransformer model 'all-MiniLM-L6-v2' loaded successfully.") |
| 161 | + except Exception as e: |
| 162 | + logger.error(f"Failed to load SentenceTransformer model: {e}", exc_info=True) |
| 163 | + embedding_model = None |
| 164 | + |
| 165 | + if FAISS_INDEX_PATH.exists() and FAISS_METADATA_PATH.exists(): |
| 166 | + try: |
| 167 | + logger.info(f"Loading FAISS index from {FAISS_INDEX_PATH}") |
| 168 | + faiss_index = faiss.read_index(str(FAISS_INDEX_PATH)) |
| 169 | + logger.info(f"Loading FAISS metadata from {FAISS_METADATA_PATH}") |
| 170 | + with open(FAISS_METADATA_PATH, "r") as f: |
| 171 | + loaded_metadata = json.load(f) |
| 172 | + faiss_metadata_store = loaded_metadata.get("metadata", {}) |
| 173 | + next_faiss_id_counter = loaded_metadata.get("next_id", 0) |
| 174 | + logger.info(f"FAISS data loaded. Index size: {faiss_index.ntotal if faiss_index else 0}. Next ID: {next_faiss_id_counter}") |
| 175 | + if faiss_index and faiss_index.d != EMBEDDING_DIMENSION: |
| 176 | + logger.warning(f"Loaded FAISS index dimension ({faiss_index.d}) differs from expected ({EMBEDDING_DIMENSION}). Re-initializing.") |
| 177 | + faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(EMBEDDING_DIMENSION)) |
| 178 | + faiss_metadata_store = {} |
| 179 | + next_faiss_id_counter = 0 |
| 180 | + except Exception as e: |
| 181 | + logger.error(f"Error loading FAISS data: {e}. Re-initializing.", exc_info=True) |
| 182 | + faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(EMBEDDING_DIMENSION)) |
| 183 | + faiss_metadata_store = {} |
| 184 | + next_faiss_id_counter = 0 |
| 185 | + else: |
| 186 | + logger.info("FAISS index or metadata not found. Initializing new.") |
| 187 | + faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(EMBEDDING_DIMENSION)) |
| 188 | + faiss_metadata_store = {} |
| 189 | + next_faiss_id_counter = 0 |
| 190 | + |
| 191 | +def save_faiss_data(): |
| 192 | + global faiss_index, faiss_metadata_store, next_faiss_id_counter |
| 193 | + if faiss_index is None: |
| 194 | + logger.error("FAISS index is not initialized. Cannot save.") |
| 195 | + return |
| 196 | + try: |
| 197 | + SERVERS_DIR.mkdir(parents=True, exist_ok=True) # Ensure directory exists |
| 198 | + logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH} (Size: {faiss_index.ntotal})") |
| 199 | + faiss.write_index(faiss_index, str(FAISS_INDEX_PATH)) |
| 200 | + logger.info(f"Saving FAISS metadata to {FAISS_METADATA_PATH}") |
| 201 | + with open(FAISS_METADATA_PATH, "w") as f: |
| 202 | + json.dump({"metadata": faiss_metadata_store, "next_id": next_faiss_id_counter}, f, indent=2) |
| 203 | + logger.info("FAISS data saved successfully.") |
| 204 | + except Exception as e: |
| 205 | + logger.error(f"Error saving FAISS data: {e}", exc_info=True) |
| 206 | + |
| 207 | +async def add_or_update_service_in_faiss(service_path: str, server_info: dict): |
| 208 | + global faiss_index, faiss_metadata_store, embedding_model, next_faiss_id_counter |
| 209 | + |
| 210 | + if embedding_model is None or faiss_index is None: |
| 211 | + logger.error("Embedding model or FAISS index not initialized. Cannot add/update service in FAISS.") |
| 212 | + return |
| 213 | + |
| 214 | + logger.info(f"Attempting to add/update service '{service_path}' in FAISS.") |
| 215 | + text_to_embed = _get_text_for_embedding(server_info) |
| 216 | + |
| 217 | + current_faiss_id = -1 |
| 218 | + needs_new_embedding = True # Assume new embedding is needed |
| 219 | + |
| 220 | + existing_entry = faiss_metadata_store.get(service_path) |
| 221 | + |
| 222 | + if existing_entry: |
| 223 | + current_faiss_id = existing_entry["id"] |
| 224 | + if existing_entry.get("text_for_embedding") == text_to_embed: |
| 225 | + needs_new_embedding = False |
| 226 | + logger.info(f"Text for embedding for '{service_path}' has not changed. Will update metadata store only if server_info differs.") |
| 227 | + else: |
| 228 | + logger.info(f"Text for embedding for '{service_path}' has changed. Re-embedding required.") |
| 229 | + else: # New service |
| 230 | + current_faiss_id = next_faiss_id_counter |
| 231 | + next_faiss_id_counter += 1 |
| 232 | + logger.info(f"New service '{service_path}'. Assigning new FAISS ID: {current_faiss_id}.") |
| 233 | + needs_new_embedding = True # Definitely needs embedding |
| 234 | + |
| 235 | + if needs_new_embedding: |
| 236 | + try: |
| 237 | + # Run model encoding in a separate thread to avoid blocking asyncio event loop |
| 238 | + embedding = await asyncio.to_thread(embedding_model.encode, [text_to_embed]) |
| 239 | + embedding_np = np.array([embedding[0]], dtype=np.float32) |
| 240 | + |
| 241 | + ids_to_remove = np.array([current_faiss_id]) |
| 242 | + if existing_entry: # Only attempt removal if it was an existing entry |
| 243 | + try: |
| 244 | + # remove_ids returns number of vectors removed. |
| 245 | + # It's okay if the ID isn't found (returns 0). |
| 246 | + num_removed = faiss_index.remove_ids(ids_to_remove) |
| 247 | + if num_removed > 0: |
| 248 | + logger.info(f"Removed {num_removed} old vector(s) for FAISS ID {current_faiss_id} ({service_path}).") |
| 249 | + else: |
| 250 | + logger.info(f"No old vector found for FAISS ID {current_faiss_id} ({service_path}) during update, or ID not in index.") |
| 251 | + except Exception as e_remove: # Should be rare with IndexIDMap if ID was valid type |
| 252 | + logger.warning(f"Issue removing FAISS ID {current_faiss_id} for {service_path}: {e_remove}. Proceeding to add.") |
| 253 | + |
| 254 | + faiss_index.add_with_ids(embedding_np, np.array([current_faiss_id])) |
| 255 | + logger.info(f"Added/Updated vector for '{service_path}' with FAISS ID {current_faiss_id}.") |
| 256 | + except Exception as e: |
| 257 | + logger.error(f"Error encoding or adding embedding for '{service_path}': {e}", exc_info=True) |
| 258 | + return # Don't update metadata or save if embedding failed |
| 259 | + |
| 260 | + # Update metadata store if new, or if text changed, or if full_server_info changed |
| 261 | + # --- Enrich server_info with is_enabled status before storing --- START |
| 262 | + enriched_server_info = server_info.copy() |
| 263 | + enriched_server_info["is_enabled"] = MOCK_SERVICE_STATE.get(service_path, False) # Default to False if not found |
| 264 | + # --- Enrich server_info with is_enabled status before storing --- END |
| 265 | + |
| 266 | + if existing_entry is None or needs_new_embedding or existing_entry.get("full_server_info") != enriched_server_info: |
| 267 | + faiss_metadata_store[service_path] = { |
| 268 | + "id": current_faiss_id, |
| 269 | + "text_for_embedding": text_to_embed, |
| 270 | + "full_server_info": enriched_server_info # Store the enriched server_info |
| 271 | + } |
| 272 | + logger.debug(f"Updated faiss_metadata_store for '{service_path}'.") |
| 273 | + await asyncio.to_thread(save_faiss_data) # Persist changes in a thread |
| 274 | + else: |
| 275 | + logger.debug(f"No changes to FAISS vector or enriched full_server_info for '{service_path}'. Skipping save.") |
| 276 | + |
| 277 | +# --- FAISS Helper Functions --- END |
| 278 | + |
110 | 279 | async def broadcast_health_status(): |
111 | 280 | """Sends the current health status to all connected WebSocket clients.""" |
112 | 281 | if active_connections: |
@@ -662,14 +831,23 @@ async def perform_single_health_check(path: str) -> tuple[str, datetime | None]: |
662 | 831 | # Save the updated server info to its file |
663 | 832 | if not save_server_to_file(REGISTERED_SERVERS[path]): |
664 | 833 | logger.error(f"ERROR: Failed to save updated tool list/count for {path} to file.") |
| 834 | + # --- Update FAISS after tool list/count change --- START |
| 835 | + # No explicit call here, will be handled by the one at the end of perform_single_health_check |
| 836 | + # logger.info(f"Updating FAISS metadata for '{path}' after tool list/count update.") |
| 837 | + # await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path]) # Moved to end |
| 838 | + # --- Update FAISS after tool list/count change --- END |
665 | 839 | else: |
666 | 840 | logger.info(f"Tool list for {path} remains unchanged. No update needed.") |
667 | 841 | else: |
668 | 842 | logger.info(f"Failed to retrieve tool list for healthy service {path}. List/Count remains unchanged.") |
| 843 | + # Even if tool list fetch failed, server is healthy. |
| 844 | + # FAISS update will occur at the end of this function with current REGISTERED_SERVERS[path]. |
669 | 845 | else: |
670 | 846 | # This case should technically not be reachable due to earlier url check |
671 | 847 | logger.info(f"Cannot fetch tool list for {path}: proxy_pass_url is missing.") |
672 | 848 | # --- Check for transition to healthy state --- END |
| 849 | + # If it was already healthy, and tools changed, the above block (current_tool_list_str != new_tool_list_str) handles it. |
| 850 | + # The FAISS update with the latest REGISTERED_SERVERS[path] will happen at the end of this function. |
673 | 851 |
|
674 | 852 | elif proc.returncode == 28: |
675 | 853 | current_status = f"error: timeout ({HEALTH_CHECK_TIMEOUT_SECONDS}s)" |
@@ -703,6 +881,11 @@ async def perform_single_health_check(path: str) -> tuple[str, datetime | None]: |
703 | 881 | SERVER_HEALTH_STATUS[path] = current_status |
704 | 882 | logger.info(f"Final health status for {path}: {current_status}") |
705 | 883 |
|
| 884 | + # --- Update FAISS with final server_info state after health check attempt --- |
| 885 | + if path in REGISTERED_SERVERS and embedding_model and faiss_index is not None: |
| 886 | + logger.info(f"Updating FAISS metadata for '{path}' post health check (status: {current_status}).") |
| 887 | + await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path]) |
| 888 | + |
706 | 889 | # --- Regenerate Nginx if status affecting it changed --- START |
707 | 890 | # Check if the service is enabled AND its Nginx-relevant status changed |
708 | 891 | if is_enabled: |
@@ -801,8 +984,25 @@ async def lifespan(app: FastAPI): |
801 | 984 | logger.info("Logging configured. Running startup tasks...") # Now logger is configured |
802 | 985 | # --- Configure Logging INSIDE lifespan --- END |
803 | 986 |
|
| 987 | + # 0. Load FAISS data and embedding model |
| 988 | + load_faiss_data() # Loads model, empty index or existing index. Synchronous. |
| 989 | + |
804 | 990 | # 1. Load server definitions and persisted enabled/disabled state |
805 | | - load_registered_servers_and_state() |
| 991 | + load_registered_servers_and_state() # This populates REGISTERED_SERVERS. Synchronous. |
| 992 | + |
| 993 | + # 1.5 Sync FAISS with loaded servers (initial build or update) |
| 994 | + if embedding_model and faiss_index is not None: # Check faiss_index is not None |
| 995 | + logger.info("Performing initial FAISS synchronization with loaded server definitions...") |
| 996 | + sync_tasks = [] |
| 997 | + for path, server_info in REGISTERED_SERVERS.items(): |
| 998 | + # add_or_update_service_in_faiss is async, can be gathered |
| 999 | + sync_tasks.append(add_or_update_service_in_faiss(path, server_info)) |
| 1000 | + |
| 1001 | + if sync_tasks: |
| 1002 | + await asyncio.gather(*sync_tasks) |
| 1003 | + logger.info("Initial FAISS synchronization complete.") |
| 1004 | + else: |
| 1005 | + logger.warning("Skipping initial FAISS synchronization: embedding model or FAISS index not ready.") |
806 | 1006 |
|
807 | 1007 | # 2. Perform initial health checks concurrently for *enabled* services |
808 | 1008 | logger.info("Performing initial health checks for enabled services...") |
@@ -838,6 +1038,10 @@ async def lifespan(app: FastAPI): |
838 | 1038 | else: |
839 | 1039 | status, _ = result # Unpack the result tuple |
840 | 1040 | logger.info(f"Initial health check completed for {path}: Status = {status}") |
| 1041 | + # Update FAISS with potentially changed server_info (e.g., num_tools from health check) |
| 1042 | + if path in REGISTERED_SERVERS and embedding_model and faiss_index is not None: |
| 1043 | + # This runs after each health check result, can be awaited individually |
| 1044 | + await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path]) |
841 | 1045 | else: |
842 | 1046 | logger.info("No services are initially enabled.") |
843 | 1047 |
|
@@ -1063,6 +1267,14 @@ async def toggle_service_route( |
1063 | 1267 | # Update global state directly when disabling |
1064 | 1268 | SERVER_HEALTH_STATUS[service_path] = new_status |
1065 | 1269 | logger.info(f"Service {service_path} toggled OFF. Status set to disabled.") |
| 1270 | + # --- Update FAISS metadata for disabled service --- START |
| 1271 | + if embedding_model and faiss_index is not None: |
| 1272 | + logger.info(f"Updating FAISS metadata for disabled service {service_path}.") |
| 1273 | + # REGISTERED_SERVERS[service_path] contains the static definition |
| 1274 | + await add_or_update_service_in_faiss(service_path, REGISTERED_SERVERS[service_path]) |
| 1275 | + else: |
| 1276 | + logger.warning(f"Skipped FAISS metadata update for disabled service {service_path}: model or index not ready.") |
| 1277 | + # --- Update FAISS metadata for disabled service --- END |
1066 | 1278 |
|
1067 | 1279 | # --- Send *targeted* update via WebSocket --- START |
1068 | 1280 | # Send immediate feedback for the toggled service only |
@@ -1229,6 +1441,15 @@ async def register_service( |
1229 | 1441 | else: |
1230 | 1442 | logger.info("[DEBUG] Successfully regenerated Nginx configuration") |
1231 | 1443 |
|
| 1444 | + # --- Add to FAISS Index --- START |
| 1445 | + logger.info(f"[DEBUG] Adding/updating service '{path}' in FAISS index after registration...") |
| 1446 | + if embedding_model and faiss_index is not None: |
| 1447 | + await add_or_update_service_in_faiss(path, server_entry) # server_entry is the new service info |
| 1448 | + logger.info(f"[DEBUG] Service '{path}' processed for FAISS index.") |
| 1449 | + else: |
| 1450 | + logger.warning(f"[DEBUG] Skipped FAISS update for '{path}': model or index not ready.") |
| 1451 | + # --- Add to FAISS Index --- END |
| 1452 | + |
1232 | 1453 | logger.info(f"[INFO] New service registered: '{name}' at path '{path}' by user '{username}'") |
1233 | 1454 |
|
1234 | 1455 | # --- Persist the updated state after registration --- START |
@@ -1464,6 +1685,15 @@ async def edit_server_submit( |
1464 | 1685 | if not regenerate_nginx_config(): |
1465 | 1686 | logger.error("ERROR: Failed to update Nginx configuration after edit.") |
1466 | 1687 | # Consider how to notify user - maybe flash message system needed |
| 1688 | + |
| 1689 | + # --- Update FAISS Index --- START |
| 1690 | + logger.info(f"Updating service '{service_path}' in FAISS index after edit.") |
| 1691 | + if embedding_model and faiss_index is not None: |
| 1692 | + await add_or_update_service_in_faiss(service_path, updated_server_entry) |
| 1693 | + logger.info(f"Service '{service_path}' updated in FAISS index.") |
| 1694 | + else: |
| 1695 | + logger.warning(f"Skipped FAISS update for '{service_path}' post-edit: model or index not ready.") |
| 1696 | + # --- Update FAISS Index --- END |
1467 | 1697 |
|
1468 | 1698 | logger.info(f"Server '{name}' ({service_path}) updated by user '{username}'") |
1469 | 1699 |
|
|
0 commit comments