From c1f0b82c39f090edd2840cd43bcc8c5e3cf87098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Vigil-V=C3=A1squez?= Date: Wed, 21 Jan 2026 11:30:31 +0100 Subject: [PATCH] fix: correct n-grams encoding Behavior wasn't deterministic for some reason, producing errors in some cases. --- src/encoding.jl | 6 ++++-- test/encoding.jl | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/encoding.jl b/src/encoding.jl index e793174..f5f69ba 100644 --- a/src/encoding.jl +++ b/src/encoding.jl @@ -437,13 +437,15 @@ and shift operations. # References - [Torchhd documentation](https://torchhd.readthedocs.io/en/stable/generated/torchhd.ngrams.html) - """ function ngrams(vs::AbstractVector{<:AbstractHV}, n::Int = 3) l = length(vs) p = l - n + 1 @assert 1 <= n <= length(vs) "`n` must be 1 ≤ n ≤ $l" - return bundle([bind([shift(vs[i + j], j) for j in 0:(n - 1)]) for i in 1:p]) + return map( + s -> bindsequence(s), + (vs[f:(f + (n - 1))] for f in 1:p) + ) |> multiset end """ diff --git a/test/encoding.jl b/test/encoding.jl index ff6180f..0d851a2 100644 --- a/test/encoding.jl +++ b/test/encoding.jl @@ -37,7 +37,8 @@ end @testset "ngrams" begin - @test ngrams(hvs) == Bool.([0, 1, 0, 0, 1]) + @test ngrams(hvs).v == Bool.([0, 1, 0, 0, 1]) + @test ngrams(hvs) == bundle([hvs[1] * ρ(hvs[2]) * ρ(hvs[3], 2), hvs[2] * ρ(hvs[3]) * ρ(hvs[4], 2), hvs[3] * ρ(hvs[4]) * ρ(hvs[5], 2)]) @test_throws AssertionError ngrams(hvs, 0) @test_throws AssertionError ngrams(hvs, length(hvs) + 1) end