Skip to content

Commit cd3ae2c

Browse files
Added mlflow experiment capability
1 parent 8b0492e commit cd3ae2c

File tree

3 files changed

+550
-24
lines changed

3 files changed

+550
-24
lines changed

harness/harness_llama3.1_8b.py

Lines changed: 178 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,34 @@
99
import sys
1010
import time
1111
import logging
12+
import argparse
13+
import subprocess
1214
from datetime import datetime
1315
from pathlib import Path
1416
from typing import Optional, Dict, Any
15-
import subprocess
17+
18+
# Matplotlib imports - set backend early for headless environments
19+
try:
20+
import matplotlib
21+
matplotlib.use('Agg') # Use non-interactive backend
22+
import matplotlib.pyplot as plt
23+
MATPLOTLIB_AVAILABLE = True
24+
except ImportError:
25+
MATPLOTLIB_AVAILABLE = False
26+
plt = None
1627

1728
# Add harness directory to path
1829
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
1930

2031
# Import harness components
2132
try:
22-
from backendserver import VLLMServer, create_server, start_server_from_config
33+
from backendserver import VLLMServer, create_server, start_server_from_config, load_server_config
2334
from Client import LoadGenOfflineClient, LoadGenServerClient, create_loadgen_client
2435
from data.dataset_processor import DatasetProcessor
2536
except ImportError:
2637
# Try relative imports
2738
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
28-
from backendserver import VLLMServer, create_server, start_server_from_config
39+
from backendserver import VLLMServer, create_server, start_server_from_config, load_server_config
2940
from Client import LoadGenOfflineClient, LoadGenServerClient, create_loadgen_client
3041
from data.dataset_processor import DatasetProcessor
3142

@@ -48,6 +59,28 @@
4859
METRICS_AVAILABLE = False
4960
logging.warning("Metrics collection not available")
5061

62+
# Import environment info collector
63+
try:
64+
from environment.environment_info import EnvironmentInfoCollector
65+
ENVIRONMENT_INFO_AVAILABLE = True
66+
except ImportError:
67+
ENVIRONMENT_INFO_AVAILABLE = False
68+
69+
# Import metrics CSVStorage if available
70+
try:
71+
from metrics.vllm_metrics_collector import CSVStorage
72+
CSV_STORAGE_AVAILABLE = True
73+
except ImportError:
74+
CSV_STORAGE_AVAILABLE = False
75+
76+
# Import MLflow client if available
77+
try:
78+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'mlflow_tools'))
79+
from mlflow_client import MLflowClient
80+
MLFLOW_AVAILABLE = True
81+
except ImportError:
82+
MLFLOW_AVAILABLE = False
83+
5184

5285
class Llama31_8BHarness:
5386
"""
@@ -58,6 +91,7 @@ class Llama31_8BHarness:
5891
- LoadGen client (Offline/Server scenarios)
5992
- Dataset processing
6093
- Metrics collection and visualization
94+
- MLflow tracking and artifact upload
6195
"""
6296

6397
def __init__(self,
@@ -71,7 +105,10 @@ def __init__(self,
71105
num_samples: int = 13368,
72106
output_dir: str = "./harness_output",
73107
enable_metrics: bool = False,
74-
metrics_interval: int = 10):
108+
metrics_interval: int = 10,
109+
mlflow_tracking_uri: Optional[str] = None,
110+
mlflow_experiment_name: Optional[str] = None,
111+
mlflow_output_dir: Optional[str] = None):
75112
"""
76113
Initialize harness.
77114
@@ -87,6 +124,9 @@ def __init__(self,
87124
output_dir: Output directory for logs and results
88125
enable_metrics: Enable metrics collection
89126
metrics_interval: Metrics collection interval (seconds)
127+
mlflow_tracking_uri: MLflow tracking server URI (e.g., http://localhost:5000)
128+
mlflow_experiment_name: MLflow experiment name
129+
mlflow_output_dir: Output directory to upload to MLflow (defaults to output_dir)
90130
"""
91131
self.model_name = model_name
92132
self.dataset_path = dataset_path
@@ -100,6 +140,12 @@ def __init__(self,
100140
self.enable_metrics = enable_metrics
101141
self.metrics_interval = metrics_interval
102142

143+
# MLflow configuration
144+
self.mlflow_tracking_uri = mlflow_tracking_uri
145+
self.mlflow_experiment_name = mlflow_experiment_name
146+
self.mlflow_output_dir = Path(mlflow_output_dir) if mlflow_output_dir else self.output_dir
147+
self.mlflow_client = None
148+
103149
# Setup logging
104150
self.logger = logging.getLogger(self.__class__.__name__)
105151

@@ -130,6 +176,10 @@ def __init__(self,
130176
self.logger.info(f" - MLPerf logs: {self.mlperf_output_dir}")
131177
self.logger.info(f" - Environment info: {self.environment_output_dir}")
132178

179+
# Initialize MLflow if configured
180+
if self.mlflow_tracking_uri and self.mlflow_experiment_name:
181+
self._initialize_mlflow()
182+
133183
# Setup stdout redirection to harness_output
134184
self._setup_stdout_redirection()
135185

@@ -151,6 +201,25 @@ def __init__(self,
151201
self.original_stdout = None
152202
self.original_stderr = None
153203

204+
def _initialize_mlflow(self):
205+
"""Initialize MLflow client if configured."""
206+
if not MLFLOW_AVAILABLE:
207+
self.logger.warning("MLflow is not available. MLflow tracking will be disabled.")
208+
return
209+
210+
try:
211+
self.mlflow_client = MLflowClient(
212+
tracking_uri=self.mlflow_tracking_uri,
213+
experiment_name=self.mlflow_experiment_name,
214+
client_type="loadgen",
215+
output_dir=str(self.mlflow_output_dir)
216+
)
217+
self.logger.info(f"MLflow client initialized: {self.mlflow_tracking_uri}")
218+
self.logger.info(f"MLflow experiment: {self.mlflow_experiment_name}")
219+
except Exception as e:
220+
self.logger.warning(f"Failed to initialize MLflow client: {e}")
221+
self.mlflow_client = None
222+
154223
def _setup_stdout_redirection(self):
155224
"""Setup stdout and stderr redirection to harness_output directory."""
156225
try:
@@ -193,18 +262,18 @@ def _restore_stdout_redirection(self):
193262

194263
def _collect_environment_info(self):
195264
"""Collect environment information."""
265+
if not ENVIRONMENT_INFO_AVAILABLE:
266+
self.logger.debug("Environment info collector not available")
267+
return
268+
196269
try:
197-
from environment.environment_info import EnvironmentInfoCollector
198-
199270
collector = EnvironmentInfoCollector(self.environment_output_dir)
200271
results = collector.collect_all()
201272

202273
self.logger.info("Environment information collected:")
203274
self.logger.info(f" - Successfully collected: {list(results.get('success', {}).keys())}")
204275
if results.get('errors'):
205276
self.logger.warning(f" - Errors: {list(results['errors'].keys())}")
206-
except ImportError as e:
207-
self.logger.warning(f"Could not import environment info collector: {e}")
208277
except Exception as e:
209278
self.logger.warning(f"Error collecting environment info: {e}")
210279

@@ -309,8 +378,12 @@ def initialize_metrics(self):
309378
# Use metrics subdirectory - default to CSV format
310379
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
311380
metrics_file = self.metrics_output_dir / f"metrics_{timestamp}.csv"
312-
from metrics.vllm_metrics_collector import CSVStorage
313-
storage = CSVStorage(str(metrics_file))
381+
382+
if CSV_STORAGE_AVAILABLE:
383+
storage = CSVStorage(str(metrics_file))
384+
else:
385+
# Fallback to JSON storage
386+
storage = JSONStorage(str(metrics_file).replace('.csv', '.json'))
314387

315388
self.metrics_collector = VLLMMetricsCollector(
316389
metrics_endpoint=f"{self.api_server_url}/metrics",
@@ -356,8 +429,28 @@ def run(self, user_conf: str = "user.conf", lg_model_name: str = "llama3_1-8b")
356429
server_started_here = False
357430
client_initialized = False
358431
metrics_started = False
432+
mlflow_run_started = False
359433

360434
try:
435+
# Start MLflow run if configured
436+
if self.mlflow_client:
437+
try:
438+
self.mlflow_client.start_run()
439+
mlflow_run_started = True
440+
441+
# Log parameters
442+
params = {
443+
'model_name': self.model_name,
444+
'scenario': self.scenario,
445+
'test_mode': self.test_mode,
446+
'batch_size': str(self.batch_size),
447+
'num_samples': str(self.num_samples)
448+
}
449+
self.mlflow_client.log_parameters(params)
450+
except Exception as e:
451+
self.logger.warning(f"Failed to start MLflow run: {e}")
452+
mlflow_run_started = False
453+
361454
# Start server if needed
362455
if not self.api_server_url:
363456
self.start_server()
@@ -459,25 +552,50 @@ def run(self, user_conf: str = "user.conf", lg_model_name: str = "llama3_1-8b")
459552
if self.enable_metrics and self.metrics_collector:
460553
self._generate_metrics_visualizations()
461554

462-
return {
555+
test_results = {
463556
'status': 'success',
464557
'duration': test_duration,
465558
'scenario': self.scenario,
466559
'test_mode': self.test_mode,
467560
'num_samples': self.num_samples
468561
}
469562

563+
# Upload to MLflow if configured
564+
if mlflow_run_started and self.mlflow_client:
565+
try:
566+
self._upload_to_mlflow(test_results)
567+
except Exception as e:
568+
self.logger.warning(f"Failed to upload to MLflow: {e}")
569+
570+
return test_results
571+
470572
except Exception as e:
471573
self.logger.error(f"Test failed: {e}", exc_info=True)
472-
return {
574+
test_results = {
473575
'status': 'failed',
474576
'error': str(e)
475577
}
578+
579+
# Upload failure to MLflow if configured
580+
if mlflow_run_started and self.mlflow_client:
581+
try:
582+
self._upload_to_mlflow(test_results)
583+
except Exception as e:
584+
self.logger.warning(f"Failed to upload failure to MLflow: {e}")
585+
586+
return test_results
476587

477588
finally:
478589
# Cleanup in reverse order of initialization
479590
self.logger.info("Performing cleanup...")
480591

592+
# End MLflow run if started
593+
if mlflow_run_started and self.mlflow_client:
594+
try:
595+
self.mlflow_client.end_run()
596+
except Exception as e:
597+
self.logger.warning(f"Error ending MLflow run: {e}")
598+
481599
# Stop metrics collector if not already stopped (in case of exception)
482600
if metrics_started and self.metrics_collector:
483601
try:
@@ -510,6 +628,30 @@ def run(self, user_conf: str = "user.conf", lg_model_name: str = "llama3_1-8b")
510628

511629
self.logger.info("Cleanup completed")
512630

631+
def _upload_to_mlflow(self, test_results: Dict[str, Any]):
632+
"""Upload test results and artifacts to MLflow."""
633+
if not self.mlflow_client:
634+
return
635+
636+
try:
637+
# Log client-specific metrics
638+
self.mlflow_client.log_client_metrics(test_results)
639+
640+
# Generate and log description
641+
description = self.mlflow_client.get_client_description(test_results)
642+
self.mlflow_client.log_description(description)
643+
644+
# Upload artifacts - upload entire output directory
645+
self.mlflow_client.upload_artifacts(
646+
output_dir=str(self.mlflow_output_dir),
647+
include_subdirs=True
648+
)
649+
650+
self.logger.info("Successfully uploaded to MLflow")
651+
except Exception as e:
652+
self.logger.error(f"Failed to upload to MLflow: {e}", exc_info=True)
653+
raise
654+
513655
def _load_samples_to_ram(self, query_samples):
514656
"""LoadGen callback - samples are pre-loaded in Dataset."""
515657
pass
@@ -524,11 +666,11 @@ def _generate_metrics_visualizations(self):
524666
self.logger.warning("Metrics collector or visualizer not available")
525667
return
526668

669+
if not MATPLOTLIB_AVAILABLE:
670+
self.logger.warning("Matplotlib not available. Visualizations will not be generated.")
671+
return
672+
527673
try:
528-
# Set matplotlib backend for headless environments
529-
import matplotlib
530-
matplotlib.use('Agg') # Use non-interactive backend
531-
532674
storage_file = self.metrics_collector._get_storage_file_path()
533675
if not storage_file:
534676
self.logger.warning("No storage file path available from metrics collector")
@@ -615,8 +757,7 @@ def _generate_metrics_visualizations(self):
615757
save_path=str(save_path),
616758
show_labels=False # Disable label grouping to avoid dict unhashable error
617759
)
618-
# Close plot to free memory
619-
import matplotlib.pyplot as plt
760+
# Close all figures to free memory and ensure clean state for next plot
620761
plt.close('all')
621762
self.logger.info(f"✓ Generated visualization: {save_path}")
622763
successful_viz += 1
@@ -627,7 +768,6 @@ def _generate_metrics_visualizations(self):
627768
self.logger.error(f"Failed to generate visualization {viz['filename']}: {e}", exc_info=True)
628769
# Close plot even on error
629770
try:
630-
import matplotlib.pyplot as plt
631771
plt.close('all')
632772
except:
633773
pass
@@ -643,8 +783,6 @@ def _generate_metrics_visualizations(self):
643783

644784
def main():
645785
"""Main entry point for harness."""
646-
import argparse
647-
648786
parser = argparse.ArgumentParser(description="MLPerf Harness for Llama 3.1 8B")
649787

650788
parser.add_argument("--model", type=str, required=True, help="Model name or path")
@@ -663,6 +801,16 @@ def main():
663801
parser.add_argument("--lg-model-name", type=str, default="llama3_1-8b", help="LoadGen model name")
664802
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
665803

804+
# MLflow arguments
805+
parser.add_argument("--mlflow-experiment-name", type=str, default=None,
806+
help="MLflow experiment name (enables MLflow tracking)")
807+
parser.add_argument("--mlflow-output-dir", type=str, default=None,
808+
help="Output directory to upload to MLflow (defaults to --output-dir)")
809+
parser.add_argument("--mlflow-host", type=str, default="localhost",
810+
help="MLflow tracking server hostname")
811+
parser.add_argument("--mlflow-port", type=int, default=5000,
812+
help="MLflow tracking server port")
813+
666814
args = parser.parse_args()
667815

668816
# Configure logging
@@ -675,10 +823,14 @@ def main():
675823
# Load server config if provided
676824
server_config = {}
677825
if args.server_config:
678-
from backendserver import load_server_config
679826
server_config = load_server_config(args.server_config)
680827
server_config['config_file'] = args.server_config
681828

829+
# Construct MLflow tracking URI if experiment name is provided
830+
mlflow_tracking_uri = None
831+
if args.mlflow_experiment_name:
832+
mlflow_tracking_uri = f"http://{args.mlflow_host}:{args.mlflow_port}"
833+
682834
# Create and run harness
683835
harness = Llama31_8BHarness(
684836
model_name=args.model,
@@ -690,7 +842,10 @@ def main():
690842
batch_size=args.batch_size,
691843
num_samples=args.num_samples,
692844
output_dir=args.output_dir,
693-
enable_metrics=args.enable_metrics
845+
enable_metrics=args.enable_metrics,
846+
mlflow_tracking_uri=mlflow_tracking_uri,
847+
mlflow_experiment_name=args.mlflow_experiment_name,
848+
mlflow_output_dir=args.mlflow_output_dir
694849
)
695850

696851
results = harness.run(user_conf=args.user_conf, lg_model_name=args.lg_model_name)
@@ -705,4 +860,3 @@ def main():
705860

706861
if __name__ == "__main__":
707862
sys.exit(main())
708-

harness/mlflow_tools/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# MLflow client module
2+
try:
3+
from .mlflow_client import MLflowClient
4+
__all__ = ['MLflowClient']
5+
except ImportError:
6+
__all__ = []
7+

0 commit comments

Comments
 (0)