Skip to content

Commit d4be430

Browse files
committed
Make sure that arrays are unaligned for stacking (and fix concatenate -> stack)
1 parent e78d2a6 commit d4be430

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

bench/ndarray/stack.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def run_benchmark(num_arrays=10, size=500, aligned_chunks=False, axis=0,
4444
chunk_shapes = [(chunks[0], chunks[1]) for shape in shapes]
4545
else:
4646
# Unaligned chunks: not divisors of shape dimensions
47-
chunk_shapes = [(chunks[0] + 1, chunks[1] - 1) for shape in shapes]
47+
chunk_shapes = []
48+
for i in range(len(shapes)):
49+
added_random_size = np.random.randint(1, 10) # Random size to ensure unalignment
50+
chunk_shapes.append((chunks[0] + added_random_size, chunks[1] - added_random_size))
4851

4952
# Create arrays
5053
arrays = []
@@ -108,7 +111,7 @@ def run_numpy_benchmark(num_arrays=10, size=500, axis=0, dtype=np.float64, datad
108111
total_elements = sum(np.prod(shape) for shape in shapes)
109112
data_size_gb = total_elements * 4 / (1024**3) # Convert bytes to GB
110113

111-
# Time the concatenation
114+
# Time the stacking
112115
start_time = time.time()
113116
result = np.stack(numpy_arrays, axis=axis)
114117
duration = time.time() - start_time
@@ -190,11 +193,11 @@ def autolabel(rects, ax):
190193

191194
# Save the plot
192195
plt.tight_layout()
193-
plt.savefig(os.path.join(output_dir, 'concatenate_benchmark_combined.png'), dpi=100)
196+
plt.savefig(os.path.join(output_dir, 'stack_benchmark_combined.png'), dpi=100)
194197
plt.show()
195198
plt.close()
196199

197-
print(f"Combined plot saved to {os.path.join(output_dir, 'concatenate_benchmark_combined.png')}")
200+
print(f"Combined plot saved to {os.path.join(output_dir, 'stack_benchmark_combined.png')}")
198201

199202

200203
def main():
@@ -257,9 +260,9 @@ def main():
257260

258261
# Quick verification of result shape
259262
if axis == 0:
260-
expected_shape = (10, size, size // num_arrays) # After concatenation along axis 0
263+
expected_shape = (10, size, size // num_arrays) # After stacking along axis 0
261264
else:
262-
expected_shape = (size, size // num_arrays, 10) # After concatenation along axis - 1
265+
expected_shape = (size, size // num_arrays, 10) # After stacking along axis - 1
263266

264267
# Verify shapes match
265268
shapes = [numpy_shape, shape1, shape2]

0 commit comments

Comments
 (0)