Sampling, batch_all, non-zero, optimizer-flag#33
Conversation
| if args.loss_ignore_zero is True: | ||
| nnz = tf.count_nonzero(losses, dtype=tf.float32) | ||
| else: | ||
| nnz = tf.reduce_sum(tf.to_float(tf.greater(losses, args.loss_ignore_zero or 1e-5))) |
There was a problem hiding this comment.
Is the point of this supposed to be just for logging?
There was a problem hiding this comment.
No, actually. It's type-magic and can be a little obscure, hence why I still need to write documentation in the README :)
The else case happens when loss_ignore_zero is given an additional float argument, so one can call it as --loss_ignore_zero 1e-3 for example, in order to consider anything below 1e-3 to be counted as zero.
There was a problem hiding this comment.
Read our paper, we explain them in there :) But really it's not a good time investment to play with that parameter.
There was a problem hiding this comment.
I am going to delete that comment because it makes no sense sorry. 😆 Currently have it printed and highlighted in front of me trying to get to grips!
| help='Enable the super-mega-advanced top-secret sampling stabilizer.') | ||
|
|
||
| parser.add_argument( | ||
| '--loss_ignore_zero', default=False, const=True, nargs='?', type=common.positive_float, |
There was a problem hiding this comment.
Ohh misread this to mean it can only be boolean. I am going to start playing with this then. 🍾
So this is a branch I had lying around for a while which adds a lot of things. I'm not merging it yet, as I still want to add these changes to the README and test it thoroughly, but here it is for others to try out.
It implements the trick I mentioned in #4 as well as some more variants we cover in the paper. If you give this a try, please give feedback here.