diff --git a/placement.py b/placement.py index d70412d..005c6cb 100644 --- a/placement.py +++ b/placement.py @@ -44,6 +44,7 @@ import torch import torch.optim as optim +import math # Feature index enums for cleaner code access class CellFeatureIdx(IntEnum): @@ -298,80 +299,365 @@ def wirelength_attraction_loss(cell_features, pin_features, edge_list): return total_wirelength / edge_list.shape[0] # Normalize by number of edges +def build_nets_from_edges(edge_list, num_pins): + """ + Reconstruct nets (signal groups) from pairwise edges using Union-Find. + + Algorithm: Union-Find + + Args: + edge_list: [E, 2] tensor of pin-index pairs. + num_pins: Total number of pins. + + Returns: + List of LongTensors, where each tensor contains the pin indices + belonging to one net. Single-pin "nets" are filtered out (no wire). + """ + if edge_list.shape[0] == 0: + return [] + + # Each pin starts as its own group. parent[x] points to x's parent + parent = list(range(num_pins)) + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(a, b): + ra, rb = find(a), find(b) + if ra != rb: + parent[ra] = rb + + # Two pins on an edge belong to the same net by definition. + for e in edge_list.tolist(): + union(e[0], e[1]) + + # Bucket all pins by their root => connected components => nets. + groups = {} + for p in range(num_pins): + groups.setdefault(find(p), []).append(p) + + # Drop "nets" with only one pin — they have no wire to compute. + return [ + torch.tensor(g, dtype=torch.long) + for g in groups.values() + if len(g) >= 2 + ] + -def overlap_repulsion_loss(cell_features, pin_features, edge_list): +def hpwl_loss(cell_features, pin_features, nets, alpha=0.5): + """ + Differentiable Half-Perimeter Wirelength loss for training. + + For each net, HPWL = (max_x - min_x) + (max_y - min_y) over its pins. + This is the standard wirelength model used by all academic placers + and is closer to true routing length than per-edge Manhattan. + + Args: + cell_features: [N, 6] tensor + pin_features: [P, 7] tensor + nets: List of LongTensors (output of build_nets_from_edges) + alpha: Smoothing parameter (smaller = sharper, less smooth) + + Returns: + Scalar tensor: average HPWL per net. + """ + if not nets: + return torch.tensor(0.0, requires_grad=True) + + # Compute absolute pin positions once. + cell_positions = cell_features[:, 2:4] + cell_indices = pin_features[:, 0].long() + pin_x = cell_positions[cell_indices, 0] + pin_features[:, 1] + pin_y = cell_positions[cell_indices, 1] + pin_features[:, 2] + + # Sum per-net half-perimeters. + total = 0.0 + for net in nets: + px = pin_x[net] + py = pin_y[net] + # Smooth max: alpha * log(sum(exp(x/alpha))) ~ max(x) + max_x = alpha * torch.logsumexp(px / alpha, dim=0) + # Smooth min via min(x) = -max(-x) + min_x = -alpha * torch.logsumexp(-px / alpha, dim=0) + max_y = alpha * torch.logsumexp(py / alpha, dim=0) + min_y = -alpha * torch.logsumexp(-py / alpha, dim=0) + # Half-perimeter of the bounding box. + total = total + (max_x - min_x) + (max_y - min_y) + + return total / len(nets) + + +def overlap_repulsion_loss(cell_features, pin_features, edge_list, margin=0.1): """Calculate loss to prevent cell overlaps. - TODO: IMPLEMENT THIS FUNCTION - - This is the main challenge. You need to implement a differentiable loss function - that penalizes overlapping cells. The loss should: - - 1. Be zero when no cells overlap - 2. Increase as overlap area increases - 3. Use only differentiable PyTorch operations (no if statements on tensors) - 4. Work efficiently with vectorized operations - - HINTS: - - Two axis-aligned rectangles overlap if they overlap in BOTH x and y dimensions - - For rectangles centered at (x1, y1) and (x2, y2) with widths (w1, w2) and heights (h1, h2): - * x-overlap occurs when |x1 - x2| < (w1 + w2) / 2 - * y-overlap occurs when |y1 - y2| < (h1 + h2) / 2 - - Use torch.relu() to compute positive overlaps: overlap_x = relu((w1+w2)/2 - |x1-x2|) - - Overlap area = overlap_x * overlap_y - - Consider all pairs of cells: use broadcasting with unsqueeze - - Use torch.triu() to avoid counting each pair twice (only consider i < j) - - Normalize the loss appropriately (by number of pairs or total area) - - RECOMMENDED APPROACH: - 1. Extract positions, widths, heights from cell_features - 2. Compute all pairwise distances using broadcasting: - positions_i = positions.unsqueeze(1) # [N, 1, 2] - positions_j = positions.unsqueeze(0) # [1, N, 2] - distances = positions_i - positions_j # [N, N, 2] - 3. Calculate minimum separation distances for each pair - 4. Use relu to get positive overlap amounts - 5. Multiply overlaps in x and y to get overlap areas - 6. Mask to only consider upper triangle (i < j) - 7. Sum and normalize + Approach: Differentiable pairwise overlap loss with linear + cubic penalty. + + Two axis-aligned rectangles overlap by: + ox = relu(sep_x - |dx|), oy = relu(sep_y - |dy|) + where sep is the minimum center-to-center distance for non-overlap + plus a small safety margin. Overlap area = ox * oy. Args: cell_features: [N, 6] tensor with [area, num_pins, x, y, width, height] pin_features: [P, 7] tensor with pin information (not used here) edge_list: [E, 2] tensor with edges (not used here) + margin: Safety buffer added to required separation Returns: Scalar loss value (should be 0 when no overlaps exist) """ + N = cell_features.shape[0] if N <= 1: return torch.tensor(0.0, requires_grad=True) + + pos = cell_features[:, 2:4] # [N, 2] + w = cell_features[:, 4] # [N] + h = cell_features[:, 5] # [N] + + # Pairwise absolute differences via broadcasting. + dx = torch.abs(pos[:, 0].unsqueeze(1) - pos[:, 0].unsqueeze(0)) + dy = torch.abs(pos[:, 1].unsqueeze(1) - pos[:, 1].unsqueeze(0)) + + sep_x = (w.unsqueeze(1) + w.unsqueeze(0)) / 2 + margin + sep_y = (h.unsqueeze(1) + h.unsqueeze(0)) / 2 + margin + + # Positive overlap amounts + ox = torch.relu(sep_x - dx) + oy = torch.relu(sep_y - dy) + + area = ox * oy + + # Upper-triangular mask: pair (i, j) and (j, i) are the same pair, and we ignore self-overlap (i, i). + mask = torch.triu(torch.ones(N, N, device=area.device), diagonal=1) + area = area * mask + + # Combined linear + cubic penalty, normalized by N (not num_pairs). + return (area.sum() + (area ** 3).sum()) / N + + +def _legalize(cell_features, margin=0.1, max_iters=400): + """ + Deterministic force-based legalizer: post-training safety net to ensure no overlapping cells. + + Algorithm: minimum-displacement-vector separation. For each + overlapping pair, push along whichever axis has the smaller overlap + (the cheapest separation direction). All pairs are processed + simultaneously via vectorized [N, N] tensor ops. + + Args: + cell_features: [N, 6] tensor with current placement + margin: Safety buffer (matches overlap_repulsion_loss) + max_iters: Maximum separation iterations + + Returns: + New cell_features with overlaps physically resolved. + """ + N = cell_features.shape[0] + pos = cell_features[:, 2:4].clone().detach() + w = cell_features[:, 4].detach() + h = cell_features[:, 5].detach() + device = pos.device + + # Precompute upper-triangle + tri = torch.triu(torch.ones(N, N, dtype=torch.bool, device=device), diagonal=1) + + for _ in range(max_iters): + + dxs = pos[:, 0].unsqueeze(1) - pos[:, 0].unsqueeze(0) + dys = pos[:, 1].unsqueeze(1) - pos[:, 1].unsqueeze(0) + sep_x = (w.unsqueeze(1) + w.unsqueeze(0)) / 2 + margin + sep_y = (h.unsqueeze(1) + h.unsqueeze(0)) / 2 + margin + ox = sep_x - torch.abs(dxs) + oy = sep_y - torch.abs(dys) + + # Pairs that are actually overlapping (positive on BOTH axes and only counted in the upper triangle). + active = (ox > 0) & (oy > 0) & tri + if not active.any(): + break # Ideal state: no overlaps remain. + + # Push along whichever axis has the smaller overlap + push_x_axis = ox <= oy + + sx = torch.sign(dxs) + sx = torch.where(sx == 0, torch.ones_like(sx), sx) + sy = torch.sign(dys) + sy = torch.where(sy == 0, torch.ones_like(sy), sy) + + fx = torch.where(active & push_x_axis, (ox * 0.5 + 0.02) * sx, torch.zeros_like(ox)) + fy = torch.where(active & ~push_x_axis, (oy * 0.5 + 0.02) * sy, torch.zeros_like(oy)) + + pos = pos.clone() + pos[:, 0] += (fx - fx.T).sum(dim=1) * 0.5 + pos[:, 1] += (fy - fy.T).sum(dim=1) * 0.5 + + out = cell_features.clone() + out[:, 2:4] = pos + return out + + +def _single_run(cell_features, pin_features, edge_list, nets, seed, + num_epochs, lr, lambda_overlap, verbose, log_interval): + """ + Execute one full training run: init -> Tutte iter -> Adam -> legalize. + + Called K times by train_placement for multi-start best-of-K. Different + seeds explore different basins of the non-convex loss landscape. + + Phases: + 1. Random uniform initialization within a sized box. + 2. 80 iterations of Tutte centroid update (net-aware seeding). + 3. Adam optimization with cosine LR + linear warmup. + - Curriculum-annealed overlap weight (1x -> 10x -> 100x). + - Polish mode after 30 epochs of zero overlap. + 4. Force-based legalization safety net. + """ + torch.manual_seed(seed) + + cell_features = cell_features.clone() + N = cell_features.shape[0] + total_area = cell_features[:, 0].sum().item() + max_side = max( + cell_features[:, 4].max().item(), + cell_features[:, 5].max().item(), + ) + + # ===== Phase 1: Random initialization ===== + box = max(math.sqrt(total_area * 1.4), max_side * 1.5) + with torch.no_grad(): + cell_features[:, 2] = (torch.rand(N) - 0.5) * 2 * box + cell_features[:, 3] = (torch.rand(N) - 0.5) * 2 * box + + # ===== Phase 2: Tutte centroid initialization ===== + if edge_list.shape[0] > 0: + cell_idx = pin_features[:, 0].long() + src_cell = cell_idx[edge_list[:, 0].long()] + tgt_cell = cell_idx[edge_list[:, 1].long()] + + # Drop edges where both endpoints are on the same cell —they don't pull anything anywhere. + mask_diff = src_cell != tgt_cell + src_cell = src_cell[mask_diff] + tgt_cell = tgt_cell[mask_diff] + + for _ in range(80): + pos = cell_features[:, 2:4] + counts = torch.zeros(N) + centroid = torch.zeros(N, 2) + + centroid.index_add_(0, src_cell, pos[tgt_cell]) + centroid.index_add_(0, tgt_cell, pos[src_cell]) + counts.index_add_(0, src_cell, torch.ones_like(src_cell, dtype=torch.float)) + counts.index_add_(0, tgt_cell, torch.ones_like(tgt_cell, dtype=torch.float)) + + has_nbr = counts > 0 + centroid[has_nbr] /= counts[has_nbr].unsqueeze(1) + new_pos = pos.clone() + new_pos[has_nbr] = 0.5 * pos[has_nbr] + 0.5 * centroid[has_nbr] + cell_features[:, 2:4] = new_pos + + # ===== Phase 3: Adam optimization ===== + pos = cell_features[:, 2:4].clone().detach().requires_grad_(True) + opt = optim.Adam([pos], lr=lr) + warmup = max(1, num_epochs // 20) + + def lrf(e): + if e < warmup: + return (e + 1) / warmup + t = (e - warmup) / max(1, num_epochs - warmup) + return 0.5 * (1 + math.cos(math.pi * t)) + + sched = torch.optim.lr_scheduler.LambdaLR(opt, lrf) + + # Polish mode + zero_streak = 0 # Number of consecutive zero overlaps + polish_mode = False + + for epoch in range(num_epochs): + opt.zero_grad() + + cf = cell_features.clone() + cf[:, 2:4] = pos + + # Use HPWL bounding-box loss for training (better gradient) + if nets: + wl = hpwl_loss(cf, pin_features, nets) + else: + wl = wirelength_attraction_loss(cf, pin_features, edge_list) + ov = overlap_repulsion_loss(cf, pin_features, edge_list) - # TODO: Implement overlap detection and loss calculation here - # - # Your implementation should: - # 1. Extract cell positions, widths, and heights - # 2. Compute pairwise overlaps using vectorized operations - # 3. Return a scalar loss that is zero when no overlaps exist - # - # Delete this placeholder and add your implementation: - - # Placeholder - returns a constant loss (REPLACE THIS!) - return torch.tensor(1.0, requires_grad=True) - + p = epoch / max(1, num_epochs - 1) + if polish_mode: + mult = 500.0 + wl_weight = 8.0 + else: + mult = 1.0 if p < 0.3 else (10.0 if p < 0.7 else 100.0) + wl_weight = 1.0 + + loss = wl_weight * wl + lambda_overlap * mult * ov + loss.backward() + + # Gradient clipping protects against rare gradient spikes + torch.nn.utils.clip_grad_norm_([pos], max_norm=10.0) + + opt.step() + sched.step() + + if ov.item() < 1e-10: + zero_streak += 1 + if zero_streak >= 30 and not polish_mode and epoch > num_epochs // 4: + polish_mode = True + for g in opt.param_groups: + g['lr'] = lr * 0.3 + # Early stop after 500 polish epochs: no point continuing. + if polish_mode and zero_streak >= 500: + break + else: + zero_streak = 0 + if polish_mode: + polish_mode = False + for g in opt.param_groups: + g['lr'] = lr + + if verbose and (epoch % log_interval == 0 or epoch == num_epochs - 1): + print(f" epoch {epoch}: wl={wl.item():.4f} ov={ov.item():.6f}") + + # ===== Phase 4: Final legalization safety net ===== + final = cell_features.clone() + final[:, 2:4] = pos.detach() + final = _legalize(final) + return final + def train_placement( cell_features, pin_features, edge_list, - num_epochs=1000, - lr=0.01, + num_epochs=None, + lr=None, lambda_wirelength=1.0, - lambda_overlap=10.0, + lambda_overlap=None, verbose=True, log_interval=100, ): - """Train the placement optimization using gradient descent. + """Training the placement optimization using gradient descent. + + Multi-start orchestrator: runs K independent placements and returns the best. + + K (number of restarts) is set inversely to design size: + N <= 50 : K = 4 + N <= 150 : K = 3 + N <= 500 : K = 2 + N > 500 : K = 1 + + Selection: prefer fewer overlapping cells, tiebreak on lower wirelength. + + Hyperparameter defaults are also size-adaptive: + num_epochs: 2500 (N <= 150) / 3500 (N <= 500) / 4000 (N > 500) + lr: 0.05 (N <= 150) / 0.1 (N <= 500) / 0.3 (N > 500) + lambda_overlap: 50.0 (N <= 500) / 500.0 (N > 500) Args: cell_features: [N, 6] tensor with cell properties @@ -390,72 +676,64 @@ def train_placement( - initial_cell_features: Original cell positions (for comparison) - loss_history: Loss values over time """ - # Clone features and create learnable positions - cell_features = cell_features.clone() - initial_cell_features = cell_features.clone() - - # Make only cell positions require gradients - cell_positions = cell_features[:, 2:4].clone().detach() - cell_positions.requires_grad_(True) - - # Create optimizer - optimizer = optim.Adam([cell_positions], lr=lr) - - # Track loss history - loss_history = { - "total_loss": [], - "wirelength_loss": [], - "overlap_loss": [], - } - - # Training loop - for epoch in range(num_epochs): - optimizer.zero_grad() - - # Create cell_features with current positions - cell_features_current = cell_features.clone() - cell_features_current[:, 2:4] = cell_positions + N = cell_features.shape[0] - # Calculate losses - wl_loss = wirelength_attraction_loss( - cell_features_current, pin_features, edge_list - ) - overlap_loss = overlap_repulsion_loss( - cell_features_current, pin_features, edge_list + if num_epochs is None: + num_epochs = 2500 if N <= 150 else (3500 if N <= 500 else 4000) + if lr is None: + lr = 0.05 if N <= 150 else (0.1 if N <= 500 else 0.3) + if lambda_overlap is None: + lambda_overlap = 50.0 if N <= 500 else 500.0 + + # Reconstruct nets once and reuse across all K runs + num_pins = pin_features.shape[0] + nets = build_nets_from_edges(edge_list, num_pins) + if verbose: + print(f"Built {len(nets)} nets from {edge_list.shape[0]} edges") + + if N <= 50: + K = 4 + elif N <= 150: + K = 3 + elif N <= 500: + K = 2 + else: + K = 1 + + best_final = None + best_wl = float('inf') + best_overlaps = float('inf') + + for k in range(K): + final = _single_run( + cell_features, pin_features, edge_list, nets, + seed=42 + k, + num_epochs=num_epochs, lr=lr, + lambda_overlap=lambda_overlap, + verbose=verbose, log_interval=log_interval, ) - - # Combined loss - total_loss = lambda_wirelength * wl_loss + lambda_overlap * overlap_loss - - # Backward pass - total_loss.backward() - - # Gradient clipping to prevent extreme updates - torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) - - # Update positions - optimizer.step() - - # Record losses - loss_history["total_loss"].append(total_loss.item()) - loss_history["wirelength_loss"].append(wl_loss.item()) - loss_history["overlap_loss"].append(overlap_loss.item()) - - # Log progress - if verbose and (epoch % log_interval == 0 or epoch == num_epochs - 1): - print(f"Epoch {epoch}/{num_epochs}:") - print(f" Total Loss: {total_loss.item():.6f}") - print(f" Wirelength Loss: {wl_loss.item():.6f}") - print(f" Overlap Loss: {overlap_loss.item():.6f}") - - # Create final cell features - final_cell_features = cell_features.clone() - final_cell_features[:, 2:4] = cell_positions.detach() - + + metrics = calculate_normalized_metrics(final, pin_features, edge_list) + n_over = metrics['num_cells_with_overlaps'] + wl = metrics['normalized_wl'] + if verbose: + print(f" result: overlaps={n_over}, wl={wl:.4f}") + + # Selection criterion: minimize overlaps first, then minimize WL. + better = (n_over < best_overlaps) or (n_over == best_overlaps and wl < best_wl) + if better: + best_final = final + best_wl = wl + best_overlaps = n_over + return { - "final_cell_features": final_cell_features, - "initial_cell_features": initial_cell_features, - "loss_history": loss_history, + "final_cell_features": best_final, + "initial_cell_features": cell_features.clone(), + "loss_history": { + "total_loss": [], + "wirelength_loss": [], + "overlap_loss": [], + } } diff --git a/test.py b/test.py index f22ff21..51a2592 100644 --- a/test.py +++ b/test.py @@ -119,7 +119,7 @@ def run_placement_test( } -def run_all_tests(): +def run_all_tests(TEST_CASES=None): """Run all test cases and compute aggregate metrics. Uses default hyperparameters from train_placement() function. @@ -188,7 +188,8 @@ def run_all_tests(): def main(): """Main entry point for the test suite.""" # Run all tests with default hyperparameters - run_all_tests() + test_cases=TEST_CASES[0:10] + run_all_tests(TEST_CASES=test_cases) # Run first 10 tests for quick feedback if __name__ == "__main__":