Skip to content

Issue with resuming training #31

@cjchristopher

Description

@cjchristopher

Hi team - a question about the following line of code:

mask_prob = self.args.mask_prob + (.8 - self.args.mask_prob) * (epoch - 1) / 20

Can I ask what function this serves, and why? If resuming training from a checkpoint model where the epoch count exceeds 28, this is greater than 1 and causes assertion fails. I note that if starting training from epoch 0 with no interruption, the mask_prob remains fixed at the initialized value (0.2), but with any interruption the training resumes with a higher mask prob.
Should this just be mask_prob = self.args.mask_prob ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions