From 2ba51dfc62bf45ed8d94b85eecf4a63210674e05 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 19 May 2022 11:36:33 +0800 Subject: [PATCH 1/5] add delay_penalty in rnnt loss --- k2/python/k2/rnnt_loss.py | 96 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 1 deletion(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 64e51a4c2..7b6538a77 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -200,6 +200,7 @@ def rnnt_loss_simple( termination_symbol: int, boundary: Optional[Tensor] = None, modified: bool = False, + delay_penalty: float = 0.0, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: @@ -226,6 +227,10 @@ def rnnt_loss_simple( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + delay_penalty: A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time-masking + encouraging the network to delay symbols. + See https://github.com/k2-fsa/k2/issues/955 for more details. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -255,6 +260,24 @@ def rnnt_loss_simple( boundary=boundary, modified=modified, ) + + if delay_penalty > 0.0: + B, S, T0 = px.shape + T = T0 if modified else T0 - 1 + if boundary is None: + offset = torch.tensor( + [(T - 1) / 2 * delay_penalty] * B, + dtype=px.dtype, + device=px.device, + ) + else: + offset = (boundary[:, 3] - 1) / 2 * delay_penalty + penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) + penalty = penalty.to(px.dtype) + px += penalty + scores_and_grads = mutual_information_recursion( px=px, py=py, boundary=boundary, return_grad=return_grad ) @@ -372,6 +395,7 @@ def rnnt_loss( termination_symbol: int, boundary: Optional[Tensor] = None, modified: bool = False, + delay_penalty: float = 0.0, reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, @@ -393,6 +417,10 @@ def rnnt_loss( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + delay_penalty: A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time-masking + encouraging the network to delay symbols. + See https://github.com/k2-fsa/k2/issues/955 for more details. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -412,6 +440,24 @@ def rnnt_loss( boundary=boundary, modified=modified, ) + + if delay_penalty > 0.0: + B, S, T0 = px.shape + T = T0 if modified else T0 - 1 + if boundary is None: + offset = torch.tensor( + [(T - 1) / 2 * delay_penalty] * B, + dtype=px.dtype, + device=px.device, + ) + else: + offset = (boundary[:, 3] - 1) / 2 * delay_penalty + penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) + penalty = penalty.to(px.dtype) + px += penalty + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": return -negated_loss @@ -622,7 +668,9 @@ def do_rnnt_pruning( lm_pruned = torch.gather( lm.unsqueeze(1).expand((B, T, S + 1, decoder_dim)), dim=2, - index=ranges.reshape((B, T, s_range, 1)).expand((B, T, s_range, decoder_dim)), + index=ranges.reshape((B, T, s_range, 1)).expand( + (B, T, s_range, decoder_dim) + ), ) return am_pruned, lm_pruned @@ -803,6 +851,7 @@ def rnnt_loss_pruned( termination_symbol: int, boundary: Tensor = None, modified: bool = False, + delay_penalty: float = 0.0, reduction: Optional[str] = "mean", ) -> Tensor: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output @@ -833,6 +882,10 @@ def rnnt_loss_pruned( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + delay_penalty: A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time-masking + encouraging the network to delay symbols. + See https://github.com/k2-fsa/k2/issues/955 for more details. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -852,6 +905,24 @@ def rnnt_loss_pruned( boundary=boundary, modified=modified, ) + + if delay_penalty > 0.0: + B, S, T0 = px.shape + T = T0 if modified else T0 - 1 + if boundary is None: + offset = torch.tensor( + [(T - 1) / 2 * delay_penalty] * B, + dtype=px.dtype, + device=px.device, + ) + else: + offset = (boundary[:, 3] - 1) / 2 * delay_penalty + penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) + penalty = penalty.to(px.dtype) + px += penalty + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": return -negated_loss @@ -1094,6 +1165,7 @@ def rnnt_loss_smoothed( am_only_scale: float = 0.1, boundary: Optional[Tensor] = None, modified: bool = False, + delay_penalty: float = 0.0, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: @@ -1127,6 +1199,10 @@ def rnnt_loss_smoothed( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. + delay_penalty: A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time-masking + encouraging the network to delay symbols. + See https://github.com/k2-fsa/k2/issues/955 for more details. reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -1159,6 +1235,24 @@ def rnnt_loss_smoothed( boundary=boundary, modified=modified, ) + + if delay_penalty > 0.0: + B, S, T0 = px.shape + T = T0 if modified else T0 - 1 + if boundary is None: + offset = torch.tensor( + [(T - 1) / 2 * delay_penalty] * B, + dtype=px.dtype, + device=px.device, + ) + else: + offset = (boundary[:, 3] - 1) / 2 * delay_penalty + penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) + penalty = penalty.to(px.dtype) + px += penalty + scores_and_grads = mutual_information_recursion( px=px, py=py, boundary=boundary, return_grad=return_grad ) From 5d6ed8b6adf39ab00c5e9f5c3eab83d4f984858d Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 19 May 2022 16:02:29 +0800 Subject: [PATCH 2/5] Fix comments --- k2/python/k2/rnnt_loss.py | 48 +++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 7b6538a77..cc051633c 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -266,17 +266,17 @@ def rnnt_loss_simple( T = T0 if modified else T0 - 1 if boundary is None: offset = torch.tensor( - [(T - 1) / 2 * delay_penalty] * B, + (T - 1) / 2, dtype=px.dtype, device=px.device, - ) + ).expand(B, 1, 1) else: - offset = (boundary[:, 3] - 1) / 2 * delay_penalty - penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + offset = (boundary[:, 3] - 1) / 2 + penalty = offset.reshape(B, 1, 1) - torch.arange( T0, device=px.device ).reshape(1, 1, T0) - penalty = penalty.to(px.dtype) - px += penalty + penalty = penalty * delay_penalty + px += penalty.to(px.dtype) scores_and_grads = mutual_information_recursion( px=px, py=py, boundary=boundary, return_grad=return_grad @@ -446,17 +446,17 @@ def rnnt_loss( T = T0 if modified else T0 - 1 if boundary is None: offset = torch.tensor( - [(T - 1) / 2 * delay_penalty] * B, + (T - 1) / 2, dtype=px.dtype, device=px.device, - ) + ).expand(B, 1, 1) else: - offset = (boundary[:, 3] - 1) / 2 * delay_penalty - penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + offset = (boundary[:, 3] - 1) / 2 + penalty = offset.reshape(B, 1, 1) - torch.arange( T0, device=px.device ).reshape(1, 1, T0) - penalty = penalty.to(px.dtype) - px += penalty + penalty = penalty * delay_penalty + px += penalty.to(px.dtype) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": @@ -911,17 +911,17 @@ def rnnt_loss_pruned( T = T0 if modified else T0 - 1 if boundary is None: offset = torch.tensor( - [(T - 1) / 2 * delay_penalty] * B, + (T - 1) / 2, dtype=px.dtype, device=px.device, - ) + ).expand(B, 1, 1) else: - offset = (boundary[:, 3] - 1) / 2 * delay_penalty - penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + offset = (boundary[:, 3] - 1) / 2 + penalty = offset.reshape(B, 1, 1) - torch.arange( T0, device=px.device ).reshape(1, 1, T0) - penalty = penalty.to(px.dtype) - px += penalty + penalty = penalty * delay_penalty + px += penalty.to(px.dtype) negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) if reduction == "none": @@ -1241,17 +1241,17 @@ def rnnt_loss_smoothed( T = T0 if modified else T0 - 1 if boundary is None: offset = torch.tensor( - [(T - 1) / 2 * delay_penalty] * B, + (T - 1) / 2, dtype=px.dtype, device=px.device, - ) + ).expand(B, 1, 1) else: - offset = (boundary[:, 3] - 1) / 2 * delay_penalty - penalty = offset.reshape(B, 1, 1) - delay_penalty * torch.arange( + offset = (boundary[:, 3] - 1) / 2 + penalty = offset.reshape(B, 1, 1) - torch.arange( T0, device=px.device ).reshape(1, 1, T0) - penalty = penalty.to(px.dtype) - px += penalty + penalty = penalty * delay_penalty + px += penalty.to(px.dtype) scores_and_grads = mutual_information_recursion( px=px, py=py, boundary=boundary, return_grad=return_grad From 5231e2bfd8eb97b6eeafcbae3eff223f940bb07e Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 4 Oct 2022 11:36:08 +0800 Subject: [PATCH 3/5] add fast_emit to rnnt loss --- k2/python/k2/mutual_information.py | 25 +++++- k2/python/k2/rnnt_loss.py | 130 +++++++++-------------------- 2 files changed, 59 insertions(+), 96 deletions(-) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index 641ea32e9..583f27cf0 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -35,6 +35,7 @@ def forward( py: torch.Tensor, pxy_grads: List[Optional[torch.Tensor]], boundary: Optional[torch.Tensor] = None, + fast_emit_scale: float = 0.0, return_grad: bool = False, ) -> torch.Tensor: """ @@ -109,6 +110,10 @@ def forward( all sequences are of the same length. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). + return_grad: Whether to return grads of ``px`` and ``py``, this grad standing for the occupation probability is the output of the backward with a @@ -163,6 +168,7 @@ def forward( ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) (px_grad, py_grad) = _k2.mutual_information_backward( px, py, boundary, p, ans_grad) + px_grad *= (1 + fast_emit_scale) ctx.save_for_backward(px_grad, py_grad) assert len(pxy_grads) == 2 pxy_grads[0] = px_grad @@ -173,19 +179,20 @@ def forward( @staticmethod def backward( ctx, ans_grad: Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: + ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None, None]: (px_grad, py_grad) = ctx.saved_tensors (B,) = ans_grad.shape ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1) px_grad *= ans_grad py_grad *= ans_grad - return (px_grad, py_grad, None, None, None) + return (px_grad, py_grad, None, None, None, None) def mutual_information_recursion( px: Tensor, py: Tensor, boundary: Optional[Tensor] = None, + fast_emit_scale: float = 0.0, return_grad: bool = False, ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: """A recursion that is useful in computing mutual information between two @@ -248,6 +255,10 @@ def mutual_information_recursion( ``y`` sequences respectively, and can be used if not all sequences are of the same length. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). + return_grad: Whether to return grads of ``px`` and ``py``, this grad standing for the occupation probability is the output of the backward with a @@ -291,8 +302,14 @@ def mutual_information_recursion( assert px.is_contiguous() assert py.is_contiguous() pxy_grads = [None, None] - scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads, - boundary, return_grad) + scores = MutualInformationRecursionFunction.apply( + px=px, + py=py, + pxy_grads=pxy_grads, + boundary=boundary, + fast_emit_scale=fast_emit_scale, + return_grad=return_grad + ) px_grad, py_grad = pxy_grads return (scores, (px_grad, py_grad)) if return_grad else scores diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index cc051633c..55358e86e 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -200,7 +200,7 @@ def rnnt_loss_simple( termination_symbol: int, boundary: Optional[Tensor] = None, modified: bool = False, - delay_penalty: float = 0.0, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: @@ -227,10 +227,9 @@ def rnnt_loss_simple( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. - delay_penalty: A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time-masking - encouraging the network to delay symbols. - See https://github.com/k2-fsa/k2/issues/955 for more details. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -261,25 +260,12 @@ def rnnt_loss_simple( modified=modified, ) - if delay_penalty > 0.0: - B, S, T0 = px.shape - T = T0 if modified else T0 - 1 - if boundary is None: - offset = torch.tensor( - (T - 1) / 2, - dtype=px.dtype, - device=px.device, - ).expand(B, 1, 1) - else: - offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange( - T0, device=px.device - ).reshape(1, 1, T0) - penalty = penalty * delay_penalty - px += penalty.to(px.dtype) - scores_and_grads = mutual_information_recursion( - px=px, py=py, boundary=boundary, return_grad=return_grad + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale, + return_grad=return_grad ) negated_loss = scores_and_grads[0] if return_grad else scores_and_grads if reduction == "none": @@ -395,7 +381,7 @@ def rnnt_loss( termination_symbol: int, boundary: Optional[Tensor] = None, modified: bool = False, - delay_penalty: float = 0.0, + fast_emit_scale : float = 0.0, reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, @@ -417,10 +403,9 @@ def rnnt_loss( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. - delay_penalty: A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time-masking - encouraging the network to delay symbols. - See https://github.com/k2-fsa/k2/issues/955 for more details. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -441,24 +426,12 @@ def rnnt_loss( modified=modified, ) - if delay_penalty > 0.0: - B, S, T0 = px.shape - T = T0 if modified else T0 - 1 - if boundary is None: - offset = torch.tensor( - (T - 1) / 2, - dtype=px.dtype, - device=px.device, - ).expand(B, 1, 1) - else: - offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange( - T0, device=px.device - ).reshape(1, 1, T0) - penalty = penalty * delay_penalty - px += penalty.to(px.dtype) - - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + negated_loss = mutual_information_recursion( + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale + ) if reduction == "none": return -negated_loss elif reduction == "mean": @@ -851,7 +824,7 @@ def rnnt_loss_pruned( termination_symbol: int, boundary: Tensor = None, modified: bool = False, - delay_penalty: float = 0.0, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", ) -> Tensor: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output @@ -882,10 +855,9 @@ def rnnt_loss_pruned( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. - delay_penalty: A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time-masking - encouraging the network to delay symbols. - See https://github.com/k2-fsa/k2/issues/955 for more details. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -906,24 +878,12 @@ def rnnt_loss_pruned( modified=modified, ) - if delay_penalty > 0.0: - B, S, T0 = px.shape - T = T0 if modified else T0 - 1 - if boundary is None: - offset = torch.tensor( - (T - 1) / 2, - dtype=px.dtype, - device=px.device, - ).expand(B, 1, 1) - else: - offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange( - T0, device=px.device - ).reshape(1, 1, T0) - penalty = penalty * delay_penalty - px += penalty.to(px.dtype) - - negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + negated_loss = mutual_information_recursion( + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale + ) if reduction == "none": return -negated_loss elif reduction == "mean": @@ -1165,7 +1125,7 @@ def rnnt_loss_smoothed( am_only_scale: float = 0.1, boundary: Optional[Tensor] = None, modified: bool = False, - delay_penalty: float = 0.0, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: @@ -1199,10 +1159,9 @@ def rnnt_loss_smoothed( Most likely you will want begin_symbol and begin_frame to be zero. modified: if True, each time a real symbol is consumed a frame will also be consumed, so at most 1 symbol can appear per frame. - delay_penalty: A constant value to penalize symbol delay, this may be - needed when training with time masking, to avoid the time-masking - encouraging the network to delay symbols. - See https://github.com/k2-fsa/k2/issues/955 for more details. + fast_emit_scale: + Implement fast_emit proposed in https://arxiv.org/pdf/2010.11148.pdf + The idea is to scale px_grad with (1 + fast_emit_scale). reduction: Specifies the reduction to apply to the output: `none`, `mean` or `sum`. `none`: no reduction will be applied. @@ -1236,25 +1195,12 @@ def rnnt_loss_smoothed( modified=modified, ) - if delay_penalty > 0.0: - B, S, T0 = px.shape - T = T0 if modified else T0 - 1 - if boundary is None: - offset = torch.tensor( - (T - 1) / 2, - dtype=px.dtype, - device=px.device, - ).expand(B, 1, 1) - else: - offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange( - T0, device=px.device - ).reshape(1, 1, T0) - penalty = penalty * delay_penalty - px += penalty.to(px.dtype) - scores_and_grads = mutual_information_recursion( - px=px, py=py, boundary=boundary, return_grad=return_grad + px=px, + py=py, + boundary=boundary, + fast_emit_scale=fast_emit_scale, + return_grad=return_grad ) negated_loss = scores_and_grads[0] if return_grad else scores_and_grads if reduction == "none": From 7c11d96cd9f2116bd89065c778131c84cffab8da Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 4 Oct 2022 11:48:36 +0800 Subject: [PATCH 4/5] Minor fixes --- k2/python/k2/mutual_information.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/k2/python/k2/mutual_information.py b/k2/python/k2/mutual_information.py index f1b3cc156..a7567d0ad 100644 --- a/k2/python/k2/mutual_information.py +++ b/k2/python/k2/mutual_information.py @@ -304,12 +304,12 @@ def mutual_information_recursion( pxy_grads = [None, None] scores = MutualInformationRecursionFunction.apply( - px=px, - py=py, - pxy_grads=pxy_grads, - boundary=boundary, - fast_emit_scale=fast_emit_scale, - return_grad=return_grad + px, + py, + pxy_grads, + boundary, + fast_emit_scale, + return_grad ) px_grad, py_grad = pxy_grads return (scores, (px_grad, py_grad)) if return_grad else scores From be847577ebc13bfa97561bf2cb6c933b39abbf34 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 4 Oct 2022 11:49:30 +0800 Subject: [PATCH 5/5] fix flake8 --- k2/python/k2/rnnt_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index e6132b047..960edaa53 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -385,7 +385,7 @@ def rnnt_loss( termination_symbol: int, boundary: Optional[Tensor] = None, modified: bool = False, - fast_emit_scale : float = 0.0, + fast_emit_scale: float = 0.0, reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input,