Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ function vision_transformer() {

function word_language_model() {
uv run main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
for model in "RNN_TANH" "RNN_RELU" "LSTM" "GRU" "Transformer"; do
uv run main.py --model $model --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
done
}

function gcn() {
Expand Down
34 changes: 30 additions & 4 deletions word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.onnx

import data
import model
from model import PositionalEncoding, RNNModel, TransformerModel

parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM/GRU/Transformer Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2',
Expand Down Expand Up @@ -108,9 +108,9 @@ def batchify(data, bsz):

ntokens = len(corpus.dictionary)
if args.model == 'Transformer':
model = model.TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
model = TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
else:
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)
model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)

criterion = nn.NLLLoss()

Expand Down Expand Up @@ -243,7 +243,33 @@ def export_onnx(path, batch_size, seq_len):

# Load the best saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)
if args.model == 'Transformer':
Comment thread
dvrogozh marked this conversation as resolved.
safe_globals = [
PositionalEncoding,
TransformerModel,
torch.nn.functional.relu,
torch.nn.modules.activation.MultiheadAttention,
torch.nn.modules.container.ModuleList,
torch.nn.modules.dropout.Dropout,
torch.nn.modules.linear.Linear,
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
torch.nn.modules.normalization.LayerNorm,
torch.nn.modules.sparse.Embedding,
torch.nn.modules.transformer.TransformerEncoder,
torch.nn.modules.transformer.TransformerEncoderLayer,
]
else:
safe_globals = [
RNNModel,
torch.nn.modules.dropout.Dropout,
torch.nn.modules.linear.Linear,
torch.nn.modules.rnn.GRU,
torch.nn.modules.rnn.LSTM,
torch.nn.modules.rnn.RNN,
torch.nn.modules.sparse.Embedding,
]
with torch.serialization.safe_globals(safe_globals):
model = torch.load(f)
# after load the rnn params are not a continuous chunk of memory
# this makes them a continuous chunk, and will speed up forward pass
# Currently, only rnn model supports flatten_parameters function.
Expand Down
2 changes: 1 addition & 1 deletion word_language_model/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torch<2.6
torch>=2.6
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.serialization.safe_globals I've used in this PR appeared in torch 2.5. The torch.load default to load weights was changed to True in 2.6. Since we in any case plan to bump dependency to 2.6 by using torch.accelerate, I think we can update the requirement to 2.6 right away.