-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder.py
More file actions
48 lines (44 loc) · 1.93 KB
/
decoder.py
File metadata and controls
48 lines (44 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch.nn as nn
import torch
class GreedySearchDecoder(nn.Module):
"""
"""
def __init__(self, encoder, decoder, device, SOS_token):
super(GreedySearchDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
self.SOS_token = SOS_token
def forward(self, input_seq, input_length, max_length):
"""
:param input_seq:
:param input_length:
:param max_length:
:return:
"""
# Forward input through encoder model
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
# Prepare encoder's final hidden layer to be first hidden input to
# the decoder
decoder_hidden = encoder_hidden[:self.decoder.n_layers]
# Initialize decoder input with SOS_token
decoder_input = \
torch.ones(1, 1, device=self.device, dtype=torch.long) * \
self.SOS_token
# Initialize tensors to append decoded words to
all_tokens = torch.zeros([0], device=self.device, dtype=torch.long)
all_scores = torch.zeros([0], device=self.device)
# Iteratively decode one word token at a time
for _ in range(max_length):
# Forward pass through decoder
decoder_output, decoder_hidden = self.decoder(
decoder_input, decoder_hidden, encoder_outputs)
# Obtain most likely word token and its softmax score
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
# Record token and score
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
# Prepare current token to be next decoder input (add a dimension)
decoder_input = torch.unsqueeze(decoder_input, 0)
# Return collections of word tokens and scores
return all_tokens, all_scores