Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 56 additions & 47 deletions eval/distribution_comparison_boxplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import data_cache

# Increase all font sizes by 16 points from their defaults
rcParams.update({key: rcParams[key] + 16 for key in rcParams if "size" in key and isinstance(rcParams[key], (int, float))})
rcParams.update({key: rcParams[key] + 32 for key in rcParams if "size" in key and isinstance(rcParams[key], (int, float))})

# ============================================================================
# CONFIGURATION SECTION
Expand Down Expand Up @@ -249,26 +249,25 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o

print(f"Found {len(block_runs)} block runs and {len(zns_runs)} ZNS runs")

# Create figure with 6 subplots (1 row x 6 columns) using GridSpec for custom widths
# Create figure with 6 subplots (2 rows x 3 columns) using GridSpec for custom widths
# Subplots are 2/5 original height, but all spacing preserved
num_subplots = len(RATIOS) * len(CHUNK_SIZES) # 2 ratios * 3 chunk sizes = 6
num_ratios = len(RATIOS)
num_chunk_sizes = len(CHUNK_SIZES)

# Width ratios: 1077MiB subplots (indices 2 and 5) are half as wide since they have 2 boxes instead of 4
width_ratios = [1, 1, 0.5, 1, 1, 0.5]
# Width ratios: 1077MiB subplots (column 2) are half as wide since they have 2 boxes instead of 4
width_ratios = [1, 1, 0.5]

fig = plt.figure(figsize=(5 * num_subplots, 5.38))
gs = GridSpec(1, num_subplots, figure=fig, width_ratios=width_ratios)
axes = [fig.add_subplot(gs[0, i]) for i in range(num_subplots)]
fig = plt.figure(figsize=(8.0 * num_chunk_sizes, 7.0 * num_ratios))
gs = GridSpec(num_ratios, num_chunk_sizes, figure=fig, width_ratios=width_ratios)
axes = [[fig.add_subplot(gs[i, j]) for j in range(num_chunk_sizes)] for i in range(num_ratios)]

# First pass: collect all data to find global maximum for y-axis
all_subplot_data = []
global_max = 0.0

idx = 0

# Iterate through ratios, then chunk sizes
for ratio in RATIOS:
for chunk_size in CHUNK_SIZES:
for ratio_idx, ratio in enumerate(RATIOS):
for chunk_idx, chunk_size in enumerate(CHUNK_SIZES):
# Prepare data for this subplot
current_data = []
labels = []
Expand Down Expand Up @@ -364,7 +363,9 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
'labels': labels,
'colors': colors,
'hatches': hatches,
'chunk_size': chunk_size
'chunk_size': chunk_size,
'ratio_idx': ratio_idx,
'chunk_idx': chunk_idx
})

# Update global maximum if using common scale
Expand All @@ -375,8 +376,6 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
if local_max > global_max:
global_max = local_max

idx += 1

# Add some padding to the global max (10% above highest value)
if common_y_scale:
y_max = global_max * 1.1
Expand All @@ -385,13 +384,14 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
y_max = None

# Second pass: create boxplots with common y-axis
idx = 0
for subplot_info in all_subplot_data:
current_data = subplot_info['data']
labels = subplot_info['labels']
colors = subplot_info['colors']
hatches = subplot_info['hatches']
chunk_size = subplot_info['chunk_size']
ratio_idx = subplot_info['ratio_idx']
chunk_idx = subplot_info['chunk_idx']

# Create boxplot for this subplot
if current_data:
Expand All @@ -404,10 +404,13 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
else:
box_width = 0.8 * (num_boxes / 4) if num_boxes > 0 else 0.8

bp = axes[idx].boxplot(current_data,
bp = axes[ratio_idx][chunk_idx].boxplot(current_data,
showfliers=show_outliers,
widths=box_width,
medianprops=dict(linewidth=2, color='black'),
boxprops=dict(linewidth=3),
whiskerprops=dict(linewidth=3),
capprops=dict(linewidth=3),
medianprops=dict(linewidth=3, color='black'),
patch_artist=True)

# Apply colors and hatches
Expand All @@ -418,46 +421,45 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
box.set_alpha(0.7)

# Set x-axis labels (empty for cleaner look, or could add device labels)
axes[idx].set_xticks(range(1, len(labels) + 1))
axes[idx].set_xticklabels([], rotation=45, fontsize=10)
axes[ratio_idx][chunk_idx].set_xticks(range(1, len(labels) + 1))
axes[ratio_idx][chunk_idx].set_xticklabels([], rotation=45, fontsize=10)

# Add chunk size label below subplot
axes[idx].set_xlabel(CHUNK_SIZE_LABELS[chunk_size], fontsize=28, weight='bold')
axes[ratio_idx][chunk_idx].set_xlabel(CHUNK_SIZE_LABELS[chunk_size], fontsize=58, weight='bold')

# Use scalar formatter without scientific notation
axes[idx].yaxis.set_major_formatter(ticker.ScalarFormatter(useOffset=False, useMathText=False))
axes[ratio_idx][chunk_idx].yaxis.set_major_formatter(ticker.ScalarFormatter(useOffset=False, useMathText=False))

# Rotate y-axis labels
for label in axes[idx].get_yticklabels():
for label in axes[ratio_idx][chunk_idx].get_yticklabels():
label.set_rotation(45)

# Set y-axis range
if common_y_scale:
# Use common y-axis maximum for all subplots
axes[idx].set_ylim(0, y_max)
axes[ratio_idx][chunk_idx].set_ylim(0, y_max)
else:
# Just set bottom to 0, let matplotlib auto-scale the top
axes[idx].set_ylim(bottom=0)

idx += 1
axes[ratio_idx][chunk_idx].set_ylim(bottom=0)

# Add y-axis label on the far left
fig.text(-0.005, 0.5, metric_label, va='center', rotation='vertical', fontsize=22, weight='bold')
fig.text(-0.065, 0.5, metric_label, va='center', rotation='vertical', fontsize=64, weight='bold')

# Adjust layout (do these BEFORE computing positions) - subplots at 2/5 height with proportional spacing
plt.subplots_adjust(wspace=0.2, hspace=0.0)
plt.tight_layout(pad=0.0)
plt.subplots_adjust(top=0.851, bottom=0.279, left=0.05)
plt.subplots_adjust(top=0.90, bottom=0.15, left=0.08, hspace=1.2, wspace=0.3)

# Make sure layout is finalized
fig.canvas.draw()

# Compute positions for ratio boxes and labels based on actual subplot bounds
axes_bboxes = [ax.get_position().bounds for ax in axes] # (x, y, w, h) per axes
# Flatten the 2D axes array to get all subplot bounds
axes_flat = [axes[i][j] for i in range(num_ratios) for j in range(num_chunk_sizes)]
axes_bboxes = [ax.get_position().bounds for ax in axes_flat] # (x, y, w, h) per axes

# First 3 subplots -> Ratio 1:2, next 3 -> Ratio 1:10
group1 = axes_bboxes[0:3]
group2 = axes_bboxes[3:6]
# First row (3 subplots) -> Ratio 1:2, second row (3 subplots) -> Ratio 1:10
group1 = axes_bboxes[0:3] # First row
group2 = axes_bboxes[3:6] # Second row

# Left/right bounds of each group
g1_left = group1[0][0]
Expand All @@ -468,14 +470,21 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
g2_right = group2[-1][0] + group2[-1][2]
g2_width = g2_right - g2_left

# Vertical placement of the grey boxes in figure coords
box_y = 0.85
box_h = 0.10
# Vertical placement of the grey boxes - position above each row
# Get the top y position of each row's subplots and add some padding
g1_top = group1[0][1] + group1[0][3] # y + height of first row
g2_top = group2[0][1] + group2[0][3] # y + height of second row

box_h = 0.06
box_y_offset = 0.02 # Space above the subplot

g1_box_y = g1_top + box_y_offset
g2_box_y = g2_top + box_y_offset

# Grey box for Ratio 1:2
# Grey box for Ratio 1:2 (first row)
fig.add_artist(
Rectangle(
(g1_left, box_y),
(g1_left, g1_box_y),
g1_width,
box_h,
transform=fig.transFigure,
Expand All @@ -487,10 +496,10 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
)
)

# Grey box for Ratio 1:10
# Grey box for Ratio 1:10 (second row)
fig.add_artist(
Rectangle(
(g2_left, box_y),
(g2_left, g2_box_y),
g2_width,
box_h,
transform=fig.transFigure,
Expand All @@ -505,21 +514,21 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
# Centered text in each box
fig.text(
g1_left + g1_width / 2,
box_y + box_h / 2,
g1_box_y + box_h / 2,
"Ratio: 1:2",
ha='center',
va='center',
fontsize=26,
fontsize=50,
weight='bold',
zorder=2,
)
fig.text(
g2_left + g2_width / 2,
box_y + box_h / 2,
g2_box_y + box_h / 2,
"Ratio: 1:10",
ha='center',
va='center',
fontsize=26,
fontsize=50,
weight='bold',
zorder=2,
)
Expand All @@ -533,9 +542,9 @@ def generate_distribution_comparison(block_dir, zns_dir, distribution, metric, o
]

fig.legend(
ncols=4,
ncols=2,
handles=legend_patches,
bbox_to_anchor=(0.5, 0.08),
bbox_to_anchor=(0.5, -0.08),
loc='center',
fontsize="large",
columnspacing=2.0,
Expand Down
Loading
Loading