From 2c9cc608e3301753c3d06a616111fabae40fb590 Mon Sep 17 00:00:00 2001 From: ssanjeev2016 <32891905+ssanjeev2016@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:09:24 -0700 Subject: [PATCH 1/2] First pass at challenge --- README.md | 5 +- placement.py | 400 ++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 319 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index df0c441..f6682cd 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ We will review submissions on a rolling basis. | Rank | Name | Overlap | Wirelength (um) | Runtime (s) | Notes | |------|-----------------|-------------|-----------------|-------------|----------------------| -| 1 | example | 0.5000 | 0.5 | 10 | example submission | -| 2 | Add Yours! | | | | | +| 1 | Sanjeev Sinha | 0.0000 | 0.1493 | 4057.78 | KD-Tree usage for large designs with convergence checks | +| 2 | example | 0.5000 | 0.5 | 10 | example submission | +| 3 | Add Yours! | | | | | diff --git a/placement.py b/placement.py index d70412d..d0133cb 100644 --- a/placement.py +++ b/placement.py @@ -44,6 +44,9 @@ import torch import torch.optim as optim +import time +from scipy.spatial import KDTree +import numpy as np # Feature index enums for cleaner code access class CellFeatureIdx(IntEnum): @@ -296,49 +299,60 @@ def wirelength_attraction_loss(cell_features, pin_features, edge_list): # Total wirelength total_wirelength = torch.sum(smooth_manhattan) - return total_wirelength / edge_list.shape[0] # Normalize by number of edges + # Use a loss function which has a steady gradient as x -> 0. + # Use a multiplier to make the gradient flatter as we don't want + # this value to have as much of an effect as overlap initially. + loss = torch.log1p(1/100 * total_wirelength**2) + return loss +# TODO: Split the batches across threads and serially sum once all threads are done. +def overlap_repulsion_loss_batched(cell_features, batch_size=1000): + """Calculate overlap loss in a batched manner due to memory constraints. -def overlap_repulsion_loss(cell_features, pin_features, edge_list): - """Calculate loss to prevent cell overlaps. + Args: + cell_features: [N, 6] tensor with [area, num_pins, x, y, width, height] + batch_size: Number of cells to compute overlap loss for at a time. + + Returns: + Scalar loss value + """ + N = cell_features.shape[0] # Total number of cells (macro + std) + pos = cell_features[:, 2:4] # x,y center positions of each cell + shapes = cell_features[:, 4:] # width, height of each cell + + batch_totals = [] + + with torch.no_grad(): + for start in range(0, N, batch_size): + end = min(start + batch_size, N) + + pos_i = pos[start:end].unsqueeze(1) # [batch, 1, 2] + pos_j = pos.unsqueeze(0) # [1, N, 2] + diff = torch.abs(pos_i - pos_j) # [batch, N, 2] - TODO: IMPLEMENT THIS FUNCTION - - This is the main challenge. You need to implement a differentiable loss function - that penalizes overlapping cells. The loss should: - - 1. Be zero when no cells overlap - 2. Increase as overlap area increases - 3. Use only differentiable PyTorch operations (no if statements on tensors) - 4. Work efficiently with vectorized operations - - HINTS: - - Two axis-aligned rectangles overlap if they overlap in BOTH x and y dimensions - - For rectangles centered at (x1, y1) and (x2, y2) with widths (w1, w2) and heights (h1, h2): - * x-overlap occurs when |x1 - x2| < (w1 + w2) / 2 - * y-overlap occurs when |y1 - y2| < (h1 + h2) / 2 - - Use torch.relu() to compute positive overlaps: overlap_x = relu((w1+w2)/2 - |x1-x2|) - - Overlap area = overlap_x * overlap_y - - Consider all pairs of cells: use broadcasting with unsqueeze - - Use torch.triu() to avoid counting each pair twice (only consider i < j) - - Normalize the loss appropriately (by number of pairs or total area) - - RECOMMENDED APPROACH: - 1. Extract positions, widths, heights from cell_features - 2. Compute all pairwise distances using broadcasting: - positions_i = positions.unsqueeze(1) # [N, 1, 2] - positions_j = positions.unsqueeze(0) # [1, N, 2] - distances = positions_i - positions_j # [N, N, 2] - 3. Calculate minimum separation distances for each pair - 4. Use relu to get positive overlap amounts - 5. Multiply overlaps in x and y to get overlap areas - 6. Mask to only consider upper triangle (i < j) - 7. Sum and normalize + min_sep = (shapes[start:end].unsqueeze(1) + shapes.unsqueeze(0)) / 2 # [batch, N, 2] + overlap = torch.relu(min_sep - diff) # [batch, N, 2] + overlap_areas = overlap[:, :, 0] * overlap[:, :, 1] # [batch, N] + + batch_indices = torch.arange(start, end).unsqueeze(1) # [batch, 1] + col_indices = torch.arange(N).unsqueeze(0) # [1, N] + mask = col_indices > batch_indices # [batch, N] + + # Keep as tensor — don't convert to float yet + batch_totals.append((overlap_areas * mask).sum()) + + # Sum tensors, then convert once — preserves precision + return torch.stack(batch_totals).sum() + +def overlap_repulsion_loss(cell_features, pin_features, edge_list, pairs=None): + """Calculate loss to prevent cell overlaps. Args: cell_features: [N, 6] tensor with [area, num_pins, x, y, width, height] pin_features: [P, 7] tensor with pin information (not used here) edge_list: [E, 2] tensor with edges (not used here) + pairs: List of candidate cell pairs which are overlapping depending on a predefined radius. + See get_nearby_pairs_kdtree for more details Returns: Scalar loss value (should be 0 when no overlaps exist) @@ -347,25 +361,148 @@ def overlap_repulsion_loss(cell_features, pin_features, edge_list): if N <= 1: return torch.tensor(0.0, requires_grad=True) - # TODO: Implement overlap detection and loss calculation here - # - # Your implementation should: - # 1. Extract cell positions, widths, and heights - # 2. Compute pairwise overlaps using vectorized operations - # 3. Return a scalar loss that is zero when no overlaps exist - # - # Delete this placeholder and add your implementation: + # Used in "dense" overlap check when there are a large amount of cells + # aka when we need to verify there is truly no overlap across all cells + if very_large_design(N) and pairs is None: + return overlap_repulsion_loss_batched(cell_features) - # Placeholder - returns a constant loss (REPLACE THIS!) - return torch.tensor(1.0, requires_grad=True) + pos = cell_features[:, 2:4] # x, y + shape = cell_features[:, 4:] # width, height + if pairs is None or not use_pairs(N): + pos_i = pos.unsqueeze(1) # [N, 1, 2] + pos_j = pos.unsqueeze(0) # [1, N, 2] + diff = torch.abs(pos_i - pos_j) # [N, N, 2] + + min_sep = (shape.unsqueeze(1) + shape.unsqueeze(0)) / 2 # [N, N, 2] + overlap = torch.relu(min_sep - diff) # [N, N, 2] + + overlap_areas = overlap[:, :, 0] * overlap[:, :, 1] # [N, N] + mask = torch.triu(torch.ones(N, N, device=cell_features.device), diagonal=1) + overlap_areas = overlap_areas * mask + else: + i_idx = pairs[:, 0] + j_idx = pairs[:, 1] + diff = torch.abs(pos[i_idx] - pos[j_idx]) + min_sep = (shape[i_idx] + shape[j_idx]) / 2 + overlap = torch.relu(min_sep - diff) + overlap_areas = overlap[:, 0] * overlap[:, 1] + + total_overlap = torch.sum(overlap_areas) + + # Use a loss function which has a steady gradient as x -> 0. + # Use a multiplier to make the gradient more steep as we want + # this value to have more influence initially during optimization. + # Add a linear term to take care of the vanishing gradient as x gets closer to 0. + loss = torch.log1p(100 * total_overlap**2) + torch.tensor(0.25 * total_overlap) + return loss + +def get_nearby_pairs_kdtree(cell_features): + """Find candidate overlapping cell pairs using a KD-tree spatial index. + + Cells are split into macros (height > 1.5) and standard cells (height <= 1.5) + and queried separately to avoid a single global radius that would return O(N²) + pairs at high packing density. Three pair types are handled: + + std-std: + Builds a KD-tree over standard cell positions and calls query_pairs once + with Chebyshev distance (p=inf) and radius = max std cell dimension. + Chebyshev distance checks a square window rather than a circle, matching + the rectangular geometry of std cells and producing far fewer false + positives than Euclidean distance. Two std cells can only overlap if their + centers are within (wi+wj)/2 in x AND (hi+hj)/2 in y simultaneously, so + this radius is tight. + + macro-macro: + With at most ~10 macros there are at most 45 pairs — brute forced directly + with no tree needed. + + macro-std: + For each macro, queries the std cell tree with query_ball_point using + radius = macro half-dimension + max std half-dimension. This guarantees + no overlapping macro-std pair is missed regardless of macro size. + + Args: + cell_features: [N, 6] tensor with [area, num_pins, x, y, width, height] + + Returns: + [M, 2] long tensor of candidate pair indices (i, j) with i < j. + M << N² since only spatially nearby pairs are returned. False positives + are harmless — torch.relu in the loss filters non-overlapping pairs to + zero contribution. False negatives would cause overlaps to be missed, so + radii are chosen conservatively to guarantee completeness. + """ + positions = cell_features[:, 2:4].detach().numpy() + widths = cell_features[:, 4].detach().numpy() + heights = cell_features[:, 5].detach().numpy() + + is_macro = heights > 1.5 + macro_idx = np.where(is_macro)[0] + std_idx = np.where(~is_macro)[0] + + macro_pos = positions[macro_idx] + std_pos = positions[std_idx] + + pair_set = set() + + # --- std-std: single batched query, small radius --- + if len(std_idx) > 1: + std_pos = positions[std_idx] + max_std_w = float(widths[std_idx].max()) + max_std_h = float(heights[std_idx].max()) # always 1.0 + + # Chebyshev radius = max of x-separation and y-separation thresholds + # Two std cells overlap only if BOTH x-dist < (wi+wj)/2 AND y-dist < 1.0 + # Chebyshev(p=inf) with radius = max(max_w, 1.0) catches all such pairs + # with far fewer false positives than Euclidean + std_radius = max(max_std_w, max_std_h) + std_tree = KDTree(std_pos) + + # p=inf uses Chebyshev distance — equivalent to checking a square + # window around each cell rather than a circle, much tighter for + # rectangular cells + std_pairs = std_tree.query_pairs(std_radius, p=np.inf, output_type='ndarray') + for i, j in std_pairs: + pair_set.add((int(std_idx[i]), int(std_idx[j]))) + + # --- macro-macro: few macros, brute force all pairs --- + for a in range(len(macro_idx)): + for b in range(a + 1, len(macro_idx)): + i, j = int(macro_idx[a]), int(macro_idx[b]) + pair_set.add((i, j)) + + # --- macro-std: one query per macro against std tree --- + if len(macro_idx) > 0 and len(std_idx) > 0: + std_tree = KDTree(std_pos) if 'std_tree' not in dir() else std_tree + for a, i in enumerate(macro_idx): + radius = (max(widths[i], heights[i]) / 2 + + float(widths[std_idx].max() + heights[std_idx].max()) / 2) + neighbors = std_tree.query_ball_point(macro_pos[a], r=radius) + for b in neighbors: + j = int(std_idx[b]) + pair_set.add((min(i, j), max(i, j))) + + if not pair_set: + return torch.zeros((0, 2), dtype=torch.long) + + return torch.tensor(list(pair_set), dtype=torch.long) + +# Specific flags to use for optimizations, due to memory and runtime issues. Runtime still WIP. +# Use KD-Tree for overlap instead of naive/expensive overlap +def use_pairs(N): return bool(N > 1000) + +# For optimizer initialization +def large_design(N): return bool(N > 10000) + +# For hyperparameter initialization +def very_large_design(N): return bool(N > 100000) def train_placement( cell_features, pin_features, edge_list, - num_epochs=1000, - lr=0.01, + num_epochs=20000, + lr=1.0, lambda_wirelength=1.0, lambda_overlap=10.0, verbose=True, @@ -378,7 +515,7 @@ def train_placement( pin_features: [P, 7] tensor with pin properties edge_list: [E, 2] tensor with edge connectivity num_epochs: Number of optimization iterations - lr: Learning rate for Adam optimizer + lr: Learning rate for optimizer lambda_wirelength: Weight for wirelength loss lambda_overlap: Weight for overlap loss verbose: Whether to print progress @@ -390,6 +527,15 @@ def train_placement( - initial_cell_features: Original cell positions (for comparison) - loss_history: Loss values over time """ + + N = cell_features.shape[0] + max_norm = 5.0 + if (very_large_design(N)): + num_epochs = 22000 + lr = 5.0 # Aggresive learning rate, but scheduler below will help with this in later epochs. + lambda_wirelength = 0.0 # TODO: Include wirelength loss for this relatively packed design. + max_norm = 50.0 # Allow larger gradient flows if necessary. + # Clone features and create learnable positions cell_features = cell_features.clone() initial_cell_features = cell_features.clone() @@ -398,8 +544,17 @@ def train_placement( cell_positions = cell_features[:, 2:4].clone().detach() cell_positions.requires_grad_(True) - # Create optimizer - optimizer = optim.Adam([cell_positions], lr=lr) + # Create SGD optimizer + # We are not using Adam here as there were issues with noisy loss in much later epochs. + # Along with a smoother loss function, SGD is more stable. + if (not large_design(N)): + optimizer = optim.SGD([cell_positions], lr=lr, momentum=0.9) + else: + optimizer = optim.SGD([cell_positions], lr=lr, momentum=0.9, nesterov=True) + lambda_overlap = 50.0 # Apply larger penalty to overlap loss to speed up optimization. + + # Use cosine to adjust learning rate, especially as we start to reach a minimum. + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, 0.01**2) # Track loss history loss_history = { @@ -409,20 +564,39 @@ def train_placement( } # Training loop + # Initialize training state. We need to keep track of whether overlap loss + # ever comes back if it had ever been resolved. We will also refresh the + # tree indicating which cells have overlapping neighbors every so often, + # as refreshing this tree every epoch does not scale well with increasing N. + overlap_resolved = False + pair_refresh_interval = 50 + pairs = None for epoch in range(num_epochs): + if use_pairs(N) and epoch % pair_refresh_interval == 0: + with torch.no_grad(): + cell_features_temp = cell_features.clone() + cell_features_temp[:, 2:4] = cell_positions + + t0 = time.time() + pairs = get_nearby_pairs_kdtree(cell_features_temp) + # print(f"KDTree: {time.time()-t0:.3f}s, pairs={len(pairs)}") + optimizer.zero_grad() - # Create cell_features with current positions cell_features_current = cell_features.clone() cell_features_current[:, 2:4] = cell_positions - # Calculate losses + t0 = time.time() wl_loss = wirelength_attraction_loss( - cell_features_current, pin_features, edge_list - ) + cell_features_current, pin_features, edge_list + ) + # print(f"Wirelength: {time.time()-t0:.3f}s") + + t0 = time.time() overlap_loss = overlap_repulsion_loss( - cell_features_current, pin_features, edge_list + cell_features_current, pin_features, edge_list, pairs ) + # print(f"Overlap: {time.time()-t0:.3f}s") # Combined loss total_loss = lambda_wirelength * wl_loss + lambda_overlap * overlap_loss @@ -431,11 +605,37 @@ def train_placement( total_loss.backward() # Gradient clipping to prevent extreme updates - torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) + torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=max_norm) - # Update positions + # Update positions and learning rate optimizer.step() + scheduler.step() + + # Convergence check — only when sparse loss is near zero + sparse_loss = overlap_loss.item() + if sparse_loss < 1e-8: + with torch.no_grad(): + cell_features_current[:, 2:4] = cell_positions + dense_loss = overlap_repulsion_loss( + cell_features_current, pin_features, edge_list, pairs=None + ) + if dense_loss.item() <= 1e-6: + if calculate_cells_with_overlaps(cell_features_current) == 0: + print(f"Converged at epoch {epoch}") + break + else: + # Pairs were stale — refresh and continue, don't touch LR + pairs = get_nearby_pairs_kdtree(cell_features_current) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, num_epochs - epoch, 0.01**2 + ) + else: + if overlap_resolved: + # Overlap came back — restore LR to fight it + overlap_resolved = False + for param_group in optimizer.param_groups: + param_group['lr'] = lr # Record losses loss_history["total_loss"].append(total_loss.item()) loss_history["wirelength_loss"].append(wl_loss.item()) @@ -531,6 +731,43 @@ def calculate_overlap_metrics(cell_features): "overlap_percentage": overlap_percentage, } +def calculate_cells_with_overlaps_batched(cell_features, batch_size=1000): + N = cell_features.shape[0] + if N <= 1: + return 0 + + pos = cell_features[:, 2:4].detach() + shapes = cell_features[:, 4:].detach() + + # Track which cells are involved in any overlap + # Use a 1D bool tensor instead of N×N matrix + has_overlap = torch.zeros(N, dtype=torch.bool) + + for start in range(0, N, batch_size): + end = min(start + batch_size, N) + + pos_i = pos[start:end].unsqueeze(1) # [batch, 1, 2] + pos_j = pos.unsqueeze(0) # [1, N, 2] + diff = torch.abs(pos_i - pos_j) # [batch, N, 2] + + min_sep = (shapes[start:end].unsqueeze(1) + shapes.unsqueeze(0)) / 2 # [batch, N, 2] + + overlap = diff < min_sep # [batch, N, 2] + overlap_present = overlap[:, :, 0] & overlap[:, :, 1] # [batch, N] + + # Upper triangle only — j > i + batch_indices = torch.arange(start, end).unsqueeze(1) # [batch, 1] + col_indices = torch.arange(N).unsqueeze(0) # [1, N] + upper_tri = col_indices > batch_indices # [batch, N] + overlap_present = overlap_present & upper_tri # [batch, N] + + # Mark row cells (i) that overlap with any j + has_overlap[start:end] |= overlap_present.any(dim=1) + + # Mark col cells (j) that overlap with any i in this batch + has_overlap |= overlap_present.any(dim=0) + + return has_overlap.sum().item() def calculate_cells_with_overlaps(cell_features): """Calculate number of cells involved in at least one overlap. @@ -541,40 +778,36 @@ def calculate_cells_with_overlaps(cell_features): cell_features: [N, 6] tensor with cell properties Returns: - Set of cell indices that have overlaps with other cells + Number of cell indices that have overlaps with other cells """ N = cell_features.shape[0] if N <= 1: - return set() - - # Extract cell properties - positions = cell_features[:, 2:4].detach().numpy() - widths = cell_features[:, 4].detach().numpy() - heights = cell_features[:, 5].detach().numpy() + return 0 - cells_with_overlaps = set() + # Used in "dense" overlap check when there are a large amount of cells + # aka when we need to verify there is truly no overlap across all cells + if very_large_design(N): + return calculate_cells_with_overlaps_batched(cell_features) - # Check all pairs - for i in range(N): - for j in range(i + 1, N): - # Calculate center-to-center distances - dx = abs(positions[i, 0] - positions[j, 0]) - dy = abs(positions[i, 1] - positions[j, 1]) + pos = cell_features[:, 2:4] + shape = cell_features[:, 4:] - # Minimum separation for non-overlap - min_sep_x = (widths[i] + widths[j]) / 2 - min_sep_y = (heights[i] + heights[j]) / 2 + pos_i = pos.unsqueeze(1) + pos_j = pos.unsqueeze(0) + diff = torch.abs(pos_i - pos_j) - # Calculate overlap amounts - overlap_x = max(0, min_sep_x - dx) - overlap_y = max(0, min_sep_y - dy) + min_sep = (shape.unsqueeze(1) + shape.unsqueeze(0)) / 2 + overlap = (min_sep - diff) > 0 + overlap_present = overlap[:, :, 0] & overlap[:, :, 1] - # Overlap occurs only if both x and y overlap - if overlap_x > 0 and overlap_y > 0: - cells_with_overlaps.add(i) - cells_with_overlaps.add(j) + # Upper triangle only — unique pairs (i < j) + mask = torch.triu(torch.ones(N, N, device=cell_features.device, dtype=torch.bool), diagonal=1) + overlap_present = overlap_present & mask - return cells_with_overlaps + # For each overlapping pair (i,j), both i and j are involved + # i appears as a row, j appears as a column + has_overlap = overlap_present.any(dim=1) | overlap_present.any(dim=0) + return has_overlap.sum().item() def calculate_normalized_metrics(cell_features, pin_features, edge_list): @@ -598,8 +831,7 @@ def calculate_normalized_metrics(cell_features, pin_features, edge_list): N = cell_features.shape[0] # Calculate overlap metric: num cells with overlaps / total cells - cells_with_overlaps = calculate_cells_with_overlaps(cell_features) - num_cells_with_overlaps = len(cells_with_overlaps) + num_cells_with_overlaps = calculate_cells_with_overlaps(cell_features) overlap_ratio = num_cells_with_overlaps / N if N > 0 else 0.0 # Calculate wirelength metric: (wirelength / num nets) / sqrt(total area) From ca676968b3216531f244db9a5ea4c4433248b2eb Mon Sep 17 00:00:00 2001 From: ssanjeev2016 <32891905+ssanjeev2016@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:17:16 -0700 Subject: [PATCH 2/2] Provide more notes --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f6682cd..57f314a 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ We will review submissions on a rolling basis. | Rank | Name | Overlap | Wirelength (um) | Runtime (s) | Notes | |------|-----------------|-------------|-----------------|-------------|----------------------| -| 1 | Sanjeev Sinha | 0.0000 | 0.1493 | 4057.78 | KD-Tree usage for large designs with convergence checks | +| 1 | Sanjeev Sinha | 0.0000 | 0.1493 | 4057.78 | KD-Tree usage for large designs with convergence checks. Test Case 12 takes 3662.55s, runtime is WIP. | | 2 | example | 0.5000 | 0.5 | 10 | example submission | | 3 | Add Yours! | | | | |