From 77cef7d4439babdf4db27580dc4aecab3f78ebf4 Mon Sep 17 00:00:00 2001 From: Sidhartha Parhi Date: Tue, 24 Mar 2026 13:08:46 -0500 Subject: [PATCH 1/3] Optimization improvements --- placement.py | 125 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 7 deletions(-) diff --git a/placement.py b/placement.py index d70412d..929e9da 100644 --- a/placement.py +++ b/placement.py @@ -295,7 +295,7 @@ def wirelength_attraction_loss(cell_features, pin_features, edge_list): # Total wirelength total_wirelength = torch.sum(smooth_manhattan) - + return total_wirelength / edge_list.shape[0] # Normalize by number of edges @@ -343,6 +343,8 @@ def overlap_repulsion_loss(cell_features, pin_features, edge_list): Returns: Scalar loss value (should be 0 when no overlaps exist) """ + # def calculate_overlap_1d() + N = cell_features.shape[0] if N <= 1: return torch.tensor(0.0, requires_grad=True) @@ -357,17 +359,90 @@ def overlap_repulsion_loss(cell_features, pin_features, edge_list): # Delete this placeholder and add your implementation: # Placeholder - returns a constant loss (REPLACE THIS!) - return torch.tensor(1.0, requires_grad=True) + # return torch.tensor(1.0, requires_grad=True) + + areas = cell_features[:, CellFeatureIdx.AREA] # = areas + # num_pins = cell_features[:, CellFeatureIdx.NUM_PINS] = num_pins_per_cell.float() + x = cell_features[:, CellFeatureIdx.X] # = 0.0 # x position (initialized to 0) + y = cell_features[:, CellFeatureIdx.Y] # = 0.0 # y position (initialized to 0) + w = cell_features[:, CellFeatureIdx.WIDTH] # = cell_widths + h = cell_features[:, CellFeatureIdx.HEIGHT] # = cell_heights + + # # Calculate absolute pin positions + cell_indices = pin_features[:, 0].long() + pin_absolute_x = x[cell_indices] + pin_features[:, 1] + pin_absolute_y = y[cell_indices] + pin_features[:, 2] + + # # Get source and target pin positions for each edge + # src_pins = edge_list[:, 0].long() + # tgt_pins = edge_list[:, 1].long() + + # src_x = pin_absolute_x[src_pins] + # src_y = pin_absolute_y[src_pins] + # tgt_x = pin_absolute_x[tgt_pins] + # tgt_y = pin_absolute_y[tgt_pins] + + # TODO: Optimize torch.triu + overlap_x = w.unsqueeze(0) + w.unsqueeze(1) + overlap_x /= 2. + overlap_x -= torch.abs(x.unsqueeze(0) - x.unsqueeze(1)) + overlap_x = torch.relu(overlap_x) + overlap_x = torch.triu(overlap_x, diagonal=1) + + overlap_y = h.unsqueeze(0) + h.unsqueeze(1) + overlap_y /= 2. + overlap_y -= torch.abs(y.unsqueeze(0) - y.unsqueeze(1)) + overlap_y = torch.relu(overlap_y) + overlap_y = torch.triu(overlap_y, diagonal=1) + + loss = overlap_x * overlap_y + # loss *= areas.unsqueeze(0) + areas.unsqueeze(1) + # print(loss.mean()) + # breakpoint() + + # loss += 0.001*(x**2 + y**2) + # loss += 0.001*torch.mean(x**2 + y**2) + # loss += 0.001*torch.mean(pin_absolute_x**2 + pin_absolute_y**2) + # loss = torch.where( + # loss == 0, + # loss + 0.001*torch.mean(x**2 + y**2), + # loss + # ) + + # canvas_area = (torch.max(x+w/2.) - torch.min(x-w/2.)) * (torch.max(y+h/2.) - torch.min(y-h/2.)) + # canvas_area = canvas_area - 27*(torch.sum(w)**0.5 * torch.sum(h)**0.5) + # print(loss.min(), loss.max(), canvas_area) + # loss = canvas_area + # loss += 1.*canvas_area + # loss = torch.where( + # loss == 0, + # loss + 1.*canvas_area, + # loss + # ) + + # dist_sq = (x.unsqueeze(0) - x.unsqueeze(1))**2 + (y.unsqueeze(0) - y.unsqueeze(1))**2 + # dist_sq = dist_sq / torch.max(dist_sq) + # loss = torch.where( + # loss == 0, + # loss + .0001*dist_sq, # 0.00001* + # loss + # ) + loss *= .2 # 200000. + # loss = torch.triu(loss, diagonal=1) + # breakpoint() + loss = torch.mean(loss) + + return loss def train_placement( cell_features, pin_features, edge_list, - num_epochs=1000, + num_epochs=40000, lr=0.01, - lambda_wirelength=1.0, - lambda_overlap=10.0, + lambda_wirelength=100000000.0, + lambda_overlap=1000000000.0, verbose=True, log_interval=100, ): @@ -400,6 +475,9 @@ def train_placement( # Create optimizer optimizer = optim.Adam([cell_positions], lr=lr) + # scheduler = optim.lr_scheduler.CosineAnnealingLR( + # optimizer, T_max=num_epochs, eta_min=lr*0.3 + # ) # Track loss history loss_history = { @@ -417,24 +495,57 @@ def train_placement( cell_features_current[:, 2:4] = cell_positions # Calculate losses + # if epoch < num_epochs//2: + # if epoch % 2 == 0: + # if epoch < 1.8*num_epochs: + # b = (epoch // (num_epochs//5)) % 2 == 0 + # else: + # b = .75*num_epochs < epoch < .85*num_epochs + if epoch >= .8*num_epochs: + lambda_overlap_ = lambda_overlap + lambda_wirelength_ = 0. + elif (epoch // (num_epochs//8)) % 2 == 0: + lambda_overlap_ = 1000. + lambda_wirelength_ = lambda_wirelength + else: + lambda_overlap_ = lambda_overlap + lambda_wirelength_ = 1000. wl_loss = wirelength_attraction_loss( cell_features_current, pin_features, edge_list ) + x = cell_features_current[:, CellFeatureIdx.X] # = 0.0 # x position (initialized to 0) + y = cell_features_current[:, CellFeatureIdx.Y] # = 0.0 # y position (initialized to 0) + # cell_indices = pin_features[:, 0].long() + # pin_absolute_x = x[cell_indices] + pin_features[:, 1] + # pin_absolute_y = y[cell_indices] + pin_features[:, 2] + # wl_loss += .05*torch.mean(x**2 + y**2) + # wl_loss += .001*torch.mean(x**2 + y**2) + # wl_loss += .001*torch.mean(pin_absolute_x**2 + pin_absolute_y**2) + # loss = torch.where( + # loss == 0, + # loss + 0.001*torch.mean(pin_absolute_x**2 + pin_absolute_y**2), + # loss + # ) + # overlap_loss = torch.tensor(0.) + # else: overlap_loss = overlap_repulsion_loss( cell_features_current, pin_features, edge_list ) + # wl_loss = torch.tensor(0.) # Combined loss - total_loss = lambda_wirelength * wl_loss + lambda_overlap * overlap_loss + total_loss = lambda_wirelength_ * wl_loss + lambda_overlap_ * overlap_loss # Backward pass total_loss.backward() # Gradient clipping to prevent extreme updates - torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) + # torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) # Update positions optimizer.step() + # scheduler.step() + # print(scheduler.get_lr()) # Record losses loss_history["total_loss"].append(total_loss.item()) From 5307d4d02376cc9f48bfa6f39ca87ae710f4bab7 Mon Sep 17 00:00:00 2001 From: Sidhartha Parhi Date: Tue, 24 Mar 2026 14:50:33 -0500 Subject: [PATCH 2/3] Experiments --- README.md | 10 ++++++++++ placement.py | 6 +++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index df0c441..2f7e3b5 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,16 @@ We will review submissions on a rolling basis. | 2 | Add Yours! | | | | | +- Gradient Clipping + - LR Scheduler +- Wirelength loss +- Make loss reflect the fact that every step in the "GAN" leads to a better final state. + - Currently, the loss doesn't reflect it much at all, but it does automatically get better with each "GAN" step. +- Optimizer (and optim hyperparams) +- "Budging it" + +- All valid permutations -> min of wirelength (except make it differentiable) + ## Leaderboard (sorted by overlap) (OLD; test suite has been updated; see above) diff --git a/placement.py b/placement.py index 929e9da..61c881e 100644 --- a/placement.py +++ b/placement.py @@ -501,7 +501,7 @@ def train_placement( # b = (epoch // (num_epochs//5)) % 2 == 0 # else: # b = .75*num_epochs < epoch < .85*num_epochs - if epoch >= .8*num_epochs: + if epoch >= 0.8*num_epochs: # 35000: lambda_overlap_ = lambda_overlap lambda_wirelength_ = 0. elif (epoch // (num_epochs//8)) % 2 == 0: @@ -519,7 +519,7 @@ def train_placement( # pin_absolute_x = x[cell_indices] + pin_features[:, 1] # pin_absolute_y = y[cell_indices] + pin_features[:, 2] # wl_loss += .05*torch.mean(x**2 + y**2) - # wl_loss += .001*torch.mean(x**2 + y**2) + wl_loss += .001*torch.mean(x**2 + y**2) # wl_loss += .001*torch.mean(pin_absolute_x**2 + pin_absolute_y**2) # loss = torch.where( # loss == 0, @@ -540,7 +540,7 @@ def train_placement( total_loss.backward() # Gradient clipping to prevent extreme updates - # torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) + torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) # Update positions optimizer.step() From b8dad4930495db68f82ab4e37a680dbcef9b3786 Mon Sep 17 00:00:00 2001 From: Sidhartha Parhi Date: Tue, 31 Mar 2026 10:15:48 -0500 Subject: [PATCH 3/3] Implemented multi-stage alternating optimization and initial centering regularization to find the Pareto Front. --- .gitignore | 3 + README.md | 60 ++++------ placement.py | 330 +++++++++++++++++++++++++++++---------------------- 3 files changed, 213 insertions(+), 180 deletions(-) diff --git a/.gitignore b/.gitignore index fdd0c6d..5e78125 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,7 @@ *.gif *.bmp +placement_experiments.py +README_AND_NOTES.md + **/__pycache__/** \ No newline at end of file diff --git a/README.md b/README.md index 2f7e3b5..e761714 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,8 @@ The deadline is when all intern slots for summer 2026 are filled. We will review 1. **Fork this repository.** 2. Solve the placement problem using your preferred tools or scripts. -3. Run the test script to evaluate your solution and obtain the overlap and wirelength metrics. -4. Submit a pull request with your updated leaderboard entry and instructions for me to access your actual submission (it's fine if it's public). +3. Run the first 10 tests to evaluate your solution and obtain the overlap and wirelength metrics. Report Average Overlap, Wirelength and total Runtime. *Test cases 11 and 12 are extra credit, give them a shot if you have some time.* +5. Submit a pull request with your updated leaderboard entry and instructions for me to access your actual submission (it's fine if it's public). Note: You can use any libraries or frameworks you like, but please ensure that your code is well-documented and easy to follow. @@ -27,30 +27,11 @@ You may submit multiple solutions to try and increase your score. We will review submissions on a rolling basis. -## New Leaderboard (sorted by overlap) +## Leaderboard (sorted by overlap) | Rank | Name | Overlap | Wirelength (um) | Runtime (s) | Notes | |------|-----------------|-------------|-----------------|-------------|----------------------| -| 1 | example | 0.5000 | 0.5 | 10 | example submission | -| 2 | Add Yours! | | | | | - - -- Gradient Clipping - - LR Scheduler -- Wirelength loss -- Make loss reflect the fact that every step in the "GAN" leads to a better final state. - - Currently, the loss doesn't reflect it much at all, but it does automatically get better with each "GAN" step. -- Optimizer (and optim hyperparams) -- "Budging it" - -- All valid permutations -> min of wirelength (except make it differentiable) - - -## Leaderboard (sorted by overlap) (OLD; test suite has been updated; see above) - -| Rank | Name | Overlap | Wirelength (um) | Runtime (s) | Notes | -|------|-----------------|-------------|-----------------|-------------|----------------------| -| 1 | Shashank Shriram | 0.0000 | 0.1310 | 11.32 | 🏎️💥 | +| 1 | Sidhartha Parhi | 0.0000 | 0.2549 | 458.28 | Multi-stage alternating optimization and initial centering regularization to find the Pareto Front. | | 2 | Brayden Rudisill | 0.0000 | 0.2611 | 50.51 | Timed on a mac air | | 3 | manuhalapeth | 0.0000 | 0.2630 | 196.8 | | | 4 | Neil Teje | 0.0000 | 0.2700 | 24.00s | | @@ -58,22 +39,23 @@ We will review submissions on a rolling basis. | 6 | William Pan | 0.0000 | 0.2848 | 155.33s | | | 7 | Ashmit Dutta | 0.0000 | 0.2870 | 995.58 | Spent my entire morning (12 am - 6 am) doing this :P | | 8 | Pawan Paleja | 0.0000 | 0.3311 | 1.74s | Implemented hint for loss func, cosine annealing on learning rate with warmup, std annealing on lambda weight. Used optuna to tune hyperparam. Tested on gh codespaces 2-core. | -| 9 | Gabriel Del Monte | 0.0000 | 0.3427 | 606.07 | | -| 10 | Aleksey Valouev| 0.0000 | 0.3577 | 118.98 | | -| 11 | Mohul Shukla | 0.0000 | 0.5048 | 54.60s | | -| 12 | Ryan Hulke | 0.0000 | 0.5226 | 166.24 | | -| 13 | Neel Shah | 0.0000 | 0.5445 | 45.40 | Zero overlaps on all tests, adaptive schedule + early stop | -| 14 | Nawel Asgar | 0.0000 | 0.5675 | 81.49 | Adaptive penalty scaling with cubic gradients and design-size optimization -| 15 | Shiva Baghel | 0.0000 | 0.5885 | 491.00 | Stable zero-overlap with balanced optimization | -| 16 | Vansh Jain | 0.0000 | 0.9352 | 86.36 | | -| 17 | Akash Pai | 0.0006 | 0.4933 | 326.25s | | -| 18 | Zade Mahayni | 0.00665 | 0.5157 | 127.4 | Will try again tomorrow | -| 19 | Nithin Yanna | 0.0148 | 0.5034 | 247.30s | aggressive overlap penalty with quadratic scaling | -| 20 | Sean Ko | 0.0271 | .5138 | 31.83s | lr increase, decrease epoch, increase lambda overlap and decreased lambda wire_length + log penalty loss | -| 21 | Keya Gohil | 0.0155 | 0.4678 | 1513.07 | Still working | -| 22 | Prithvi Seran | 0.0499 | 0.4890 | 398.58 | | -| 23 | partcl example | 0.8 | 0.4 | 5 | example | -| 24 | Add Yours! | | | | | + 9 | Shashank Shriram | 0.0000 | 0.3312 | 11.32 | 🏎️💥 | +| 10 | Gabriel Del Monte | 0.0000 | 0.3427 | 606.07 | | +| 11 | Aleksey Valouev| 0.0000 | 0.3577 | 118.98 | | +| 12 | Mohul Shukla | 0.0000 | 0.5048 | 54.60s | | +| 13 | Ryan Hulke | 0.0000 | 0.5226 | 166.24 | | +| 14 | Neel Shah | 0.0000 | 0.5445 | 45.40 | Zero overlaps on all tests, adaptive schedule + early stop | +| 15 | Nawel Asgar | 0.0000 | 0.5675 | 81.49 | Adaptive penalty scaling with cubic gradients and design-size optimization +| 16 | Shiva Baghel | 0.0000 | 0.5885 | 491.00 | Stable zero-overlap with balanced optimization | +| 17 | Vansh Jain | 0.0000 | 0.9352 | 86.36 | | +| 18 | Akash Pai | 0.0006 | 0.4933 | 326.25s | | +| 19 | Zade Mahayni | 0.00665 | 0.5157 | 127.4 | Will try again tomorrow | +| 20 | Nithin Yanna | 0.0148 | 0.5034 | 247.30s | aggressive overlap penalty with quadratic scaling | +| 21 | Sean Ko | 0.0271 | .5138 | 31.83s | lr increase, decrease epoch, increase lambda overlap and decreased lambda wire_length + log penalty loss | +| 22 | Keya Gohil | 0.0155 | 0.4678 | 1513.07 | Still working | +| 23 | Prithvi Seran | 0.0499 | 0.4890 | 398.58 | | +| 24 | partcl example | 0.8 | 0.4 | 5 | example | +| 25 | Add Yours! | | | | | > **To add your results:** > Insert a new row in the table above with your name, overlap, wirelength, and any notes. Ensure you sort by overlap. diff --git a/placement.py b/placement.py index 61c881e..b50e721 100644 --- a/placement.py +++ b/placement.py @@ -44,7 +44,6 @@ import torch import torch.optim as optim - # Feature index enums for cleaner code access class CellFeatureIdx(IntEnum): """Indices for cell feature tensor columns.""" @@ -83,6 +82,26 @@ class PinFeatureIdx(IntEnum): # Output directory OUTPUT_DIR = os.path.dirname(os.path.abspath(__file__)) +# Training Hyperparameters +LR = 0.01 +PARETO_NUM_EPOCHS = 5000 +TRAINING_STAGES = { + "0": 32000, + "1": 8000, + "2": 20000, + "3": 10000, + "4": 10000, + "5": 30000, + "6": 50000, + "7": 15000 +} +LAMBDA_WIRELENGTH = 100000000.0 +LAMBDA_OVERLAP = 65000000.0 +LAMBDA_PARETO = 1000.0 # 100.0 +LAMBDA_CENTERING = 1.0 +ALPHA_MANHATTAN = 0.1 # Smoothing parameter +COS_LR_MIN_FCTR = 0.01 + # ======= SETUP ======= def generate_placement_input(num_macros, num_std_cells): @@ -262,8 +281,8 @@ def wirelength_attraction_loss(cell_features, pin_features, edge_list): Returns: Scalar loss value """ - if edge_list.shape[0] == 0: - return torch.tensor(0.0, requires_grad=True) + # if edge_list.shape[0] == 0: + # return torch.tensor(0.0, requires_grad=True) # Update absolute pin positions based on cell positions cell_positions = cell_features[:, 2:4] # [N, 2] @@ -284,21 +303,36 @@ def wirelength_attraction_loss(cell_features, pin_features, edge_list): # Calculate smooth approximation of Manhattan distance # Using log-sum-exp approximation for differentiability - alpha = 0.1 # Smoothing parameter dx = torch.abs(src_x - tgt_x) dy = torch.abs(src_y - tgt_y) # Smooth L1 distance with numerical stability - smooth_manhattan = alpha * torch.logsumexp( - torch.stack([dx / alpha, dy / alpha], dim=0), dim=0 + smooth_manhattan = ALPHA_MANHATTAN * torch.logsumexp( + torch.stack([dx / ALPHA_MANHATTAN, dy / ALPHA_MANHATTAN], dim=0), dim=0 ) # Total wirelength total_wirelength = torch.sum(smooth_manhattan) - - return total_wirelength / edge_list.shape[0] # Normalize by number of edges + + # Normalize by number of edges + total_wirelength = total_wirelength / edge_list.shape[0] + + return total_wirelength + + +wirelength_attraction_loss_jit = torch.compile(wirelength_attraction_loss) + + +@torch.compile +def centering_regularization(cell_features, pin_features, edge_list): + """Calculate regularization term to center the cells on the chip.""" + x = cell_features[:, CellFeatureIdx.X] + y = cell_features[:, CellFeatureIdx.Y] + centering_reg = LAMBDA_CENTERING * (torch.sum((x**2 + y**2)**0.5) / x.shape[0]) + return centering_reg +@torch.compile def overlap_repulsion_loss(cell_features, pin_features, edge_list): """Calculate loss to prevent cell overlaps. @@ -343,94 +377,31 @@ def overlap_repulsion_loss(cell_features, pin_features, edge_list): Returns: Scalar loss value (should be 0 when no overlaps exist) """ - # def calculate_overlap_1d() - - N = cell_features.shape[0] - if N <= 1: - return torch.tensor(0.0, requires_grad=True) - - # TODO: Implement overlap detection and loss calculation here - # - # Your implementation should: - # 1. Extract cell positions, widths, and heights - # 2. Compute pairwise overlaps using vectorized operations - # 3. Return a scalar loss that is zero when no overlaps exist - # - # Delete this placeholder and add your implementation: - - # Placeholder - returns a constant loss (REPLACE THIS!) - # return torch.tensor(1.0, requires_grad=True) - - areas = cell_features[:, CellFeatureIdx.AREA] # = areas + def calculate_overlap_1d(x_or_y, w_or_h): + overlap = (w_or_h.unsqueeze(0) + w_or_h.unsqueeze(1)) / 2. + overlap = overlap - torch.abs(x_or_y.unsqueeze(0) - x_or_y.unsqueeze(1)) + overlap = torch.relu(overlap) + overlap = torch.triu(overlap, diagonal=1) + return overlap + + # N = cell_features.shape[0] + # if N <= 1: + # return torch.tensor(0.0, requires_grad=True) + + # areas = cell_features[:, CellFeatureIdx.AREA] # = areas # num_pins = cell_features[:, CellFeatureIdx.NUM_PINS] = num_pins_per_cell.float() - x = cell_features[:, CellFeatureIdx.X] # = 0.0 # x position (initialized to 0) - y = cell_features[:, CellFeatureIdx.Y] # = 0.0 # y position (initialized to 0) - w = cell_features[:, CellFeatureIdx.WIDTH] # = cell_widths - h = cell_features[:, CellFeatureIdx.HEIGHT] # = cell_heights + x = cell_features[:, CellFeatureIdx.X] + y = cell_features[:, CellFeatureIdx.Y] + w = cell_features[:, CellFeatureIdx.WIDTH] + h = cell_features[:, CellFeatureIdx.HEIGHT] - # # Calculate absolute pin positions - cell_indices = pin_features[:, 0].long() - pin_absolute_x = x[cell_indices] + pin_features[:, 1] - pin_absolute_y = y[cell_indices] + pin_features[:, 2] - - # # Get source and target pin positions for each edge - # src_pins = edge_list[:, 0].long() - # tgt_pins = edge_list[:, 1].long() - - # src_x = pin_absolute_x[src_pins] - # src_y = pin_absolute_y[src_pins] - # tgt_x = pin_absolute_x[tgt_pins] - # tgt_y = pin_absolute_y[tgt_pins] - - # TODO: Optimize torch.triu - overlap_x = w.unsqueeze(0) + w.unsqueeze(1) - overlap_x /= 2. - overlap_x -= torch.abs(x.unsqueeze(0) - x.unsqueeze(1)) - overlap_x = torch.relu(overlap_x) - overlap_x = torch.triu(overlap_x, diagonal=1) - - overlap_y = h.unsqueeze(0) + h.unsqueeze(1) - overlap_y /= 2. - overlap_y -= torch.abs(y.unsqueeze(0) - y.unsqueeze(1)) - overlap_y = torch.relu(overlap_y) - overlap_y = torch.triu(overlap_y, diagonal=1) + overlap_x = calculate_overlap_1d(x, w) + overlap_y = calculate_overlap_1d(y, h) loss = overlap_x * overlap_y - # loss *= areas.unsqueeze(0) + areas.unsqueeze(1) - # print(loss.mean()) - # breakpoint() - - # loss += 0.001*(x**2 + y**2) - # loss += 0.001*torch.mean(x**2 + y**2) - # loss += 0.001*torch.mean(pin_absolute_x**2 + pin_absolute_y**2) - # loss = torch.where( - # loss == 0, - # loss + 0.001*torch.mean(x**2 + y**2), - # loss - # ) - - # canvas_area = (torch.max(x+w/2.) - torch.min(x-w/2.)) * (torch.max(y+h/2.) - torch.min(y-h/2.)) - # canvas_area = canvas_area - 27*(torch.sum(w)**0.5 * torch.sum(h)**0.5) - # print(loss.min(), loss.max(), canvas_area) - # loss = canvas_area - # loss += 1.*canvas_area - # loss = torch.where( - # loss == 0, - # loss + 1.*canvas_area, - # loss - # ) - - # dist_sq = (x.unsqueeze(0) - x.unsqueeze(1))**2 + (y.unsqueeze(0) - y.unsqueeze(1))**2 - # dist_sq = dist_sq / torch.max(dist_sq) - # loss = torch.where( - # loss == 0, - # loss + .0001*dist_sq, # 0.00001* - # loss - # ) - loss *= .2 # 200000. - # loss = torch.triu(loss, diagonal=1) - # breakpoint() - loss = torch.mean(loss) + # loss = torch.mean(loss) + num_non_zero = (loss != 0).sum().clamp(min=1) + loss = loss.sum() / num_non_zero return loss @@ -439,10 +410,10 @@ def train_placement( cell_features, pin_features, edge_list, - num_epochs=40000, - lr=0.01, - lambda_wirelength=100000000.0, - lambda_overlap=1000000000.0, + num_epochs=None, + lr=LR, + lambda_wirelength=LAMBDA_WIRELENGTH, + lambda_overlap=LAMBDA_OVERLAP, verbose=True, log_interval=100, ): @@ -465,6 +436,72 @@ def train_placement( - initial_cell_features: Original cell positions (for comparison) - loss_history: Loss values over time """ + N = cell_features.shape[0] + + def pareto_alternating_optimization(training_stage, epoch_curr): + """Alternate objectives to find an optimal solution within the Pareto front.""" + nonlocal scheduler + if (epoch_curr // PARETO_NUM_EPOCHS) % 2 == 0: + lambda_overlap_final = LAMBDA_PARETO if training_stage in {0, 2, 6} else lambda_overlap + lambda_wirelength_final = lambda_wirelength + else: + lambda_overlap_final = lambda_overlap + lambda_wirelength_final = LAMBDA_PARETO if training_stage in {0, 2, 6} else lambda_wirelength + if scheduler is not None: + scheduler = None + return lambda_overlap_final, lambda_wirelength_final + + def overlap_optimization(training_stage): + """Prioritize overlap objective.""" + nonlocal scheduler + lambda_overlap_final = lambda_overlap + lambda_wirelength_final = 0. + if scheduler is None and N >= 25 and training_stage in {5, 7}: + for pg in optimizer.param_groups: + pg['lr'] = lr + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=get_stage_epochs(N, str(training_stage), TRAINING_STAGES), eta_min=lr * COS_LR_MIN_FCTR + ) + return lambda_overlap_final, lambda_wirelength_final + + def get_stage_epochs(N, training_stage, training_stages): + num_epochs_stage = training_stages[training_stage] + training_stage = int(training_stage) + if training_stage == 1: + if N < 50: + num_epochs_stage = 1500 + elif N < 150: + num_epochs_stage = 4000 + elif N < 250: + num_epochs_stage = 6000 + elif training_stage in {3, 5, 7}: + if N < 150: + num_epochs_stage = 1000 + elif N < 250: + num_epochs_stage = 2000 + return num_epochs_stage + + def get_cumulative_epochs(N, training_stages): + epoch_stages = set() + num_epochs_tot = 0 + for training_stage in training_stages.keys(): + num_epochs_tot += get_stage_epochs(N, training_stage, training_stages) + epoch_stages.add(num_epochs_tot) + + return num_epochs_tot, epoch_stages + + # @torch.compile + # def freeze_adam_state(optimizer, params): + # for p in params: + # # p.grad.zero_() + # p.grad = None + # state = optimizer.state[p] + # if state: + # state['exp_avg'].zero_() + # state['exp_avg_sq'].zero_() + # if 'max_exp_avg_sq' in state: + # state['max_exp_avg_sq'].zero_() + # Clone features and create learnable positions cell_features = cell_features.clone() initial_cell_features = cell_features.clone() @@ -474,11 +511,26 @@ def train_placement( cell_positions.requires_grad_(True) # Create optimizer - optimizer = optim.Adam([cell_positions], lr=lr) + optimizer = optim.Adam([cell_positions], lr=lr) #, foreach=False, fused=False) + # optimizer = optim.SGD([cell_positions], lr=lr) + scheduler = None # scheduler = optim.lr_scheduler.CosineAnnealingLR( # optimizer, T_max=num_epochs, eta_min=lr*0.3 # ) + # Initialize training stage and epochs + training_stage = 0 + epoch_curr = 0 + epoch_stages = None + if num_epochs is None: + num_epochs, epoch_stages = get_cumulative_epochs(N, TRAINING_STAGES) + + # Initialize macro position freezing + freeze_macros = freeze_std_cells = False + areas = cell_features[:, CellFeatureIdx.AREA] + idx_macros = (areas >= MIN_MACRO_AREA) & (areas < MAX_MACRO_AREA) + idx_std_cells = torch.isin(areas, torch.tensor(STANDARD_CELL_AREAS)) + # Track loss history loss_history = { "total_loss": [], @@ -494,63 +546,57 @@ def train_placement( cell_features_current = cell_features.clone() cell_features_current[:, 2:4] = cell_positions - # Calculate losses - # if epoch < num_epochs//2: - # if epoch % 2 == 0: - # if epoch < 1.8*num_epochs: - # b = (epoch // (num_epochs//5)) % 2 == 0 - # else: - # b = .75*num_epochs < epoch < .85*num_epochs - if epoch >= 0.8*num_epochs: # 35000: - lambda_overlap_ = lambda_overlap - lambda_wirelength_ = 0. - elif (epoch // (num_epochs//8)) % 2 == 0: - lambda_overlap_ = 1000. - lambda_wirelength_ = lambda_wirelength + # Update training stage and macro position freezing + if epoch in epoch_stages: + training_stage += 1 + if training_stage == 3 or training_stage >= 5: + freeze_macros = False + if training_stage >= 6: + freeze_std_cells = True + elif training_stage == 2 or training_stage == 4: + freeze_macros = True + + epoch_curr = 0 + + # Set training hyperparams and optimization method + if training_stage % 2 == 0: + lambda_overlap_final, lambda_wirelength_final = pareto_alternating_optimization(training_stage, epoch_curr) else: - lambda_overlap_ = lambda_overlap - lambda_wirelength_ = 1000. - wl_loss = wirelength_attraction_loss( - cell_features_current, pin_features, edge_list - ) - x = cell_features_current[:, CellFeatureIdx.X] # = 0.0 # x position (initialized to 0) - y = cell_features_current[:, CellFeatureIdx.Y] # = 0.0 # y position (initialized to 0) - # cell_indices = pin_features[:, 0].long() - # pin_absolute_x = x[cell_indices] + pin_features[:, 1] - # pin_absolute_y = y[cell_indices] + pin_features[:, 2] - # wl_loss += .05*torch.mean(x**2 + y**2) - wl_loss += .001*torch.mean(x**2 + y**2) - # wl_loss += .001*torch.mean(pin_absolute_x**2 + pin_absolute_y**2) - # loss = torch.where( - # loss == 0, - # loss + 0.001*torch.mean(pin_absolute_x**2 + pin_absolute_y**2), - # loss - # ) - # overlap_loss = torch.tensor(0.) - # else: - overlap_loss = overlap_repulsion_loss( - cell_features_current, pin_features, edge_list - ) - # wl_loss = torch.tensor(0.) + lambda_overlap_final, lambda_wirelength_final = overlap_optimization(training_stage) + + # Calculate losses + overlap_loss = overlap_repulsion_loss(cell_features_current, pin_features, edge_list) + wl_loss = wirelength_attraction_loss_jit(cell_features_current, pin_features, edge_list) + # Stage 0 should also regularize for centering the cells on the chip, + # as this leads to an easier manifold to optimize, and it reduces wirelength on average. + if training_stage in {0, 1}: + wl_loss += centering_regularization(cell_features_current, pin_features, edge_list) # Combined loss - total_loss = lambda_wirelength_ * wl_loss + lambda_overlap_ * overlap_loss + total_loss = lambda_wirelength_final * wl_loss + lambda_overlap_final * overlap_loss # Backward pass total_loss.backward() # Gradient clipping to prevent extreme updates - torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=5.0) + torch.nn.utils.clip_grad_norm_([cell_positions], max_norm=2.5) + + # Freeze appropriate cell positions based on training stage + if freeze_macros: + cell_positions.grad[idx_macros] = 0. + elif freeze_std_cells: + cell_positions[idx_std_cells].grad = None # Update positions optimizer.step() - # scheduler.step() - # print(scheduler.get_lr()) + if scheduler is not None: + scheduler.step() # Record losses - loss_history["total_loss"].append(total_loss.item()) - loss_history["wirelength_loss"].append(wl_loss.item()) - loss_history["overlap_loss"].append(overlap_loss.item()) + if verbose and (epoch % log_interval == 0 or epoch == num_epochs - 1): + loss_history["total_loss"].append(total_loss.item()) + loss_history["wirelength_loss"].append(wl_loss.item()) + loss_history["overlap_loss"].append(overlap_loss.item()) # Log progress if verbose and (epoch % log_interval == 0 or epoch == num_epochs - 1): @@ -559,6 +605,8 @@ def train_placement( print(f" Wirelength Loss: {wl_loss.item():.6f}") print(f" Overlap Loss: {overlap_loss.item():.6f}") + epoch_curr += 1 + # Create final cell features final_cell_features = cell_features.clone() final_cell_features[:, 2:4] = cell_positions.detach()