99import sys
1010import time
1111import logging
12+ import argparse
13+ import subprocess
1214from datetime import datetime
1315from pathlib import Path
1416from typing import Optional , Dict , Any
15- import subprocess
17+
18+ # Matplotlib imports - set backend early for headless environments
19+ try :
20+ import matplotlib
21+ matplotlib .use ('Agg' ) # Use non-interactive backend
22+ import matplotlib .pyplot as plt
23+ MATPLOTLIB_AVAILABLE = True
24+ except ImportError :
25+ MATPLOTLIB_AVAILABLE = False
26+ plt = None
1627
1728# Add harness directory to path
1829sys .path .insert (0 , os .path .dirname (os .path .abspath (__file__ )))
1930
2031# Import harness components
2132try :
22- from backendserver import VLLMServer , create_server , start_server_from_config
33+ from backendserver import VLLMServer , create_server , start_server_from_config , load_server_config
2334 from Client import LoadGenOfflineClient , LoadGenServerClient , create_loadgen_client
2435 from data .dataset_processor import DatasetProcessor
2536except ImportError :
2637 # Try relative imports
2738 sys .path .insert (0 , os .path .dirname (os .path .abspath (__file__ )))
28- from backendserver import VLLMServer , create_server , start_server_from_config
39+ from backendserver import VLLMServer , create_server , start_server_from_config , load_server_config
2940 from Client import LoadGenOfflineClient , LoadGenServerClient , create_loadgen_client
3041 from data .dataset_processor import DatasetProcessor
3142
4859 METRICS_AVAILABLE = False
4960 logging .warning ("Metrics collection not available" )
5061
62+ # Import environment info collector
63+ try :
64+ from environment .environment_info import EnvironmentInfoCollector
65+ ENVIRONMENT_INFO_AVAILABLE = True
66+ except ImportError :
67+ ENVIRONMENT_INFO_AVAILABLE = False
68+
69+ # Import metrics CSVStorage if available
70+ try :
71+ from metrics .vllm_metrics_collector import CSVStorage
72+ CSV_STORAGE_AVAILABLE = True
73+ except ImportError :
74+ CSV_STORAGE_AVAILABLE = False
75+
76+ # Import MLflow client if available
77+ try :
78+ sys .path .insert (0 , os .path .join (os .path .dirname (__file__ ), 'mlflow_tools' ))
79+ from mlflow_client import MLflowClient
80+ MLFLOW_AVAILABLE = True
81+ except ImportError :
82+ MLFLOW_AVAILABLE = False
83+
5184
5285class Llama31_8BHarness :
5386 """
@@ -58,6 +91,7 @@ class Llama31_8BHarness:
5891 - LoadGen client (Offline/Server scenarios)
5992 - Dataset processing
6093 - Metrics collection and visualization
94+ - MLflow tracking and artifact upload
6195 """
6296
6397 def __init__ (self ,
@@ -71,7 +105,10 @@ def __init__(self,
71105 num_samples : int = 13368 ,
72106 output_dir : str = "./harness_output" ,
73107 enable_metrics : bool = False ,
74- metrics_interval : int = 10 ):
108+ metrics_interval : int = 10 ,
109+ mlflow_tracking_uri : Optional [str ] = None ,
110+ mlflow_experiment_name : Optional [str ] = None ,
111+ mlflow_output_dir : Optional [str ] = None ):
75112 """
76113 Initialize harness.
77114
@@ -87,6 +124,9 @@ def __init__(self,
87124 output_dir: Output directory for logs and results
88125 enable_metrics: Enable metrics collection
89126 metrics_interval: Metrics collection interval (seconds)
127+ mlflow_tracking_uri: MLflow tracking server URI (e.g., http://localhost:5000)
128+ mlflow_experiment_name: MLflow experiment name
129+ mlflow_output_dir: Output directory to upload to MLflow (defaults to output_dir)
90130 """
91131 self .model_name = model_name
92132 self .dataset_path = dataset_path
@@ -100,6 +140,12 @@ def __init__(self,
100140 self .enable_metrics = enable_metrics
101141 self .metrics_interval = metrics_interval
102142
143+ # MLflow configuration
144+ self .mlflow_tracking_uri = mlflow_tracking_uri
145+ self .mlflow_experiment_name = mlflow_experiment_name
146+ self .mlflow_output_dir = Path (mlflow_output_dir ) if mlflow_output_dir else self .output_dir
147+ self .mlflow_client = None
148+
103149 # Setup logging
104150 self .logger = logging .getLogger (self .__class__ .__name__ )
105151
@@ -130,6 +176,10 @@ def __init__(self,
130176 self .logger .info (f" - MLPerf logs: { self .mlperf_output_dir } " )
131177 self .logger .info (f" - Environment info: { self .environment_output_dir } " )
132178
179+ # Initialize MLflow if configured
180+ if self .mlflow_tracking_uri and self .mlflow_experiment_name :
181+ self ._initialize_mlflow ()
182+
133183 # Setup stdout redirection to harness_output
134184 self ._setup_stdout_redirection ()
135185
@@ -151,6 +201,25 @@ def __init__(self,
151201 self .original_stdout = None
152202 self .original_stderr = None
153203
204+ def _initialize_mlflow (self ):
205+ """Initialize MLflow client if configured."""
206+ if not MLFLOW_AVAILABLE :
207+ self .logger .warning ("MLflow is not available. MLflow tracking will be disabled." )
208+ return
209+
210+ try :
211+ self .mlflow_client = MLflowClient (
212+ tracking_uri = self .mlflow_tracking_uri ,
213+ experiment_name = self .mlflow_experiment_name ,
214+ client_type = "loadgen" ,
215+ output_dir = str (self .mlflow_output_dir )
216+ )
217+ self .logger .info (f"MLflow client initialized: { self .mlflow_tracking_uri } " )
218+ self .logger .info (f"MLflow experiment: { self .mlflow_experiment_name } " )
219+ except Exception as e :
220+ self .logger .warning (f"Failed to initialize MLflow client: { e } " )
221+ self .mlflow_client = None
222+
154223 def _setup_stdout_redirection (self ):
155224 """Setup stdout and stderr redirection to harness_output directory."""
156225 try :
@@ -193,18 +262,18 @@ def _restore_stdout_redirection(self):
193262
194263 def _collect_environment_info (self ):
195264 """Collect environment information."""
265+ if not ENVIRONMENT_INFO_AVAILABLE :
266+ self .logger .debug ("Environment info collector not available" )
267+ return
268+
196269 try :
197- from environment .environment_info import EnvironmentInfoCollector
198-
199270 collector = EnvironmentInfoCollector (self .environment_output_dir )
200271 results = collector .collect_all ()
201272
202273 self .logger .info ("Environment information collected:" )
203274 self .logger .info (f" - Successfully collected: { list (results .get ('success' , {}).keys ())} " )
204275 if results .get ('errors' ):
205276 self .logger .warning (f" - Errors: { list (results ['errors' ].keys ())} " )
206- except ImportError as e :
207- self .logger .warning (f"Could not import environment info collector: { e } " )
208277 except Exception as e :
209278 self .logger .warning (f"Error collecting environment info: { e } " )
210279
@@ -309,8 +378,12 @@ def initialize_metrics(self):
309378 # Use metrics subdirectory - default to CSV format
310379 timestamp = datetime .now ().strftime ('%Y%m%d_%H%M%S' )
311380 metrics_file = self .metrics_output_dir / f"metrics_{ timestamp } .csv"
312- from metrics .vllm_metrics_collector import CSVStorage
313- storage = CSVStorage (str (metrics_file ))
381+
382+ if CSV_STORAGE_AVAILABLE :
383+ storage = CSVStorage (str (metrics_file ))
384+ else :
385+ # Fallback to JSON storage
386+ storage = JSONStorage (str (metrics_file ).replace ('.csv' , '.json' ))
314387
315388 self .metrics_collector = VLLMMetricsCollector (
316389 metrics_endpoint = f"{ self .api_server_url } /metrics" ,
@@ -356,8 +429,28 @@ def run(self, user_conf: str = "user.conf", lg_model_name: str = "llama3_1-8b")
356429 server_started_here = False
357430 client_initialized = False
358431 metrics_started = False
432+ mlflow_run_started = False
359433
360434 try :
435+ # Start MLflow run if configured
436+ if self .mlflow_client :
437+ try :
438+ self .mlflow_client .start_run ()
439+ mlflow_run_started = True
440+
441+ # Log parameters
442+ params = {
443+ 'model_name' : self .model_name ,
444+ 'scenario' : self .scenario ,
445+ 'test_mode' : self .test_mode ,
446+ 'batch_size' : str (self .batch_size ),
447+ 'num_samples' : str (self .num_samples )
448+ }
449+ self .mlflow_client .log_parameters (params )
450+ except Exception as e :
451+ self .logger .warning (f"Failed to start MLflow run: { e } " )
452+ mlflow_run_started = False
453+
361454 # Start server if needed
362455 if not self .api_server_url :
363456 self .start_server ()
@@ -459,25 +552,50 @@ def run(self, user_conf: str = "user.conf", lg_model_name: str = "llama3_1-8b")
459552 if self .enable_metrics and self .metrics_collector :
460553 self ._generate_metrics_visualizations ()
461554
462- return {
555+ test_results = {
463556 'status' : 'success' ,
464557 'duration' : test_duration ,
465558 'scenario' : self .scenario ,
466559 'test_mode' : self .test_mode ,
467560 'num_samples' : self .num_samples
468561 }
469562
563+ # Upload to MLflow if configured
564+ if mlflow_run_started and self .mlflow_client :
565+ try :
566+ self ._upload_to_mlflow (test_results )
567+ except Exception as e :
568+ self .logger .warning (f"Failed to upload to MLflow: { e } " )
569+
570+ return test_results
571+
470572 except Exception as e :
471573 self .logger .error (f"Test failed: { e } " , exc_info = True )
472- return {
574+ test_results = {
473575 'status' : 'failed' ,
474576 'error' : str (e )
475577 }
578+
579+ # Upload failure to MLflow if configured
580+ if mlflow_run_started and self .mlflow_client :
581+ try :
582+ self ._upload_to_mlflow (test_results )
583+ except Exception as e :
584+ self .logger .warning (f"Failed to upload failure to MLflow: { e } " )
585+
586+ return test_results
476587
477588 finally :
478589 # Cleanup in reverse order of initialization
479590 self .logger .info ("Performing cleanup..." )
480591
592+ # End MLflow run if started
593+ if mlflow_run_started and self .mlflow_client :
594+ try :
595+ self .mlflow_client .end_run ()
596+ except Exception as e :
597+ self .logger .warning (f"Error ending MLflow run: { e } " )
598+
481599 # Stop metrics collector if not already stopped (in case of exception)
482600 if metrics_started and self .metrics_collector :
483601 try :
@@ -510,6 +628,30 @@ def run(self, user_conf: str = "user.conf", lg_model_name: str = "llama3_1-8b")
510628
511629 self .logger .info ("Cleanup completed" )
512630
631+ def _upload_to_mlflow (self , test_results : Dict [str , Any ]):
632+ """Upload test results and artifacts to MLflow."""
633+ if not self .mlflow_client :
634+ return
635+
636+ try :
637+ # Log client-specific metrics
638+ self .mlflow_client .log_client_metrics (test_results )
639+
640+ # Generate and log description
641+ description = self .mlflow_client .get_client_description (test_results )
642+ self .mlflow_client .log_description (description )
643+
644+ # Upload artifacts - upload entire output directory
645+ self .mlflow_client .upload_artifacts (
646+ output_dir = str (self .mlflow_output_dir ),
647+ include_subdirs = True
648+ )
649+
650+ self .logger .info ("Successfully uploaded to MLflow" )
651+ except Exception as e :
652+ self .logger .error (f"Failed to upload to MLflow: { e } " , exc_info = True )
653+ raise
654+
513655 def _load_samples_to_ram (self , query_samples ):
514656 """LoadGen callback - samples are pre-loaded in Dataset."""
515657 pass
@@ -524,11 +666,11 @@ def _generate_metrics_visualizations(self):
524666 self .logger .warning ("Metrics collector or visualizer not available" )
525667 return
526668
669+ if not MATPLOTLIB_AVAILABLE :
670+ self .logger .warning ("Matplotlib not available. Visualizations will not be generated." )
671+ return
672+
527673 try :
528- # Set matplotlib backend for headless environments
529- import matplotlib
530- matplotlib .use ('Agg' ) # Use non-interactive backend
531-
532674 storage_file = self .metrics_collector ._get_storage_file_path ()
533675 if not storage_file :
534676 self .logger .warning ("No storage file path available from metrics collector" )
@@ -615,8 +757,7 @@ def _generate_metrics_visualizations(self):
615757 save_path = str (save_path ),
616758 show_labels = False # Disable label grouping to avoid dict unhashable error
617759 )
618- # Close plot to free memory
619- import matplotlib .pyplot as plt
760+ # Close all figures to free memory and ensure clean state for next plot
620761 plt .close ('all' )
621762 self .logger .info (f"✓ Generated visualization: { save_path } " )
622763 successful_viz += 1
@@ -627,7 +768,6 @@ def _generate_metrics_visualizations(self):
627768 self .logger .error (f"Failed to generate visualization { viz ['filename' ]} : { e } " , exc_info = True )
628769 # Close plot even on error
629770 try :
630- import matplotlib .pyplot as plt
631771 plt .close ('all' )
632772 except :
633773 pass
@@ -643,8 +783,6 @@ def _generate_metrics_visualizations(self):
643783
644784def main ():
645785 """Main entry point for harness."""
646- import argparse
647-
648786 parser = argparse .ArgumentParser (description = "MLPerf Harness for Llama 3.1 8B" )
649787
650788 parser .add_argument ("--model" , type = str , required = True , help = "Model name or path" )
@@ -663,6 +801,16 @@ def main():
663801 parser .add_argument ("--lg-model-name" , type = str , default = "llama3_1-8b" , help = "LoadGen model name" )
664802 parser .add_argument ("--log-level" , type = str , default = "INFO" , help = "Logging level" )
665803
804+ # MLflow arguments
805+ parser .add_argument ("--mlflow-experiment-name" , type = str , default = None ,
806+ help = "MLflow experiment name (enables MLflow tracking)" )
807+ parser .add_argument ("--mlflow-output-dir" , type = str , default = None ,
808+ help = "Output directory to upload to MLflow (defaults to --output-dir)" )
809+ parser .add_argument ("--mlflow-host" , type = str , default = "localhost" ,
810+ help = "MLflow tracking server hostname" )
811+ parser .add_argument ("--mlflow-port" , type = int , default = 5000 ,
812+ help = "MLflow tracking server port" )
813+
666814 args = parser .parse_args ()
667815
668816 # Configure logging
@@ -675,10 +823,14 @@ def main():
675823 # Load server config if provided
676824 server_config = {}
677825 if args .server_config :
678- from backendserver import load_server_config
679826 server_config = load_server_config (args .server_config )
680827 server_config ['config_file' ] = args .server_config
681828
829+ # Construct MLflow tracking URI if experiment name is provided
830+ mlflow_tracking_uri = None
831+ if args .mlflow_experiment_name :
832+ mlflow_tracking_uri = f"http://{ args .mlflow_host } :{ args .mlflow_port } "
833+
682834 # Create and run harness
683835 harness = Llama31_8BHarness (
684836 model_name = args .model ,
@@ -690,7 +842,10 @@ def main():
690842 batch_size = args .batch_size ,
691843 num_samples = args .num_samples ,
692844 output_dir = args .output_dir ,
693- enable_metrics = args .enable_metrics
845+ enable_metrics = args .enable_metrics ,
846+ mlflow_tracking_uri = mlflow_tracking_uri ,
847+ mlflow_experiment_name = args .mlflow_experiment_name ,
848+ mlflow_output_dir = args .mlflow_output_dir
694849 )
695850
696851 results = harness .run (user_conf = args .user_conf , lg_model_name = args .lg_model_name )
@@ -705,4 +860,3 @@ def main():
705860
706861if __name__ == "__main__" :
707862 sys .exit (main ())
708-
0 commit comments