From 66fb3516f0595aab0335d6f9594428b8817736ee Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 4 Jan 2023 16:06:32 +0800 Subject: [PATCH 1/2] fix lattice length --- k2/csrc/rnnt_decode.cu | 187 ++++++++++++++++++++++++++-- k2/csrc/rnnt_decode.h | 75 ++++++++++- k2/python/csrc/torch/rnnt_decode.cu | 16 +++ k2/python/k2/rnnt_decode.py | 10 +- 4 files changed, 276 insertions(+), 12 deletions(-) diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index 1224d7b58..1e3d16331 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -247,7 +247,7 @@ RaggedShape RnntDecodingStreams::ExpandArcs() { return unpruned_arcs_shape; } -Renumbering RnntDecodingStreams::DoFisrtPassPruning( +Renumbering RnntDecodingStreams::DoFirstPassPruning( RaggedShape &unpruned_arcs_shape, const Array2 &logprobs) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(unpruned_arcs_shape.NumAxes(), 4); @@ -438,7 +438,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { auto unpruned_arcs_shape = ExpandArcs(); // (2) Do initial pruning. - auto pass1_renumbering = DoFisrtPassPruning(unpruned_arcs_shape, logprobs); + auto pass1_renumbering = DoFirstPassPruning(unpruned_arcs_shape, logprobs); // pass1_arcs_shape has a shape of [stream][context][state][arc] auto pass1_arcs_shape = @@ -507,6 +507,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { idx01 = uas_row_ids2_data[idx012], idx0 = uas_row_ids1_data[idx01], num_graph_states = num_graph_states_data[idx0]; int64_t this_state = this_states_values_data[idx012]; + int32_t this_graph_state = this_state % num_graph_states; double this_score = this_scores_data[idx012]; // handle the implicit epsilon self-loop @@ -515,7 +516,21 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { // we assume termination symbol to be 0 here. scores_data[arc_idx] = this_score + logprobs_acc(idx01, 0); ArcInfo ai; - ai.graph_arc_idx01 = -1; + /* + Track state index for self-loop arcs. + It's lucky that type int32_t has range [-2147483648, 2147483647] + there is one more negative values than positive values in computer. + state (0) --> graph_arc_idx01 (-1) + state (1) --> graph_arc_idx01 (-2) + state (2) --> graph_arc_idx01 (-3) + state (2147483647) --> graph_arc_idx01 (-2147483648) + + Actually, super final state has no self-loop. + So definitely there are enough negative values + to represent positive state index. + */ + ai.graph_arc_idx01 = -this_graph_state - 1; + K2_CHECK_LT(ai.graph_arc_idx01, 0); ai.score = logprobs_acc(idx01, 0); ai.label = 0; arcs_data[arc_idx] = ai; @@ -526,8 +541,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; int64_t this_context_state = this_state / num_graph_states; - int32_t this_graph_state = this_state % num_graph_states, - graph_idx0x = graph_row_split1_data[this_graph_state], + int32_t graph_idx0x = graph_row_split1_data[this_graph_state], graph_idx01 = graph_idx0x + idx3 - 1; // minus 1 here as // epsilon self-loop // takes the position 0. @@ -711,9 +725,164 @@ void RnntDecodingStreams::GatherPrevFrames( } } +void RnntDecodingStreams::GetFinalArcs() { + NVTX_RANGE(K2_FUNC); + /* + This function handles last two steps of the generated lattice. + Relationship of variables in these two steps are: + + arcs: last frame arcs final arcs + states: {last frame state} ---------------> {final states} ---------> {super final state} # noqa + + Suer final state has no leaving arcs. + */ + + int32_t frames = prev_frames_.size(); + + // with shape [stream][state][arc] + auto last_frame_shape = prev_frames_[frames - 1]->shape; + + // Note: last_frame_arc_data is non-const + // The original "dest_state" attribute for each element in last_frame_arc_data + // is state index processed by function GroupStatesByContexts. + // In this function, source states in last_frame is expanded again, + // and those expanded destination states are NOT grouped to save time. + // So "dest_state" should be re-assigned to a new value. + ArcInfo *last_frame_arc_data = prev_frames_[frames - 1]->values.Data(); + const int32_t *lfs_row_ids2_data = last_frame_shape.RowIds(2).Data(), + *lfs_row_ids1_data = last_frame_shape.RowIds(1).Data(), + *lfs_row_splits2_data = last_frame_shape.RowSplits(2).Data(), + *lfs_row_splits1_data = last_frame_shape.RowSplits(1).Data(); + + const int32_t *num_graph_states_data = num_graph_states_.Data(); + const int32_t *const *graph_row_splits1_ptr_data = graphs_.shape.RowSplits(1); + const Arc *const *graphs_arcs_data = graphs_.values.Data(); + + // Name meaning of final_grpah_states: + // "final_" means it's for "final states". + // "_graph_states" means it storages state index in decoding graph. + // Though this variable could be calculated both in + // labmda_get_final_arcs_shape and lambda_populate_final_arcs, + // to save time, its calculated and cached during the former and + // used in the later. + Array1 final_graph_states(c_, last_frame_shape.NumElements()); + int32_t* final_graph_states_data = final_graph_states.Data(); + + // Calculate num_arcs for each final state. + Array1 num_final_arcs(c_, last_frame_shape.NumElements() + 1); + int32_t *num_final_arcs_data = num_final_arcs.Data(); + + K2_EVAL( + c_, last_frame_shape.NumElements(), lambda_get_final_arcs_shape, + (int32_t idx012) { + // place here to save one kernel. + num_final_arcs_data[idx012] = 0; + + int32_t idx01 = lfs_row_ids2_data[idx012], // state_idx01 + idx0 = lfs_row_ids1_data[idx01], // stream_idx0 + arc_idx0x = lfs_row_splits1_data[idx0], + arc_idx0xx = lfs_row_splits2_data[arc_idx0x], + arc_idx12 = idx012 - arc_idx0xx; + + ArcInfo& ai = last_frame_arc_data[idx012]; + + // Re-assign dest_state to a new value. + // See more detail comment at previous last_frame_arc_data definition. + ai.dest_state = arc_idx12; + + if (ai.label == -1) { + num_final_arcs_data[idx012] = 0; + // -(num_graph_states_data[idx0]) for state not expandable. + final_graph_states_data[idx012] = -(num_graph_states_data[idx0]); + return; + } + int dest_state = -1; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + if (ai.graph_arc_idx01 < 0) { + // For implicit self-loop arcs. + dest_state = -(ai.graph_arc_idx01 + 1); + K2_CHECK_GE(dest_state, 0); + K2_CHECK_LE(dest_state, num_graph_states_data[idx0]); + } else { + // For other arcs shown in the decoding graph. + dest_state = graph_arcs_data[ai.graph_arc_idx01].dest_state; + } + K2_CHECK_GE(dest_state, 0); + + final_graph_states_data[idx012] = dest_state; + // Plus one for the implicit epsilon self-loop. + num_final_arcs_data[idx012] = graph_row_split1_data[dest_state + 1] - + graph_row_split1_data[dest_state] + 1; + }); + + ExclusiveSum(num_final_arcs, &num_final_arcs); + auto final_arcs_shape = RaggedShape2(&num_final_arcs, nullptr, -1); + final_arcs_shape = ComposeRaggedShapes(last_frame_shape, final_arcs_shape); + // [steam][state][arc][arc] --> [stream][arc][arc] + // could be viewd as [strem][final state][arc] + final_arcs_shape = RemoveAxis(final_arcs_shape, 1); + const int32_t *fas_row_ids1_data = final_arcs_shape.RowIds(1).Data(); + const int32_t *fas_row_ids2_data = final_arcs_shape.RowIds(2).Data(); + const int32_t *fas_row_splits2_data = final_arcs_shape.RowSplits(2).Data(); + + auto final_arcs = Ragged(final_arcs_shape); + ArcInfo *final_arcs_data = final_arcs.values.Data(); + K2_EVAL( + c_, final_arcs_shape.NumElements(), lambda_populate_final_arcs, + (int32_t idx012) { + const int32_t idx01 = fas_row_ids2_data[idx012], // state + idx0 = fas_row_ids1_data[idx01], // stream + idx01x = fas_row_splits2_data[idx01], + arc_idx2 = idx012 - idx01x; + + const Arc *graph_arcs_data = graphs_arcs_data[idx0]; + const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; + int32_t graph_state_idx0 = final_graph_states_data[idx01]; + + int32_t ai_graph_arc_idx01 = 0; + int32_t ai_arc_label = 0; + if (graph_state_idx0 < 0) { + /* + Could be one of following two cases: + case 1: not expandable if graph_state_idx0 == -(num_graph_states_data[idx0]) # noqa + case 2: implicit self-loop if graph_state_idx0 > -(num_graph_states_data[idx0]) # noqa + */ + K2_DCHECK_GT(graph_state_idx0, -(num_graph_states_data[idx0])); + ai_arc_label = 0; + ai_graph_arc_idx01 = -1; + } else { + // For arcs shown in decoding graph. + int32_t graph_arc_idx0x = graph_row_split1_data[graph_state_idx0]; + // arc_idx2 could be viewed as graph_arc_idx1, + // since final_arcs_shape has 3 axes where arc_idx2 is calculated, + // while decoding_graph only has 2 axes where arc_idx2 is used. + ai_graph_arc_idx01 = graph_arc_idx0x + arc_idx2; + auto graph_arc = graph_arcs_data[ai_graph_arc_idx01]; + ai_arc_label = graph_arc.label; + } + ArcInfo ai; + // ai.dest_state will be overwritted by FormatOutput + // just initialize it as -1 here + ai.dest_state = -1; + ai.graph_arc_idx01 = ai_graph_arc_idx01; + ai.score = 0.0; + ai.label = ai_arc_label; + final_arcs_data[idx012] = ai; + }); + + prev_frames_.emplace_back(std::make_shared>(final_arcs)); +} + void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, bool allow_partial, FsaVec *ofsa, Array1 *out_map) { + FormatOutput(num_frames, allow_partial, false /* is_final */, ofsa, out_map); +} + +void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, + bool allow_partial, bool is_final, + FsaVec *ofsa, Array1 *out_map) { NVTX_RANGE(K2_FUNC); K2_CHECK(!attached_) << "You can only get outputs after calling TerminateAndFlushToStreams()"; @@ -723,6 +892,10 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, GatherPrevFrames(num_frames); + if (is_final) { + GetFinalArcs(); + } + int32_t frames = prev_frames_.size(); auto last_frame_shape = prev_frames_[frames - 1]->shape; @@ -873,11 +1046,11 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, int32_t dest_state_idx012 = oarc_idx01x_next + arc_info.dest_state; arc.dest_state = dest_state_idx012 - oarc_idx0xx; - // graph_arc_idx01 == -1 means this is a implicit epsilon self-loop + // graph_arc_idx01 < 0 means this is an implicit epsilon self-loop // arc_info.label == -1 means this is the final arc before last // frame this is non-accessible arc, we set its label to 0 here to // make the generated lattice a valid k2 fsa. - if (arc_info.graph_arc_idx01 == -1 || arc_info.label == -1) { + if (arc_info.graph_arc_idx01 <= -1 || arc_info.label == -1) { arc.label = 0; arc_info.graph_arc_idx01 = -1; } else { diff --git a/k2/csrc/rnnt_decode.h b/k2/csrc/rnnt_decode.h index 9224ca807..58ff1e3bf 100644 --- a/k2/csrc/rnnt_decode.h +++ b/k2/csrc/rnnt_decode.h @@ -94,10 +94,25 @@ struct RnntDecodingConfig { struct ArcInfo { // The arc-index within the RnntDecodingStream::graph that corresponds to this - // arc, or -1 if this arc is a "termination symbol" (these do not appear in - // the graph). + // arc if non-negative. + // There is an implicit self-loop arc for each state, which are represented + // by -(state_index + 1), see following comments of dest_state_in_graph. int32_t graph_arc_idx01; + // Note: + // 1. To save memory, value of this variable is calculated + // from graph_arc_idx01. + // 2. It is differnt from variable dest_state. + // dest_state_in_graph is the destination state index in decoding graph. + // dest_state below is the state index in "generated lattice". + // There are two kinds of arcs in decoding graph: + // 1. Implicit self-loop arcs, dest_state of these arcs are calculated + // with -(graph_arc_idx01 + 1). + // (Note, graph_arc_idx01 is negative for these arcs) + // 2. Other arcs shown in decoding graph, dest_state of these arcs are + // calculated with graph_arcs_data[ai.graph_arc_idx01].dest_state + // int32_t dest_state_in_graph; + // The score on the arc; contains both the graph score (if any) and the score // from the RNN-T joiner. float score; @@ -199,6 +214,38 @@ class RnntDecodingStreams { void FormatOutput(const std::vector &num_frames, bool allow_partial, FsaVec *ofsa, Array1 *out_map); + /* + Generate the lattice. + Note: Almost the same with previous overloaded version, + except for an extra `is_final` argument. + + Note: The prev_frames_ only contains decoded by current object, in order to + generate the lattice we will first gather all the previous frames from + individual streams. + + @param [in] num_frames A vector containing the number of frames we want + to gather for each stream (note: the frames we have + ever received). + It MUST satisfy `num_frames.size() == num_streams_`, and + `num_frames[i] <= srcs_[i].prev_frames.size()`. + @param [in] allow_partial If true and there is no final state active, + we will treat all the states on the last frame + to be final state. If false, we only + care about the real final state in the decoding + graph on the last frame when generating lattice. + @param [in] is_final If true, function GetFinalArcs() will be called. + If false, the same with previous overloaded version. + @param [out] ofsa The output lattice will write to here, its num_axes + equals to 3, will be re-allocated. + @param [out] out_map It is an Array1 with Dim() equals to + ofsa.NumElements() containing the idx01 into the graph of + each individual streams, mapping current arc in ofsa to + original decoding graphs. It may contain -1 which means + this arc is a "termination symbol". + */ + void FormatOutput(const std::vector &num_frames, bool allow_partial, + bool is_final, FsaVec *ofsa, Array1 *out_map); + /* Terminate the decoding process of current RnntDecodingStreams object, it will update the states & scores of each individual stream and split & @@ -261,8 +308,30 @@ class RnntDecodingStreams { @return Return the renumbering object indicating which arc will be kept. */ - Renumbering DoFisrtPassPruning(RaggedShape &unprund_arcs_shape, + Renumbering DoFirstPassPruning(RaggedShape &unprund_arcs_shape, const Array2 &logprobs); + + /* + Get final arcs when last frame is received, i.e. passing is_final=True to + function `FormatOutput`. + Comparing with openfst, a valid fsa in k2 needs arcs with label==-1 + pointing to a super final state. This function is handling these arcs. + See detail of the problem solved by this function at + https://github.com/k2-fsa/k2/pull/1089 + + If we name varialbes for last two steps of a lattice as: + arcs: last frame arcs final arcs + states: {last frame state} ---------------> {final states} ---------> {super final state} + + This function mainly do following steps: + 1. get last_frame from prev_frames_ + 2. expand last frame and get final states + 3. re-assign dest state of last frame arcs to final states + 4. populate final arcs + 5. append final arcs to prev_frames_ + */ + void GetFinalArcs(); + /* Group states by contexts. diff --git a/k2/python/csrc/torch/rnnt_decode.cu b/k2/python/csrc/torch/rnnt_decode.cu index 0a81deaed..7bda90397 100644 --- a/k2/python/csrc/torch/rnnt_decode.cu +++ b/k2/python/csrc/torch/rnnt_decode.cu @@ -152,6 +152,22 @@ static void PybindRnntDecodingStreams(py::module &m) { torch::Tensor out_map_tensor = ToTorch(out_map); return std::make_pair(ofsa, out_map_tensor); }); + + streams.def("format_output", + [](PyClass &self, std::vector &num_frames, + bool allow_partial, bool is_final) + -> std::pair { + DeviceGuard guard(self.Context()); + FsaVec ofsa; + Array1 out_map; + self.FormatOutput(num_frames, + allow_partial, + is_final, + &ofsa, + &out_map); + torch::Tensor out_map_tensor = ToTorch(out_map); + return std::make_pair(ofsa, out_map_tensor); + }); } } // namespace k2 diff --git a/k2/python/k2/rnnt_decode.py b/k2/python/k2/rnnt_decode.py index 95f6aa8ab..7d2fe559c 100644 --- a/k2/python/k2/rnnt_decode.py +++ b/k2/python/k2/rnnt_decode.py @@ -149,7 +149,8 @@ def terminate_and_flush_to_streams(self) -> None: def format_output( self, num_frames: List[int], - allow_partial: bool = False + allow_partial: bool = False, + is_final: bool = False, ) -> Fsa: """ Generate the lattice Fsa currently got. @@ -173,6 +174,11 @@ def format_output( If false, we only care about the real final state in the decoding graph on the last frame when generating lattice. Default False. + is_final: + If true, function GetFinalArcs() will be called. + See detail of the problem solved by GetFinalArcs() at + https://github.com/k2-fsa/k2/pull/1089 + Returns: Return the lattice Fsa with all the attributes propagated. @@ -181,7 +187,7 @@ def format_output( assert len(num_frames) == self.num_streams ragged_arcs, out_map = self.streams.format_output( - num_frames, allow_partial + num_frames, allow_partial, is_final, ) fsa = Fsa(ragged_arcs) From de17468d383d41f9e04607c268738d2ca3c35c85 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 27 Jul 2023 11:33:57 +0800 Subject: [PATCH 2/2] Fix lattice length & allow partial --- k2/csrc/intersect_dense.cu | 9 +++-- k2/csrc/rnnt_decode.cu | 75 +++++++++++++++++++++----------------- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/k2/csrc/intersect_dense.cu b/k2/csrc/intersect_dense.cu index fbb53ca54..7291a3a3f 100644 --- a/k2/csrc/intersect_dense.cu +++ b/k2/csrc/intersect_dense.cu @@ -778,7 +778,7 @@ class MultiGraphDenseIntersect { void DoStep(int32_t t) { NVTX_RANGE(K2_FUNC); Step &step = steps_[t], &prev_step = steps_[t - 1]; - int32_t scores_num_cols = b_fsas_.scores.Dim1(); + int32_t scores_num_cols = b_fsas_.scores.Dim1(); const float minus_inf = -std::numeric_limits::infinity(); // Divide by two because each arc is repeated twice in arc_scores (once for @@ -814,9 +814,10 @@ class MultiGraphDenseIntersect { backward_dest_prob = prev_state_scores_data[dest_state_scores_index_backward]; - // Assign negative infinity (-inf) to both the forward and backward scores, - // if the label on the carc is out-of-range, i.e., the label in the decoding - // graph (a_fsas) does not exist in the neural-net output (b_fsas). + // Assign negative infinity (-inf) to both the forward and backward + // scores, if the label on the carc is out-of-range, i.e., the label + // in the decoding graph (a_fsas) does not exist in the neural-net + // output (b_fsas). float b_score_forward; float b_score_backward; if (carc.label_plus_one <= scores_num_cols) { diff --git a/k2/csrc/rnnt_decode.cu b/k2/csrc/rnnt_decode.cu index b1a514158..4c3b18681 100644 --- a/k2/csrc/rnnt_decode.cu +++ b/k2/csrc/rnnt_decode.cu @@ -489,6 +489,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { const auto logprobs_acc = logprobs.Accessor(); const Arc *const *graphs_arcs_data = graphs_.values.Data(); + K2_EVAL( c_, cur_num_arcs, lambda_populate_arcs_states_scores, (int32_t arc_idx) { // Init renumber_arcs to 0, place here to save one kernel. @@ -530,7 +531,7 @@ void RnntDecodingStreams::Advance(const Array2 &logprobs) { So definitely there are enough negative values to represent positive state index. */ - ai.graph_arc_idx01 = -this_graph_state - 1; + ai.graph_arc_idx01 = -(this_graph_state + 1); K2_CHECK_LT(ai.graph_arc_idx01, 0); ai.score = logprobs_acc(idx01, 0); ai.label = 0; @@ -743,7 +744,7 @@ void RnntDecodingStreams::GetFinalArcs() { int32_t frames = prev_frames_.size(); - // with shape [stream][state][arc] + // with shape [stream][context][state][arc] auto last_frame_shape = prev_frames_[frames - 1]->shape; // Note: last_frame_arc_data is non-const @@ -753,8 +754,10 @@ void RnntDecodingStreams::GetFinalArcs() { // and those expanded destination states are NOT grouped to save time. // So "dest_state" should be re-assigned to a new value. ArcInfo *last_frame_arc_data = prev_frames_[frames - 1]->values.Data(); - const int32_t *lfs_row_ids2_data = last_frame_shape.RowIds(2).Data(), + const int32_t *lfs_row_ids3_data = last_frame_shape.RowIds(3).Data(), + *lfs_row_ids2_data = last_frame_shape.RowIds(2).Data(), *lfs_row_ids1_data = last_frame_shape.RowIds(1).Data(), + *lfs_row_splits3_data = last_frame_shape.RowSplits(3).Data(), *lfs_row_splits2_data = last_frame_shape.RowSplits(2).Data(), *lfs_row_splits1_data = last_frame_shape.RowSplits(1).Data(); @@ -778,34 +781,34 @@ void RnntDecodingStreams::GetFinalArcs() { K2_EVAL( c_, last_frame_shape.NumElements(), lambda_get_final_arcs_shape, - (int32_t idx012) { + (int32_t idx0123) { // place here to save one kernel. - num_final_arcs_data[idx012] = 0; + num_final_arcs_data[idx0123] = 0; - int32_t idx01 = lfs_row_ids2_data[idx012], // state_idx01 - idx0 = lfs_row_ids1_data[idx01], // stream_idx0 - arc_idx0x = lfs_row_splits1_data[idx0], - arc_idx0xx = lfs_row_splits2_data[arc_idx0x], - arc_idx12 = idx012 - arc_idx0xx; + int32_t idx012 = lfs_row_ids3_data[idx0123], // state_idx012 + idx01 = lfs_row_ids2_data[idx012], // context_idx01 + idx0 = lfs_row_ids1_data[idx01], // stream_idx0 + arc_idx01x = lfs_row_splits2_data[idx01], + arc_idx01xx = lfs_row_splits3_data[arc_idx01x], + arc_idx23 = idx0123 - arc_idx01xx; - ArcInfo& ai = last_frame_arc_data[idx012]; + ArcInfo& ai = last_frame_arc_data[idx0123]; // Re-assign dest_state to a new value. // See more detail comment at previous last_frame_arc_data definition. - ai.dest_state = arc_idx12; + ai.dest_state = arc_idx23; if (ai.label == -1) { - num_final_arcs_data[idx012] = 0; // -(num_graph_states_data[idx0]) for state not expandable. - final_graph_states_data[idx012] = -(num_graph_states_data[idx0]); + final_graph_states_data[idx0123] = -(num_graph_states_data[idx0]); return; } - int dest_state = -1; + int32_t dest_state = -1; const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; const Arc *graph_arcs_data = graphs_arcs_data[idx0]; if (ai.graph_arc_idx01 < 0) { // For implicit self-loop arcs. - dest_state = -(ai.graph_arc_idx01 + 1); + dest_state = -ai.graph_arc_idx01 - 1; K2_CHECK_GE(dest_state, 0); K2_CHECK_LE(dest_state, num_graph_states_data[idx0]); } else { @@ -814,35 +817,40 @@ void RnntDecodingStreams::GetFinalArcs() { } K2_CHECK_GE(dest_state, 0); - final_graph_states_data[idx012] = dest_state; + final_graph_states_data[idx0123] = dest_state; // Plus one for the implicit epsilon self-loop. - num_final_arcs_data[idx012] = graph_row_split1_data[dest_state + 1] - + num_final_arcs_data[idx0123] = graph_row_split1_data[dest_state + 1] - graph_row_split1_data[dest_state] + 1; }); + ExclusiveSum(num_final_arcs, &num_final_arcs); + auto final_arcs_shape = RaggedShape2(&num_final_arcs, nullptr, -1); final_arcs_shape = ComposeRaggedShapes(last_frame_shape, final_arcs_shape); - // [steam][state][arc][arc] --> [stream][arc][arc] - // could be viewd as [strem][final state][arc] - final_arcs_shape = RemoveAxis(final_arcs_shape, 1); - const int32_t *fas_row_ids1_data = final_arcs_shape.RowIds(1).Data(); - const int32_t *fas_row_ids2_data = final_arcs_shape.RowIds(2).Data(); - const int32_t *fas_row_splits2_data = final_arcs_shape.RowSplits(2).Data(); + // [steam][context][state][arc][arc] --> [stream][context][arc][arc] + // could be viewd as [strem][context][final state][arc] + final_arcs_shape = RemoveAxis(final_arcs_shape, 2); + const int32_t *fas_row_ids1_data = final_arcs_shape.RowIds(1).Data(), + *fas_row_ids2_data = final_arcs_shape.RowIds(2).Data(), + *fas_row_ids3_data = final_arcs_shape.RowIds(3).Data(), + *fas_row_splits3_data = final_arcs_shape.RowSplits(3).Data(); auto final_arcs = Ragged(final_arcs_shape); ArcInfo *final_arcs_data = final_arcs.values.Data(); + K2_EVAL( c_, final_arcs_shape.NumElements(), lambda_populate_final_arcs, - (int32_t idx012) { - const int32_t idx01 = fas_row_ids2_data[idx012], // state + (int32_t idx0123) { + const int32_t idx012 = fas_row_ids3_data[idx0123], // state + idx01 = fas_row_ids2_data[idx012], // context idx0 = fas_row_ids1_data[idx01], // stream - idx01x = fas_row_splits2_data[idx01], - arc_idx2 = idx012 - idx01x; + idx012x = fas_row_splits3_data[idx012], + arc_idx3 = idx0123 - idx012x; const Arc *graph_arcs_data = graphs_arcs_data[idx0]; const int32_t *graph_row_split1_data = graph_row_splits1_ptr_data[idx0]; - int32_t graph_state_idx0 = final_graph_states_data[idx01]; + int32_t graph_state_idx0 = final_graph_states_data[idx012]; int32_t ai_graph_arc_idx01 = 0; int32_t ai_arc_label = 0; @@ -861,7 +869,7 @@ void RnntDecodingStreams::GetFinalArcs() { // arc_idx2 could be viewed as graph_arc_idx1, // since final_arcs_shape has 3 axes where arc_idx2 is calculated, // while decoding_graph only has 2 axes where arc_idx2 is used. - ai_graph_arc_idx01 = graph_arc_idx0x + arc_idx2; + ai_graph_arc_idx01 = graph_arc_idx0x + arc_idx3; auto graph_arc = graph_arcs_data[ai_graph_arc_idx01]; ai_arc_label = graph_arc.label; } @@ -872,7 +880,7 @@ void RnntDecodingStreams::GetFinalArcs() { ai.graph_arc_idx01 = ai_graph_arc_idx01; ai.score = 0.0; ai.label = ai_arc_label; - final_arcs_data[idx012] = ai; + final_arcs_data[idx0123] = ai; }); prev_frames_.emplace_back(std::make_shared>(final_arcs)); @@ -1053,7 +1061,7 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, K2_EVAL( c_, num_streams_, lambda_set_start_offset, (int32_t stream_idx) { num_padded_frames_data[stream_idx] = - frames - num_padded_frames_data[stream_idx]; + frames - num_padded_frames_data[stream_idx] - 1; K2_CHECK_LE(0, num_padded_frames_data[stream_idx]); }); } @@ -1117,11 +1125,12 @@ void RnntDecodingStreams::FormatOutput(const std::vector &num_frames, // make the generated lattice a valid k2 fsa. if (arc_info.graph_arc_idx01 <= -1 || arc_info.label == -1) { arc.label = 0; + out_map_data[oarc_idx01234] = -1; } else { arc.label = graph_arcs_data[arc_info.graph_arc_idx01].label; + out_map_data[oarc_idx01234] = arc_info.graph_arc_idx01; } arc.score = arc_info.score; - out_map_data[oarc_idx01234] = arc_info.graph_arc_idx01; } arcs_out_data[oarc_idx01234] = arc; if (arc_map_b != nullptr) {