-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathargs.py
More file actions
40 lines (36 loc) · 3.52 KB
/
Copy pathargs.py
File metadata and controls
40 lines (36 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42, help="random seed")
parser.add_argument("--num_classes", type=int, default=32, help="number of classes")
parser.add_argument("--gnn_type", type=str, choices=["gin", "gatv2"], default="gatv2", help="GNN type")
parser.add_argument("--gnn_hidden_dim", type=int, default=256, help="hidden dimension of GNN")
parser.add_argument("--gnn_output_dim", type=int, default=256, help="output dimension of GNN")
parser.add_argument("--num_gnn_layers", type=int, default=8, help="number of GNN layers")
parser.add_argument("--gnn_num_heads", type=int, default=4, help="number of attention heads for each GNN layer")
parser.add_argument("--jk_mode", type=str, default="last", help="Jumping Knowledge mode") # "last", "cat", "max", "lstm"
parser.add_argument("--gnn_dropout", type=float, default=0.2, help="dropout rate for GNN")
parser.add_argument("--transformer_num_heads", type=int, default=8, help="number of attention heads for each transformer layer")
parser.add_argument("--transformer_hidden_dim", type=int, default=512, help="hidden dimension of transformer")
parser.add_argument("--transformer_feedforward_dim", type=int, default=2048, help="feedforward dimension of transformer")
parser.add_argument("--num_transformer_layers", type=int, default=8, help="number of transformer layers")
parser.add_argument("--transformer_dropout", type=float, default=0.2, help="dropout rate for transformer")
parser.add_argument("--model_init_method", type=str, choices=["xavier", "default"], default="xavier", help="Weight initialization method: xavier or default")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--size_aware_batching", action="store_true", help="use size-aware batching")
parser.add_argument("--max_nodes_per_batch", type=int, default=10000, help="max nodes per batch")
parser.add_argument("--num_epochs", type=int, default=2000, help="number of epochs")
parser.add_argument("--k", type=int, default=16, help="number of eigenvectors used for Laplacian positional encoding")
parser.add_argument("--loss_type", type=str, choices=["ce", "cb_focal"], default="cb_focal", help="loss function to use: ce (CrossEntropy), cb_focal (ClassBalancedFocalLoss)")
parser.add_argument("--use_class_weights", action="store_true", help="use class weights for training")
parser.add_argument("--label_smoothing", type=float, default=0.0, help="label smoothing for training")
parser.add_argument("--dynamic_mask_ratio", action="store_true", help="use dynamic mask ratio")
parser.add_argument("--delete_node_width_embedding", action="store_true", help="delete node width embedding")
parser.add_argument("--random_flip_posenc", action="store_true", help="randomly flip the positional encoding")
parser.add_argument("--mask_ratio", type=float, default=0.2, help="ratio of nodes to mask")
parser.add_argument("--stratified_masking", action="store_true", help="use stratified masking to ensure class coverage")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay")
parser.add_argument("--transfer_learning", action="store_true", help="use transfer learning")
parser.add_argument("--model_path", type=str, default="", help="path to the model to transfer learning from")
return parser.parse_args()