Skip to content

Commit 1527bbd

Browse files
authored
🌱 [MRG] [exp] Fix seed
Merge pull request #46 from MLSysOps/fix-seed
2 parents 0d7db27 + 897d8aa commit 1527bbd

File tree

4 files changed

+53
-40
lines changed

4 files changed

+53
-40
lines changed

exps/abr_sim.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,13 @@ def train_agent(
322322
memory,
323323
max_messages,
324324
summary_every: int = 100,
325-
base_seed: int | None = None,
326325
):
327326
ep = 0
328327
global_step = 0
329328
episode_reward = 0.0
330329
step_logs = []
331330
# Single-stream seeding: seed only the first reset, then let RNG continue
332-
obs, info = env.reset(seed=base_seed)
331+
obs, info = env.reset()
333332

334333
system_prompt = f"{obs['task_description']}\n\n{memory.as_context_block()}\n\n"
335334
chat_history = [{"role": "system", "content": system_prompt}]
@@ -543,12 +542,14 @@ def run_llm_agent(
543542
max_episode_steps=1000_000,
544543
reward_scale=reward_scale,
545544
dummy_mode=not real_env,
545+
seed=seed,
546546
)
547547
test_env = infragym.make(
548548
'abr_sim',
549549
max_episode_steps=test_steps,
550550
reward_scale=reward_scale,
551551
dummy_mode=not real_env,
552+
seed=seed + 10_000 if seed is not None else None, # Stable test seed
552553
)
553554

554555
# Environment summary table
@@ -568,6 +569,7 @@ def run_llm_agent(
568569
"test_steps": test_steps,
569570
"summary_every": summary_every,
570571
"reward_scale": reward_scale,
572+
"seed": seed,
571573
"real_env": real_env,
572574
"model_name": model_name,
573575
"thinking_mode": thinking_mode,
@@ -590,7 +592,6 @@ def run_llm_agent(
590592
memory=memory,
591593
max_messages=max_messages,
592594
summary_every=summary_every, # summarize every 100 steps
593-
base_seed=seed, # seed only the first reset
594595
):
595596
# ---- Structured Episode Summary (printed + stored) ----
596597
ep_summary = result["ep_summary"]
@@ -679,6 +680,7 @@ def run_rule_policy(
679680
max_episode_steps=test_steps,
680681
reward_scale=reward_scale,
681682
dummy_mode=not real_env,
683+
seed=seed,
682684
)
683685

684686
# Environment summary table
@@ -702,7 +704,7 @@ def run_rule_policy(
702704
)
703705
console.print(Panel(f"Trace Space: {trace_space.location}", title="Tracing", border_style="blue"))
704706

705-
obs, info = env.reset(seed=seed)
707+
obs, info = env.reset()
706708
ep = 0 # No episodes in rule policy, just a single test run
707709
test_steps = 0
708710
test_total_reward = 0.0
@@ -806,6 +808,7 @@ def run_llm(
806808
max_episode_steps=test_steps,
807809
reward_scale=reward_scale,
808810
dummy_mode=not real_env,
811+
seed=seed,
809812
)
810813

811814
# Environment summary table
@@ -824,14 +827,15 @@ def run_llm(
824827
"test_steps": test_steps,
825828
"reward_scale": reward_scale,
826829
"real_env": real_env,
830+
"seed": seed,
827831
"model_name": model_name,
828832
"thinking_mode": thinking_mode,
829833
"max_messages": max_messages,
830834
}
831835
)
832836
console.print(Panel(f"Trace Space: {trace_space.location}", title="Tracing", border_style="blue"))
833837

834-
obs, info = env.reset(seed=seed)
838+
obs, info = env.reset()
835839
ep = 0 # No episodes in rule policy, just a single test run
836840
test_steps = 1
837841
test_total_reward = 0.0

exps/load_balance_sim.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def run_llm_agent(
412412
num_servers=num_servers,
413413
max_episode_steps=max_steps,
414414
reward_scale=reward_scale,
415+
seed=seed,
415416
)
416417

417418
# Environment summary
@@ -446,7 +447,7 @@ def run_llm_agent(
446447
for ep in range(episodes):
447448
console.rule(f"Episode {ep + 1}")
448449
# Single-stream seeding: only the FIRST reset gets the seed
449-
obs, info = env.reset(seed=seed if ep == 0 else None)
450+
obs, info = env.reset()
450451

451452
# Build system prompt with task + memory
452453
system_prompt = (
@@ -603,6 +604,7 @@ def run_rule_policy(
603604
max_episode_steps=test_steps,
604605
reward_scale=reward_scale,
605606
num_servers=num_servers,
607+
seed=seed,
606608
)
607609

608610
# Environment summary
@@ -628,7 +630,7 @@ def run_rule_policy(
628630
)
629631
console.print(Panel(f"Trace Space: {trace_space.location}", title="Tracing", border_style="blue"))
630632

631-
obs, info = env.reset(seed=seed)
633+
obs, info = env.reset()
632634
ep = 0 # No episodes in rule policy, just a single test run
633635
test_steps = 0
634636
test_total_reward = 0.0
@@ -785,6 +787,7 @@ def run_llm(
785787
max_episode_steps=test_steps,
786788
reward_scale=reward_scale,
787789
num_servers=num_servers,
790+
seed=seed,
788791
)
789792

790793
# Environment summary
@@ -814,7 +817,7 @@ def run_llm(
814817
# Tracing space for this run
815818
console.print(Panel(f"Trace Space: {trace_space.location}", title="Tracing", border_style="blue"))
816819

817-
obs, info = env.reset(seed=seed)
820+
obs, info = env.reset()
818821
ep = 0
819822
chat_history = [{"role": "system", "content": obs["task_description"]}]
820823
steps = 1

infragym/abr_sim/abr_llm_gym.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ def __init__(
3232
seed: Optional[int] = 42,
3333
):
3434
super().__init__()
35-
self.seed = seed
36-
if self.seed is not None:
37-
np.random.seed(self.seed)
3835

3936
self.max_episode_steps = max_episode_steps
4037
self.enable_llm_friendly_obs = enable_llm_friendly_obs
@@ -57,11 +54,11 @@ def __init__(
5754
self.past_download_times = deque(maxlen=self.past_chunk_window)
5855

5956
# Simulated network conditions (simplified for demo)
60-
self.network_conditions = self._generate_network_traces()
57+
self.network_conditions: np.ndarray = None # Generated / loaded on reset
6158
self.current_network_idx = 0
6259

6360
# Video chunk sizes (simplified)
64-
self.chunk_sizes = self._generate_chunk_sizes()
61+
self.chunk_sizes: np.ndarray = None # Generated / loaded on reset
6562

6663
# Setup spaces
6764
self.action_space = gym.spaces.Discrete(len(self.bitrate_levels))
@@ -87,7 +84,7 @@ def __init__(
8784
Each chunk is 4 seconds of video content.
8885
"""
8986

90-
self.reset()
87+
self.reset(seed=seed)
9188

9289
def _generate_network_traces(self) -> np.ndarray:
9390
"""Generate simplified network bandwidth traces for simulation."""
@@ -97,7 +94,7 @@ def _generate_network_traces(self) -> np.ndarray:
9794
sigma = 0.6 # multiplicative variability
9895

9996
mu = np.log(median_mbps)
100-
x = np.random.normal(mu, sigma, size=2000)
97+
x = self.np_random.normal(mu, sigma, size=2000)
10198
bw = np.exp(x)
10299
# AR(1) smoothing
103100
for t in range(1, 2000):
@@ -107,7 +104,7 @@ def _generate_network_traces(self) -> np.ndarray:
107104
return traces
108105
else:
109106
all_traces = np.load(TRACE_PATH)
110-
traces = all_traces[np.random.choice(len(all_traces))]
107+
traces = all_traces[self.np_random.choice(len(all_traces))]
111108
return traces[1]
112109

113110
def _generate_chunk_sizes(self) -> np.ndarray:
@@ -119,7 +116,7 @@ def _generate_chunk_sizes(self) -> np.ndarray:
119116
# Approximate: bitrate * chunk_duration / 8 (bytes)
120117
base_size = bitrate * self.chunk_duration * 1e6 / 8
121118
# Add some variation
122-
variations = np.random.normal(1.0, 0.1, 100)
119+
variations = self.np_random.normal(1.0, 0.1, 100)
123120
sizes = base_size * variations
124121
chunk_sizes.append(sizes)
125122
return np.array(chunk_sizes)
@@ -202,6 +199,12 @@ def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> T
202199
self.past_download_times.clear()
203200
self.current_network_idx = 0
204201

202+
# Generate network conditions and chunk sizes
203+
if self.network_conditions is None:
204+
self.network_conditions = self._generate_network_traces()
205+
if self.chunk_sizes is None:
206+
self.chunk_sizes = self._generate_chunk_sizes()
207+
205208
# Initialize past observations
206209
for _ in range(self.past_chunk_window):
207210
self.past_throughputs.append(0.0)

infragym/load_balance/load_balance_llm_gym.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,43 +35,31 @@ def __init__(
3535
arrival_rate: Optional[float] = None,
3636
):
3737
super().__init__()
38-
self.seed = seed
39-
if self.seed is not None:
40-
np.random.seed(self.seed)
38+
self.utilization_target = utilization_target
39+
self.auto_scale_arrivals = auto_scale_arrivals
40+
self._arrival_rate_arg = arrival_rate
4141

4242
self.num_servers = num_servers
4343
self.max_episode_steps = max_episode_steps
4444
self.enable_llm_friendly_obs = enable_llm_friendly_obs
4545
self.reward_scale = reward_scale
4646

4747
# Load balancing parameters
48-
self.service_rates = self._generate_service_rates()
48+
self.service_rates = None
49+
self.arrival_rate = None
4950
self.job_size_range = (1, 10) # Job sizes in arbitrary units (smaller for faster processing)
5051

51-
# compute E[size] for your job distribution (Pareto a=2, xm=1, truncated at b)
52-
def expected_job_size(xm=1.0, b=10.0) -> float:
53-
# closed form for α=2, truncated at b: E[X | X≤b] = 2*xm / (1 + xm/b)
54-
return 2.0 * xm / (1.0 + xm / b)
55-
56-
if auto_scale_arrivals and arrival_rate is None:
57-
e_s = expected_job_size(*self.job_size_range)
58-
cap = float(sum(self.service_rates)) # size units / time
59-
self.arrival_rate = max(1e-6, utilization_target * cap / e_s)
60-
else:
61-
# Jobs per time unit (higher rate for more jobs)
62-
self.arrival_rate = arrival_rate if arrival_rate is not None else 2.0
63-
6452
# Environment state
6553
self.current_step = 0
6654
self.current_time = 0.0
67-
self.servers = self._initialize_servers()
55+
self.servers = [] # Init later in reset
6856
self.job_queue = deque()
6957
self.finished_jobs = []
7058
self.total_waiting_time = 0.0
7159
self.total_processing_time = 0.0
7260

7361
# Job generation
74-
self.next_job_arrival = self._generate_next_arrival()
62+
self.next_job_arrival = None # Set in reset
7563

7664
# Setup spaces
7765
self.action_space = gym.spaces.Discrete(num_servers)
@@ -97,7 +85,7 @@ def expected_job_size(xm=1.0, b=10.0) -> float:
9785
Each server has different processing capabilities (service rates).
9886
"""
9987

100-
self.reset()
88+
self.reset(seed=seed)
10189

10290
def _generate_service_rates(self) -> List[float]:
10391
"""Generate service rates for servers (jobs per time unit)."""
@@ -107,7 +95,7 @@ def _generate_service_rates(self) -> List[float]:
10795
for i in range(self.num_servers):
10896
# Add some variation based on server ID
10997
variation = 0.3 * np.sin(i * np.pi / self.num_servers)
110-
rate = base_rate + variation + np.random.normal(0, 0.2)
98+
rate = base_rate + variation + self.np_random.normal(0, 0.2)
11199
rates.append(max(0.5, rate)) # Ensure minimum rate
112100
return rates
113101

@@ -129,11 +117,11 @@ def _initialize_servers(self) -> List[Dict]:
129117

130118
def _generate_next_arrival(self) -> float:
131119
"""Generate next job arrival time using exponential distribution."""
132-
return self.current_time + np.random.exponential(1.0 / self.arrival_rate)
120+
return self.current_time + self.np_random.exponential(1.0 / self.arrival_rate)
133121

134122
def _generate_job(self) -> Dict:
135123
"""Generate a new job with size and arrival time."""
136-
size = np.random.pareto(2.0) + 1 # Pareto distribution for job sizes
124+
size = self.np_random.pareto(2.0) + 1 # Pareto distribution for job sizes
137125
size = min(size, self.job_size_range[1]) # Cap maximum size
138126

139127
return {
@@ -249,6 +237,21 @@ def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> T
249237
"""Reset the environment."""
250238
super().reset(seed=seed)
251239

240+
# Generate service rates ONCE (first episode) after seeding → fixed across episodes
241+
if self.service_rates is None:
242+
self.service_rates = self._generate_service_rates()
243+
244+
# arrival rate can be recomputed each reset (deterministic given service_rates)
245+
def expected_job_size(xm=1.0, b=10.0) -> float:
246+
return 2.0 * xm / (1.0 + xm / b)
247+
248+
if self.auto_scale_arrivals and self._arrival_rate_arg is None:
249+
e_s = expected_job_size(1.0, 10.0)
250+
cap = float(sum(self.service_rates))
251+
self.arrival_rate = max(1e-6, self.utilization_target * cap / e_s)
252+
else:
253+
self.arrival_rate = self._arrival_rate_arg if self._arrival_rate_arg is not None else 2.0
254+
252255
self.current_step = 0
253256
self.current_time = 0.0
254257
self.servers = self._initialize_servers()

0 commit comments

Comments
 (0)