Skip to content

Conversation

@lericson
Copy link

@lericson lericson commented Mar 6, 2023

Use existing primitives such as nn.GELU
Remove no-op modules
Update super()
Removed trailing whitespaces in code

Use existing primitives such as `nn.GELU`
Remove no-op modules
Update `super()`
Removed trailing whitespaces in code
@lericson
Copy link
Author

lericson commented Mar 6, 2023

I should note that this fixes a minor discrepancy in the code compared to the JAX reference, their code says (with some trivial rewriting):

x = self.encoder(x)

x = x[:, 0]

if repr_dim is not None:
  x = nn.Dense(repr_dim)(x)
  x = nn.tanh(x)

... whereas this is what this repository says:

        x = self.encoder(x)
        x = self.pre_logits(x)

        # only support cls token now
        x = x[:, 0]

        return self.head(x)

It doesn't actually matter, but the placement of the x = x[:, 0] is earlier in the reference code. This PR does the same. It just saves some cycles I guess.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant