Skip to content

Missing Output Projection Layer in SAMFormer Pytorch Implementation #20

@dnsch

Description

@dnsch

Hi,

it seems the samformer architecture in samformer_pytorch/samformer/samformer.py differs from the one described in the paper.

In the paper, the architecture is described as follows:

$$f(X) = \big[ X + A(X)XW_V W_O \big] W$$

with

  • $W \in \mathbb{R}^{L \times H}$
  • $W_Q \in \mathbb{R}^{L \times d_m}$
  • $W_K \in \mathbb{R}^{L \times d_m}$
  • $W_V \in \mathbb{R}^{L \times d_m}$
  • $W_O \in \mathbb{R}^{d_m \times L}$

and

  • $X \text{ the input variables}$
  • $L \text{ the input sequence}$
  • $H \text{ the prediction horizon}$
  • $d_m \text{ the hidden dimension}$

However, the pytorch implementation in samformer.py is written like this:

self.compute_keys = nn.Linear(seq_len, hid_dim)
self.compute_queries = nn.Linear(seq_len, hid_dim)
self.compute_values = nn.Linear(seq_len, seq_len)
self.linear_forecaster = nn.Linear(seq_len, pred_horizon)

and the forward pass is implemented in a way that corresponds to:

$$f(X) = \big[ X + A(X)XW_V \big] W$$
  • $W \in \mathbb{R}^{L \times H}$
  • $W_Q \in \mathbb{R}^{L \times d_m}$
  • $W_K \in \mathbb{R}^{L \times d_m}$
  • $W_V \in \mathbb{R}^{L \times L}$

with $W_O$ missing.

I therefore assume that this is a bug, which could be fixed by adding the output layer and changing the shape of the values matrix:

self.compute_keys = nn.Linear(seq_len, hid_dim)
self.compute_queries = nn.Linear(seq_len, hid_dim)
self.compute_values = nn.Linear(seq_len, hid_dim)
self.output_layer = nn.Linear(hid_dim, seq_len)
self.linear_forecaster = nn.Linear(seq_len, pred_horizon)

and later adding the output layer computation in the forward pass:

att_score = self.output_layer(att_score)

after using scaled_dot_product_attention.

With this additional output layer, the number of parameters of the pytorch model matches the number of parameters in the tensorflow implementation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions