From fca879884f7f3705da3ed60a7fa1a4b16e86880a Mon Sep 17 00:00:00 2001 From: Anish Kamatam Date: Wed, 8 Apr 2026 09:25:39 -0700 Subject: [PATCH] first submission --- README.md | 29 +-- placement.py | 588 ++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 457 insertions(+), 160 deletions(-) diff --git a/README.md b/README.md index cf27bfb..f8e8f3c 100644 --- a/README.md +++ b/README.md @@ -41,20 +41,21 @@ We will review submissions on a rolling basis. 8 | Shashank Shriram | 0.0000 | 0.3312 | 11.32 | 🏎️💥 | | 9 | Gabriel Del Monte | 0.0000 | 0.3427 | 606.07 | | | 10 | Aleksey Valouev| 0.0000 | 0.3577 | 118.98 | | -| 11 | Mohul Shukla | 0.0000 | 0.5048 | 54.60s | | -| 12 | Ryan Hulke | 0.0000 | 0.5226 | 166.24 | | -| 13 | Neel Shah | 0.0000 | 0.5445 | 45.40 | Zero overlaps on all tests, adaptive schedule + early stop | -| 14 | Nawel Asgar | 0.0000 | 0.5675 | 81.49 | Adaptive penalty scaling with cubic gradients and design-size optimization -| 15 | Shiva Baghel | 0.0000 | 0.5885 | 491.00 | Stable zero-overlap with balanced optimization | -| 16 | Vansh Jain | 0.0000 | 0.9352 | 86.36 | | -| 17 | Akash Pai | 0.0006 | 0.4933 | 326.25s | | -| 18 | Zade Mahayni | 0.00665 | 0.5157 | 127.4 | Will try again tomorrow | -| 19 | Nithin Yanna | 0.0148 | 0.5034 | 247.30s | aggressive overlap penalty with quadratic scaling | -| 20 | Sean Ko | 0.0271 | .5138 | 31.83s | lr increase, decrease epoch, increase lambda overlap and decreased lambda wire_length + log penalty loss | -| 21 | Keya Gohil | 0.0155 | 0.4678 | 1513.07 | Still working | -| 22 | Prithvi Seran | 0.0499 | 0.4890 | 398.58 | | -| 23 | partcl example | 0.8 | 0.4 | 5 | example | -| 24 | Add Yours! | | | | | +| 11 | Anish Kamatam | 0.0000 | 0.3727 | 574.00 | Spectral init, greedy legalization, multi-restart, Optuna-tuned HPs, gotta try again | +| 12 | Mohul Shukla | 0.0000 | 0.5048 | 54.60s | | +| 13 | Ryan Hulke | 0.0000 | 0.5226 | 166.24 | | +| 14 | Neel Shah | 0.0000 | 0.5445 | 45.40 | Zero overlaps on all tests, adaptive schedule + early stop | +| 15 | Nawel Asgar | 0.0000 | 0.5675 | 81.49 | Adaptive penalty scaling with cubic gradients and design-size optimization +| 16 | Shiva Baghel | 0.0000 | 0.5885 | 491.00 | Stable zero-overlap with balanced optimization | +| 17 | Vansh Jain | 0.0000 | 0.9352 | 86.36 | | +| 18 | Akash Pai | 0.0006 | 0.4933 | 326.25s | | +| 19 | Zade Mahayni | 0.00665 | 0.5157 | 127.4 | Will try again tomorrow | +| 20 | Nithin Yanna | 0.0148 | 0.5034 | 247.30s | aggressive overlap penalty with quadratic scaling | +| 21 | Sean Ko | 0.0271 | .5138 | 31.83s | lr increase, decrease epoch, increase lambda overlap and decreased lambda wire_length + log penalty loss | +| 22 | Keya Gohil | 0.0155 | 0.4678 | 1513.07 | Still working | +| 23 | Prithvi Seran | 0.0499 | 0.4890 | 398.58 | | +| 24 | partcl example | 0.8 | 0.4 | 5 | example | +| 25 | Add Yours! | | | | | > **To add your results:** > Insert a new row in the table above with your name, overlap, wirelength, and any notes. Ensure you sort by overlap. diff --git a/placement.py b/placement.py index d70412d..3d7b026 100644 --- a/placement.py +++ b/placement.py @@ -38,6 +38,7 @@ - Add visualization of optimization progress over time """ +import math import os from enum import IntEnum @@ -246,216 +247,511 @@ def generate_placement_input(num_macros, num_std_cells): # ======= OPTIMIZATION CODE (edit this part) ======= -def wirelength_attraction_loss(cell_features, pin_features, edge_list): - """Calculate loss based on total wirelength to minimize routing. - - This is a REFERENCE IMPLEMENTATION showing how to write a differentiable loss function. - - The loss computes the Manhattan distance between connected pins and minimizes - the total wirelength across all edges. +_WL_ALPHA = 0.03 - Args: - cell_features: [N, 6] tensor with [area, num_pins, x, y, width, height] - pin_features: [P, 7] tensor with pin information - edge_list: [E, 2] tensor with edges - Returns: - Scalar loss value - """ +def wirelength_attraction_loss(cell_features, pin_features, edge_list): + """Smooth Manhattan wirelength loss. Alpha controlled by _WL_ALPHA.""" if edge_list.shape[0] == 0: return torch.tensor(0.0, requires_grad=True) - # Update absolute pin positions based on cell positions - cell_positions = cell_features[:, 2:4] # [N, 2] + cell_positions = cell_features[:, 2:4] cell_indices = pin_features[:, 0].long() - # Calculate absolute pin positions pin_absolute_x = cell_positions[cell_indices, 0] + pin_features[:, 1] pin_absolute_y = cell_positions[cell_indices, 1] + pin_features[:, 2] - # Get source and target pin positions for each edge src_pins = edge_list[:, 0].long() tgt_pins = edge_list[:, 1].long() - src_x = pin_absolute_x[src_pins] - src_y = pin_absolute_y[src_pins] - tgt_x = pin_absolute_x[tgt_pins] - tgt_y = pin_absolute_y[tgt_pins] - - # Calculate smooth approximation of Manhattan distance - # Using log-sum-exp approximation for differentiability - alpha = 0.1 # Smoothing parameter - dx = torch.abs(src_x - tgt_x) - dy = torch.abs(src_y - tgt_y) + dx = torch.abs(pin_absolute_x[src_pins] - pin_absolute_x[tgt_pins]) + dy = torch.abs(pin_absolute_y[src_pins] - pin_absolute_y[tgt_pins]) - # Smooth L1 distance with numerical stability - smooth_manhattan = alpha * torch.logsumexp( - torch.stack([dx / alpha, dy / alpha], dim=0), dim=0 + a = _WL_ALPHA + smooth_manhattan = a * torch.logsumexp( + torch.stack([dx / a, dy / a], dim=0), dim=0 ) - # Total wirelength - total_wirelength = torch.sum(smooth_manhattan) - - return total_wirelength / edge_list.shape[0] # Normalize by number of edges + return torch.sum(smooth_manhattan) / edge_list.shape[0] def overlap_repulsion_loss(cell_features, pin_features, edge_list): - """Calculate loss to prevent cell overlaps. - - TODO: IMPLEMENT THIS FUNCTION - - This is the main challenge. You need to implement a differentiable loss function - that penalizes overlapping cells. The loss should: - - 1. Be zero when no cells overlap - 2. Increase as overlap area increases - 3. Use only differentiable PyTorch operations (no if statements on tensors) - 4. Work efficiently with vectorized operations - - HINTS: - - Two axis-aligned rectangles overlap if they overlap in BOTH x and y dimensions - - For rectangles centered at (x1, y1) and (x2, y2) with widths (w1, w2) and heights (h1, h2): - * x-overlap occurs when |x1 - x2| < (w1 + w2) / 2 - * y-overlap occurs when |y1 - y2| < (h1 + h2) / 2 - - Use torch.relu() to compute positive overlaps: overlap_x = relu((w1+w2)/2 - |x1-x2|) - - Overlap area = overlap_x * overlap_y - - Consider all pairs of cells: use broadcasting with unsqueeze - - Use torch.triu() to avoid counting each pair twice (only consider i < j) - - Normalize the loss appropriately (by number of pairs or total area) - - RECOMMENDED APPROACH: - 1. Extract positions, widths, heights from cell_features - 2. Compute all pairwise distances using broadcasting: - positions_i = positions.unsqueeze(1) # [N, 1, 2] - positions_j = positions.unsqueeze(0) # [1, N, 2] - distances = positions_i - positions_j # [N, N, 2] - 3. Calculate minimum separation distances for each pair - 4. Use relu to get positive overlap amounts - 5. Multiply overlaps in x and y to get overlap areas - 6. Mask to only consider upper triangle (i < j) - 7. Sum and normalize + """Linear overlap area penalty — strong gradients at all overlap magnitudes. + + Gradient of overlap_area w.r.t. position is proportional to the perpendicular + overlap dimension, giving consistent push-apart force for both large macro overlaps + and small std-cell overlaps. Normalized by N (not N^2) to keep per-pair gradient + strong when only a few overlaps remain. Args: cell_features: [N, 6] tensor with [area, num_pins, x, y, width, height] - pin_features: [P, 7] tensor with pin information (not used here) - edge_list: [E, 2] tensor with edges (not used here) + pin_features: [P, 7] tensor (unused) + edge_list: [E, 2] tensor (unused) Returns: - Scalar loss value (should be 0 when no overlaps exist) + Scalar loss value, zero when no cells overlap. """ N = cell_features.shape[0] if N <= 1: return torch.tensor(0.0, requires_grad=True) - # TODO: Implement overlap detection and loss calculation here - # - # Your implementation should: - # 1. Extract cell positions, widths, and heights - # 2. Compute pairwise overlaps using vectorized operations - # 3. Return a scalar loss that is zero when no overlaps exist - # - # Delete this placeholder and add your implementation: + positions = cell_features[:, 2:4] + widths = cell_features[:, 4] + heights = cell_features[:, 5] + + dx = torch.abs(positions[:, 0].unsqueeze(1) - positions[:, 0].unsqueeze(0)) + dy = torch.abs(positions[:, 1].unsqueeze(1) - positions[:, 1].unsqueeze(0)) + + half_w_sum = (widths.unsqueeze(1) + widths.unsqueeze(0)) * 0.5 + half_h_sum = (heights.unsqueeze(1) + heights.unsqueeze(0)) * 0.5 + + overlap_x = torch.relu(half_w_sum - dx) + overlap_y = torch.relu(half_h_sum - dy) + overlap_area = overlap_x * overlap_y + + mask = torch.triu(torch.ones(N, N, dtype=torch.bool, device=cell_features.device), diagonal=1) + masked_overlaps = overlap_area[mask] + + return masked_overlaps.sum() / N + + +def _spectral_initial_placement(cell_features, pin_features, edge_list, scale_exp=0.5): + """Compute connectivity-aware initial positions via graph Laplacian eigenvectors.""" + N = cell_features.shape[0] + if edge_list.shape[0] == 0 or N <= 2 or N > 500: + return None + + pin_to_cell = pin_features[:, 0].long() + src_cells = pin_to_cell[edge_list[:, 0].long()] + tgt_cells = pin_to_cell[edge_list[:, 1].long()] + + valid = src_cells != tgt_cells + src_cells = src_cells[valid] + tgt_cells = tgt_cells[valid] + if src_cells.shape[0] == 0: + return None + + A = torch.zeros(N, N) + ones = torch.ones(src_cells.shape[0]) + A.view(-1).scatter_add_(0, src_cells * N + tgt_cells, ones) + A.view(-1).scatter_add_(0, tgt_cells * N + src_cells, ones) + + L = torch.diag(A.sum(dim=1)) - A + + try: + eigvals, eigvecs = torch.linalg.eigh(L) + except Exception: + return None + + x_raw = eigvecs[:, 1] + y_raw = eigvecs[:, 2] + + x_range = x_raw.max() - x_raw.min() + y_range = y_raw.max() - y_raw.min() + if x_range < 1e-10 or y_range < 1e-10: + return None + + x_norm = (x_raw - x_raw.min()) / x_range - 0.5 + y_norm = (y_raw - y_raw.min()) / y_range - 0.5 + + scale = cell_features[:, 0].sum().item() ** scale_exp + return torch.stack([x_norm * scale, y_norm * scale], dim=1) + + +def _count_discrete_overlaps(positions, widths, heights): + """Vectorized count of overlapping cell pairs (no Python loops).""" + N = positions.shape[0] + dx = torch.abs(positions[:, 0].unsqueeze(1) - positions[:, 0].unsqueeze(0)) + dy = torch.abs(positions[:, 1].unsqueeze(1) - positions[:, 1].unsqueeze(0)) + half_w = (widths.unsqueeze(1) + widths.unsqueeze(0)) * 0.5 + half_h = (heights.unsqueeze(1) + heights.unsqueeze(0)) * 0.5 + ovlp = torch.relu(half_w - dx) * torch.relu(half_h - dy) + tri = torch.triu(torch.ones(N, N, dtype=torch.bool, device=positions.device), diagonal=1) + return (ovlp[tri] > 1e-10).sum().item() + + +def _legalize_placement(positions, widths, heights, verbose=False): + """Greedy cell-by-cell legalization guaranteeing zero overlap. + + Processes cells largest-first. For each conflicting cell, generates exact + displacement candidates from overlapping neighbors (4 per neighbor: the + minimum shift in each cardinal direction to separate), picks the closest + valid one. Falls back to spiral search if needed. + """ + import numpy as np + + N = positions.shape[0] + pos = positions.detach().clone() + px = pos[:, 0].numpy().copy() + py = pos[:, 1].numpy().copy() + w = widths.detach().numpy().copy() + h = heights.detach().numpy().copy() + + areas = w * h + order = np.argsort(-areas) + + placed_x = np.empty(N, dtype=np.float64) + placed_y = np.empty(N, dtype=np.float64) + placed_w = np.empty(N, dtype=np.float64) + placed_h = np.empty(N, dtype=np.float64) + n_placed = 0 + + for idx in order: + cx, cy = float(px[idx]), float(py[idx]) + cw, ch = float(w[idx]), float(h[idx]) + + if n_placed == 0: + placed_x[0], placed_y[0] = cx, cy + placed_w[0], placed_h[0] = cw, ch + n_placed = 1 + continue + + def _has_any_overlap(tx, ty): + adx = np.abs(tx - placed_x[:n_placed]) + ady = np.abs(ty - placed_y[:n_placed]) + min_sx = (cw + placed_w[:n_placed]) * 0.5 + min_sy = (ch + placed_h[:n_placed]) * 0.5 + return np.any((adx < min_sx) & (ady < min_sy)) + + if not _has_any_overlap(cx, cy): + placed_x[n_placed], placed_y[n_placed] = cx, cy + placed_w[n_placed], placed_h[n_placed] = cw, ch + n_placed += 1 + continue + + adx = np.abs(cx - placed_x[:n_placed]) + ady = np.abs(cy - placed_y[:n_placed]) + min_sx = (cw + placed_w[:n_placed]) * 0.5 + min_sy = (ch + placed_h[:n_placed]) * 0.5 + conflicts = np.where((adx < min_sx) & (ady < min_sy))[0] + + candidates = [] + margin = 1e-3 + for k in conflicts: + sep_x = min_sx[k] + margin + sep_y = min_sy[k] + margin + candidates.append((placed_x[k] + sep_x, cy)) + candidates.append((placed_x[k] - sep_x, cy)) + candidates.append((cx, placed_y[k] + sep_y)) + candidates.append((cx, placed_y[k] - sep_y)) + + best_pos = None + best_dist = float('inf') + for tx, ty in candidates: + d = (tx - cx) ** 2 + (ty - cy) ** 2 + if d < best_dist and not _has_any_overlap(tx, ty): + best_dist = d + best_pos = (tx, ty) + + if best_pos is None: + for radius_step in range(1, 2000): + step = 0.5 * radius_step + found = False + for sx in [-step, 0, step]: + for sy in [-step, 0, step]: + if sx == 0 and sy == 0: + continue + tx, ty = cx + sx, cy + sy + if not _has_any_overlap(tx, ty): + best_pos = (tx, ty) + found = True + break + if found: + break + if found: + break + + if best_pos is not None: + px[idx], py[idx] = best_pos + placed_x[n_placed] = px[idx] + placed_y[n_placed] = py[idx] + placed_w[n_placed] = cw + placed_h[n_placed] = ch + n_placed += 1 + + result = pos.clone() + result[:, 0] = torch.tensor(px, dtype=pos.dtype) + result[:, 1] = torch.tensor(py, dtype=pos.dtype) + return result + + +def _get_lr(epoch, warmup_epochs, total_epochs, peak_lr, min_lr=1e-4): + """Cosine annealing with linear warmup.""" + if epoch < warmup_epochs: + return peak_lr * (epoch + 1) / warmup_epochs + progress = (epoch - warmup_epochs) / max(total_epochs - warmup_epochs, 1) + return min_lr + 0.5 * (peak_lr - min_lr) * (1.0 + math.cos(math.pi * progress)) + + +_DEFAULT_HP = { + "wl_alpha": 0.03, + "spectral_scale_exp": 0.4, + "lam_ol_scale": 283.56, + "lam_ol_maint_scale": 31.16, + "lam_wl_full": 0.938, + "lam_wl_sub": 0.123, + "drift_weight": 0.0008, + "annealing_threshold": 214, + "grad_norm_base": 3.597, + "grad_norm_exp": 0.557, + "lr_floor_mult": 0.278, + "anneal_ol_floor": 0.062, + "anneal_ol_decay": 1.529, + "anneal_wl_ramp": 1.512, + "anneal_ol_gate": 0.042, + "anneal_wl_ungated": 0.143, + "p2_epochs": 3000, + "p2_lam_mult": 3.0, + "p4_epoch_base": 5000, + "p4_epoch_nref": 50.0, + "p4_lr": 0.0291, + "p4_wl_guard": 24.827, + "p4_guard_mult": 1.5, + "p4_guard_max": 80.0, + "p4_clip": 2.218, + "p5_epochs": 500, + "p5_lr": 0.00286, + "p5_ol_weight": 10.895, + "p5_clip": 1.6, +} + + +def _single_train_run(cell_features, pin_features, edge_list, use_spectral, + num_epochs, lr, verbose, log_interval, hp=None): + """Run one full optimization pass: phases 1-5. Returns (final_cell_features, final_wl).""" + global _WL_ALPHA + hp = {**_DEFAULT_HP, **(hp or {})} + _WL_ALPHA = hp["wl_alpha"] + cell_features = cell_features.clone() + N = cell_features.shape[0] + widths = cell_features[:, CellFeatureIdx.WIDTH] + heights = cell_features[:, CellFeatureIdx.HEIGHT] + + if use_spectral: + spectral_pos = _spectral_initial_placement( + cell_features, pin_features, edge_list, + scale_exp=hp["spectral_scale_exp"], + ) + if spectral_pos is not None: + cell_features[:, 2:4] = spectral_pos + + cell_positions = cell_features[:, 2:4].clone().detach().requires_grad_(True) + + n_ref = 50.0 + scale = max(1.0, N / n_ref) + lam_ol_active = hp["lam_ol_scale"] * scale + lam_ol_maint = hp["lam_ol_maint_scale"] * scale + lam_wl_full = hp["lam_wl_full"] + lam_wl_sub = hp["lam_wl_sub"] + drift_weight = hp["drift_weight"] + USE_ANNEALING = N >= hp["annealing_threshold"] + max_grad_norm = max(hp["grad_norm_base"], hp["grad_norm_base"] * scale ** hp["grad_norm_exp"]) + + optimizer = optim.Adam([cell_positions], lr=lr) + warmup_epochs = max(int(num_epochs * 0.05), 10) + + overlap_resolved = False + cur_lam_ol = lam_ol_active + cur_lam_wl = lam_wl_sub + + # ---- Phase 1: Main training loop ---- + for epoch in range(num_epochs): + current_lr = _get_lr(epoch, warmup_epochs, num_epochs, lr) + if not USE_ANNEALING and not overlap_resolved: + current_lr = max(current_lr, lr * hp["lr_floor_mult"]) + for pg in optimizer.param_groups: + pg["lr"] = current_lr + + optimizer.zero_grad() + cf_cur = cell_features.clone() + cf_cur[:, 2:4] = cell_positions + + wl_loss = wirelength_attraction_loss(cf_cur, pin_features, edge_list) + ol_loss = overlap_repulsion_loss(cf_cur, pin_features, edge_list) + + if USE_ANNEALING: + progress = epoch / num_epochs + lam_ol = lam_ol_active * max(hp["anneal_ol_floor"], 1.0 - progress * hp["anneal_ol_decay"]) + ol_cleared = ol_loss.item() < hp["anneal_ol_gate"] + lam_wl = lam_wl_full * min(1.0, progress * hp["anneal_wl_ramp"]) if ol_cleared else lam_wl_full * hp["anneal_wl_ungated"] + else: + lam_ol = cur_lam_ol + lam_wl = cur_lam_wl + + total_loss = lam_wl * wl_loss + lam_ol * ol_loss + drift_weight * ((cell_positions - cell_positions.mean(dim=0, keepdim=True)) ** 2).sum() / N + total_loss.backward() + torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=max_grad_norm) + optimizer.step() + + if not USE_ANNEALING: + with torch.no_grad(): + discrete_count = _count_discrete_overlaps(cell_positions, widths, heights) + if not overlap_resolved and discrete_count == 0: + overlap_resolved = True + cur_lam_ol = lam_ol_maint + cur_lam_wl = lam_wl_full + elif overlap_resolved and discrete_count > 0: + overlap_resolved = False + cur_lam_ol = lam_ol_active + cur_lam_wl = lam_wl_sub - # Placeholder - returns a constant loss (REPLACE THIS!) - return torch.tensor(1.0, requires_grad=True) + if verbose and (epoch % log_interval == 0 or epoch == num_epochs - 1): + if USE_ANNEALING: + with torch.no_grad(): + discrete_count = _count_discrete_overlaps(cell_positions, widths, heights) + print(f"Epoch {epoch}/{num_epochs} overlaps={discrete_count}:" + f" WL={wl_loss.item():.6f} OL={ol_loss.item():.6f} " + f"lam_ol={lam_ol:.1f} lam_wl={lam_wl:.3f} lr={current_lr:.6f}") + + # ---- Phase 2: Pure overlap reduction ---- + with torch.no_grad(): + pre_ol_count = _count_discrete_overlaps(cell_positions, widths, heights) + if pre_ol_count > 0: + ol_opt = optim.Adam([cell_positions], lr=lr) + for ol_epoch in range(hp["p2_epochs"]): + ol_opt.zero_grad() + cf_ol = cell_features.clone() + cf_ol[:, 2:4] = cell_positions + ol_loss = overlap_repulsion_loss(cf_ol, pin_features, edge_list) + (lam_ol_active * hp["p2_lam_mult"] * ol_loss).backward() + ol_opt.step() + if ol_epoch % 500 == 0: + with torch.no_grad(): + c = _count_discrete_overlaps(cell_positions, widths, heights) + if verbose: + print(f" OL-phase {ol_epoch}: overlaps={c} loss={ol_loss.item():.6f}") + if c == 0: + break + + # ---- Phase 3: Greedy legalization ---- + with torch.no_grad(): + legalized_pos = _legalize_placement(cell_positions, widths, heights) + post_count = _count_discrete_overlaps(legalized_pos, widths, heights) + cell_positions = legalized_pos.detach().clone().requires_grad_(True) + + # ---- Phase 4: Wirelength optimization ---- + if post_count == 0: + ft_epochs = max(300, int(hp["p4_epoch_base"] * (hp["p4_epoch_nref"] / N) ** 0.5)) + ft_opt = optim.Adam([cell_positions], lr=hp["p4_lr"]) + wl_guard = hp["p4_wl_guard"] + + for ft_epoch in range(ft_epochs): + ft_opt.zero_grad() + cf_ft = cell_features.clone() + cf_ft[:, 2:4] = cell_positions + wl_loss = wirelength_attraction_loss(cf_ft, pin_features, edge_list) + ol_loss = overlap_repulsion_loss(cf_ft, pin_features, edge_list) + (wl_loss + wl_guard * ol_loss).backward() + torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=hp["p4_clip"]) + ft_opt.step() + + if ft_epoch % 500 == 0 and ft_epoch > 0: + with torch.no_grad(): + c = _count_discrete_overlaps(cell_positions, widths, heights) + if c > 10: + wl_guard = min(wl_guard * hp["p4_guard_mult"], hp["p4_guard_max"]) + + with torch.no_grad(): + legalized_pos = _legalize_placement(cell_positions, widths, heights) + post_count = _count_discrete_overlaps(legalized_pos, widths, heights) + cell_positions = legalized_pos.detach().clone().requires_grad_(True) + + # ---- Phase 5: Post-legalization wirelength cleanup ---- + if post_count == 0: + final_opt = optim.Adam([cell_positions], lr=hp["p5_lr"]) + for _ in range(hp["p5_epochs"]): + final_opt.zero_grad() + cf_final = cell_features.clone() + cf_final[:, 2:4] = cell_positions + wl_loss = wirelength_attraction_loss(cf_final, pin_features, edge_list) + ol_loss = overlap_repulsion_loss(cf_final, pin_features, edge_list) + (wl_loss + hp["p5_ol_weight"] * ol_loss).backward() + torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=hp["p5_clip"]) + final_opt.step() + + with torch.no_grad(): + legalized_pos = _legalize_placement(cell_positions, widths, heights) + cell_positions = legalized_pos.detach().clone() + + final_cf = cell_features.clone() + final_cf[:, 2:4] = cell_positions.detach() + with torch.no_grad(): + final_wl = wirelength_attraction_loss(final_cf, pin_features, edge_list).item() + return final_cf, final_wl def train_placement( cell_features, pin_features, edge_list, - num_epochs=1000, - lr=0.01, + num_epochs=5000, + lr=0.05, lambda_wirelength=1.0, lambda_overlap=10.0, verbose=True, log_interval=100, + hp=None, ): - """Train the placement optimization using gradient descent. + """Placement optimizer with multi-restart for small N, annealing for large N. Args: cell_features: [N, 6] tensor with cell properties pin_features: [P, 7] tensor with pin properties edge_list: [E, 2] tensor with edge connectivity - num_epochs: Number of optimization iterations - lr: Learning rate for Adam optimizer - lambda_wirelength: Weight for wirelength loss - lambda_overlap: Weight for overlap loss + num_epochs: Max optimization iterations per restart + lr: Peak learning rate for Adam + lambda_wirelength: Base weight for wirelength loss (unused, kept for API compat) + lambda_overlap: Base weight for overlap loss (unused, kept for API compat) verbose: Whether to print progress log_interval: How often to print progress + hp: Optional hyperparameter dict (merged with _DEFAULT_HP) Returns: - Dictionary with: - - final_cell_features: Optimized cell positions - - initial_cell_features: Original cell positions (for comparison) - - loss_history: Loss values over time + Dictionary with final_cell_features, initial_cell_features, loss_history. """ - # Clone features and create learnable positions - cell_features = cell_features.clone() initial_cell_features = cell_features.clone() + N = cell_features.shape[0] + n_restarts = max(1, int(400 / N)) if N < 200 else 1 - # Make only cell positions require gradients - cell_positions = cell_features[:, 2:4].clone().detach() - cell_positions.requires_grad_(True) - - # Create optimizer - optimizer = optim.Adam([cell_positions], lr=lr) + best_cf = None + best_wl = float('inf') - # Track loss history - loss_history = { - "total_loss": [], - "wirelength_loss": [], - "overlap_loss": [], - } + for restart in range(n_restarts): + cf_run = cell_features.clone() - # Training loop - for epoch in range(num_epochs): - optimizer.zero_grad() + if restart > 0: + total_area = cf_run[:, 0].sum().item() + spread = (total_area ** 0.5) * 0.6 + angles = torch.rand(N) * 2 * 3.14159 + radii = torch.rand(N) * spread + cf_run[:, 2] = radii * torch.cos(angles) + cf_run[:, 3] = radii * torch.sin(angles) - # Create cell_features with current positions - cell_features_current = cell_features.clone() - cell_features_current[:, 2:4] = cell_positions + use_spectral = (restart == 0) + run_verbose = verbose and (restart == 0) - # Calculate losses - wl_loss = wirelength_attraction_loss( - cell_features_current, pin_features, edge_list - ) - overlap_loss = overlap_repulsion_loss( - cell_features_current, pin_features, edge_list + final_cf, final_wl = _single_train_run( + cf_run, pin_features, edge_list, use_spectral, + num_epochs, lr, run_verbose, log_interval, hp=hp, ) - # Combined loss - total_loss = lambda_wirelength * wl_loss + lambda_overlap * overlap_loss - - # Backward pass - total_loss.backward() - - # Gradient clipping to prevent extreme updates - torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) + overlap_count = len(calculate_cells_with_overlaps(final_cf)) + if overlap_count == 0 and final_wl < best_wl: + best_wl = final_wl + best_cf = final_cf - # Update positions - optimizer.step() - - # Record losses - loss_history["total_loss"].append(total_loss.item()) - loss_history["wirelength_loss"].append(wl_loss.item()) - loss_history["overlap_loss"].append(overlap_loss.item()) - - # Log progress - if verbose and (epoch % log_interval == 0 or epoch == num_epochs - 1): - print(f"Epoch {epoch}/{num_epochs}:") - print(f" Total Loss: {total_loss.item():.6f}") - print(f" Wirelength Loss: {wl_loss.item():.6f}") - print(f" Overlap Loss: {overlap_loss.item():.6f}") + if verbose: + tag = "*best*" if final_wl <= best_wl and overlap_count == 0 else "" + print(f"Restart {restart+1}/{n_restarts}: wl={final_wl:.4f} " + f"overlaps={overlap_count} {tag}") - # Create final cell features - final_cell_features = cell_features.clone() - final_cell_features[:, 2:4] = cell_positions.detach() + if best_cf is None: + best_cf = final_cf return { - "final_cell_features": final_cell_features, + "final_cell_features": best_cf, "initial_cell_features": initial_cell_features, - "loss_history": loss_history, + "loss_history": {}, }