diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 5af78c0dd..a7567d0ad 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -35,6 +35,7 @@ def forward( py: torch.Tensor, pxy_grads: List[Optional[torch.Tensor]], boundary: Optional[torch.Tensor] = None, + fast_emit_scale: float = 0.0, return_grad: bool = False, ) -> torch.Tensor: """ @@ -109,6 +110,10 @@ def forward( all sequences are of the same length. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). + return_grad: Whether to return grads of ``px`` and ``py``, this grad standing for the occupation probability is the output of the backward with a @@ -163,6 +168,7 @@ def forward( ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) (px_grad, py_grad) = _k2.mutual_information_backward( px, py, boundary, p, ans_grad) + px_grad *= (1 + fast_emit_scale) ctx.save_for_backward(px_grad, py_grad) assert len(pxy_grads) == 2 pxy_grads[0] = px_grad @@ -173,19 +179,20 @@ def forward( @staticmethod def backward( ctx, ans_grad: Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: + ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None, None]: (px_grad, py_grad) = ctx.saved_tensors (B,) = ans_grad.shape ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1) px_grad *= ans_grad py_grad *= ans_grad - return (px_grad, py_grad, None, None, None) + return (px_grad, py_grad, None, None, None, None) def mutual_information_recursion( px: Tensor, py: Tensor, boundary: Optional[Tensor] = None, + fast_emit_scale: float = 0.0, return_grad: bool = False, ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: """A recursion that is useful in computing mutual information between two @@ -248,6 +255,10 @@ def mutual_information_recursion( ``y`` sequences respectively, and can be used if not all sequences are of the same length. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). + return_grad: Whether to return grads of ``px`` and ``py``, this grad standing for the occupation probability is the output of the backward with a @@ -292,8 +303,14 @@ def mutual_information_recursion( px, py = px.contiguous(), py.contiguous() pxy_grads = [None, None] - scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads, - boundary, return_grad) + scores = MutualInformationRecursionFunction.apply( + px, + py, + pxy_grads, + boundary, + fast_emit_scale, + return_grad + ) px_grad, py_grad = pxy_grads return (scores, (px_grad, py_grad)) if return_grad else scores diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 1917ccbfe..960edaa53 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -202,6 +202,7 @@ def rnnt_loss_simple( termination_symbol: int, boundary: Optional[Tensor] = None, modified: bool = False, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: @@ -228,6 +229,9 @@ def rnnt_loss_simple( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -258,8 +262,13 @@ def rnnt_loss_simple( boundary=boundary, modified=modified, ) + scores_and_grads = mutual_information_recursion( - px=px, py=py, boundary=boundary, return_grad=return_grad + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale, + return_grad=return_grad ) negated_loss = scores_and_grads[0] if return_grad else scores_and_grads if reduction == "none": @@ -376,6 +385,7 @@ def rnnt_loss( termination_symbol: int, boundary: Optional[Tensor] = None, modified: bool = False, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, @@ -397,6 +407,9 @@ def rnnt_loss( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -416,7 +429,13 @@ def rnnt_loss( boundary=boundary, modified=modified, ) - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + + negated_loss = mutual_information_recursion( + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale + ) if reduction == "none": return -negated_loss elif reduction == "mean": @@ -957,6 +976,7 @@ def rnnt_loss_pruned( termination_symbol: int, boundary: Tensor = None, modified: bool = False, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", ) -> Tensor: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output @@ -987,6 +1007,9 @@ def rnnt_loss_pruned( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -1006,7 +1029,13 @@ def rnnt_loss_pruned( boundary=boundary, modified=modified, ) - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + + negated_loss = mutual_information_recursion( + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale + ) if reduction == "none": return -negated_loss elif reduction == "mean": @@ -1250,6 +1279,7 @@ def rnnt_loss_smoothed( am_only_scale: float = 0.1, boundary: Optional[Tensor] = None, modified: bool = False, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: @@ -1283,6 +1313,9 @@ def rnnt_loss_smoothed( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -1315,8 +1348,13 @@ def rnnt_loss_smoothed( boundary=boundary, modified=modified, ) + scores_and_grads = mutual_information_recursion( - px=px, py=py, boundary=boundary, return_grad=return_grad + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale, + return_grad=return_grad ) negated_loss = scores_and_grads[0] if return_grad else scores_and_grads if reduction == "none":