From 7aebec7090cd571117c6c7fc70a4a7f691a19c24 Mon Sep 17 00:00:00 2001 From: Li Baoming <1508269885@qq.com> Date: Tue, 21 Apr 2026 10:04:17 +0800 Subject: [PATCH 1/2] feat: Add cpu random sample op --- src/base/random_sample.h | 164 ++++++++++++++ src/cpu/random_sample/random_sample.h | 305 ++++++++++++++++++++++++++ tests/test_random_sample.py | 222 +++++++++++++++++++ 3 files changed, 691 insertions(+) create mode 100644 src/base/random_sample.h create mode 100644 src/cpu/random_sample/random_sample.h create mode 100644 tests/test_random_sample.py diff --git a/src/base/random_sample.h b/src/base/random_sample.h new file mode 100644 index 00000000..e9066436 --- /dev/null +++ b/src/base/random_sample.h @@ -0,0 +1,164 @@ +#ifndef INFINI_OPS_BASE_RANDOM_SAMPLE_H_ +#define INFINI_OPS_BASE_RANDOM_SAMPLE_H_ + +#include +#include + +#include "operator.h" + +namespace infini::ops { + +class RandomSample : public Operator { + public: + // clang-format off + // + // logits: [batch_size, vocab_size] or [vocab_size] (batch_size=1) + // out: [batch_size] sampled token ids (int32/int64) + // valid: [batch_size] uint8 (0 or 1), whether sample is valid + // + // Per-batch parameters support two modes: + // - optional has value: per-batch tensor of shape [batch_size] + // - optional is nullopt: use the scalar _val for all requests + // + // When both (optional == nullopt) and (_val == default), + // the corresponding filtering is disabled. + // + // seed: per-request RNG seed. Same seed + offset produces identical + // results (reproducibility). Different requests should use + // different seeds. + // offset: per-step counter, increments each decode step within a request. + // Ensures different steps produce different samples even with the + // same seed. + // + // clang-format on + RandomSample(const Tensor logits, Tensor out, Tensor valid, + std::optional temperature, float temperature_val, + std::optional top_k, int top_k_val, + std::optional top_p, float top_p_val, + std::optional min_p, float min_p_val, + std::uint64_t seed, std::uint64_t offset, + bool deterministic) + : logits_dtype_{logits.dtype()}, + out_dtype_{out.dtype()}, + ndim_{logits.ndim()}, + batch_size_{ndim_ == 2 ? logits.size(-2) : 1}, + vocab_size_{logits.size(-1)}, + logits_strides_{logits.strides()}, + temperature_{temperature}, + temperature_val_{temperature_val}, + top_k_{top_k}, + top_k_val_{top_k_val}, + top_p_{top_p}, + top_p_val_{top_p_val}, + min_p_{min_p}, + min_p_val_{min_p_val}, + seed_{seed}, + offset_{offset}, + deterministic_{deterministic} { + assert((ndim_ == 1 || ndim_ == 2) && + "`RandomSample` requires 1D [vocab_size] or 2D [batch, vocab_size] " + "logits"); + assert(out.ndim() == 1 && out.size(0) == batch_size_ && + "`RandomSample` requires 1D output [batch_size]"); + assert(valid.ndim() == 1 && valid.size(0) == batch_size_ && + "`RandomSample` requires 1D valid [batch_size]"); + assert((out_dtype_ == DataType::kInt32 || out_dtype_ == DataType::kInt64) && + "`RandomSample` requires int32 or int64 output"); + ValidateParams(temperature, top_k, top_p, min_p); + } + + // Simplified constructor: no filtering, default temperature. + RandomSample(const Tensor logits, Tensor out, Tensor valid, + std::uint64_t seed, std::uint64_t offset) + : RandomSample{logits, out, valid, + std::nullopt, 1.0f, + std::nullopt, 0, + std::nullopt, 1.0f, + std::nullopt, 0.0f, + seed, offset, false} {} + + virtual void operator()(const Tensor logits, Tensor out, Tensor valid, + std::optional temperature, + float temperature_val, + std::optional top_k, int top_k_val, + std::optional top_p, float top_p_val, + std::optional min_p, float min_p_val, + std::uint64_t seed, std::uint64_t offset, + bool deterministic) const = 0; + + virtual void operator()(const Tensor logits, Tensor out, Tensor valid, + std::uint64_t seed, std::uint64_t offset) const { + return operator()(logits, out, valid, + temperature_, temperature_val_, + top_k_, top_k_val_, + top_p_, top_p_val_, + min_p_, min_p_val_, + seed, offset, deterministic_); + } + + protected: + static void ValidateIntParam(std::optional t, Tensor::Size batch_size) { + if (!t.has_value()) return; + const auto& tensor = *t; + assert(tensor.ndim() == 1 && tensor.size(0) == batch_size && + "per-batch int param must be 1D [batch_size]"); + assert((tensor.dtype() == DataType::kInt32 || + tensor.dtype() == DataType::kInt64) && + "per-batch int param must be int32 or int64"); + } + + static void ValidateFloatParam(std::optional t, Tensor::Size batch_size) { + if (!t.has_value()) return; + const auto& tensor = *t; + assert(tensor.ndim() == 1 && tensor.size(0) == batch_size && + "per-batch float param must be 1D [batch_size]"); + assert((tensor.dtype() == DataType::kFloat32 || + tensor.dtype() == DataType::kFloat64 || + tensor.dtype() == DataType::kFloat16 || + tensor.dtype() == DataType::kBFloat16) && + "per-batch float param must be float16/bfloat16/float32/float64"); + } + + void ValidateParams(std::optional temperature, std::optional top_k, + std::optional top_p, std::optional min_p) const { + ValidateFloatParam(temperature, batch_size_); + ValidateIntParam(top_k, batch_size_); + ValidateFloatParam(top_p, batch_size_); + ValidateFloatParam(min_p, batch_size_); + } + + const DataType logits_dtype_; + + const DataType out_dtype_; + + Tensor::Size ndim_{0}; + + Tensor::Size batch_size_{1}; + + Tensor::Size vocab_size_{0}; + + Tensor::Strides logits_strides_; + + // Per-batch or scalar sampling parameters. + std::optional temperature_; + float temperature_val_{1.0f}; + + std::optional top_k_; + int top_k_val_{0}; + + std::optional top_p_; + float top_p_val_{1.0f}; + + std::optional min_p_; + float min_p_val_{0.0f}; + + std::uint64_t seed_{0}; + + std::uint64_t offset_{0}; + + bool deterministic_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/random_sample/random_sample.h b/src/cpu/random_sample/random_sample.h new file mode 100644 index 00000000..4b49c3db --- /dev/null +++ b/src/cpu/random_sample/random_sample.h @@ -0,0 +1,305 @@ +#ifndef INFINI_OPS_CPU_RANDOM_SAMPLE_H_ +#define INFINI_OPS_CPU_RANDOM_SAMPLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base/random_sample.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" +#include "data_type.h" +#include "tensor.h" + +namespace infini::ops { + +template <> +class Operator + : public RandomSample, + Caster { + public: + using RandomSample::RandomSample; + + void operator()(const Tensor logits, Tensor out, Tensor valid, + std::optional temperature, float temperature_val, + std::optional top_k, int top_k_val, + std::optional top_p, float top_p_val, + std::optional min_p, float min_p_val, + std::uint64_t seed, std::uint64_t offset, + bool deterministic) const override { + DispatchFunc, + List>( + {logits_dtype_, out_dtype_}, + [&](auto tag1, auto tag2) { + using T = typename decltype(tag1)::type; + using OutT = typename decltype(tag2)::type; + Compute(logits, out, valid, temperature, temperature_val, + top_k, top_k_val, top_p, top_p_val, min_p, min_p_val, + seed, offset, deterministic); + }, + "`Operator::operator()`"); + } + + void operator()(const Tensor logits, Tensor out, Tensor valid, + std::uint64_t seed, std::uint64_t offset) const override { + return operator()(logits, out, valid, + std::nullopt, temperature_val_, + std::nullopt, top_k_val_, + std::nullopt, top_p_val_, + std::nullopt, min_p_val_, + seed, offset, deterministic_); + } + + private: + // Resolve a per-batch parameter: use tensor if provided, else scalar. + // Handles dtype conversion (int32/int64) and stride-based indexing. + template + ValType GetParam(std::optional tensor, ValType scalar_val, + Tensor::Size batch_idx) const { + if (tensor.has_value()) { + const auto& t = *tensor; + auto stride = t.strides().empty() ? 1 : t.strides()[0]; + auto offset = batch_idx * stride; + switch (t.dtype()) { + case DataType::kInt32: + return static_cast( + static_cast(t.data())[offset]); + case DataType::kInt64: + return static_cast( + static_cast(t.data())[offset]); + default: + assert(false && "unsupported dtype for int param"); + return scalar_val; + } + } + return scalar_val; + } + + // Resolve a per-batch float parameter, handling dtype and strides. + float GetFloatParam(std::optional tensor, float scalar_val, + Tensor::Size batch_idx) const { + if (tensor.has_value()) { + const auto& t = *tensor; + auto stride = t.strides().empty() ? 1 : t.strides()[0]; + auto offset = batch_idx * stride; + switch (t.dtype()) { + case DataType::kFloat32: + return static_cast(t.data())[offset]; + case DataType::kFloat64: + return static_cast( + static_cast(t.data())[offset]); + case DataType::kFloat16: + return static_cast(t.data())[offset].ToFloat(); + case DataType::kBFloat16: + return static_cast(t.data())[offset].ToFloat(); + default: + assert(false && "unsupported dtype for float param"); + return scalar_val; + } + } + return scalar_val; + } + + template + void Compute(const Tensor logits, Tensor out, Tensor valid, + std::optional temperature, float temperature_val, + std::optional top_k, int top_k_val, + std::optional top_p, float top_p_val, + std::optional min_p, float min_p_val, + std::uint64_t seed, std::uint64_t offset, + bool deterministic) const { + assert(valid.dtype() == DataType::kUInt8 && + "`RandomSample` requires uint8 valid tensor"); + + const auto* logits_ptr = static_cast(logits.data()); + auto* out_ptr = static_cast(out.data()); + auto* valid_ptr = static_cast(valid.data()); + + auto vocab_size = static_cast(vocab_size_); + + // Stride-based indexing — matching CausalSoftmax pattern. + auto batch_stride = ndim_ == 2 ? logits_strides_[0] : 0; + auto col_stride = logits_strides_[ndim_ - 1]; + + auto sample_batch = [&](Tensor::Size b, std::vector& probs) { + const T* logits_row = logits_ptr + b * batch_stride; + + // --- Step 1: Temperature scaling + Softmax --- + float temp = GetFloatParam(temperature, temperature_val, b); + float inv_temp = (temp > 0.f) ? (1.f / temp) : 0.f; + + float max_val = Cast(logits_row[0 * col_stride]) * inv_temp; + for (Tensor::Size j = 1; j < vocab_size; ++j) { + float v = Cast(logits_row[j * col_stride]) * inv_temp; + if (v > max_val) max_val = v; + } + + float sum = 0.f; + for (Tensor::Size j = 0; j < vocab_size; ++j) { + float v = std::exp( + Cast(logits_row[j * col_stride]) * inv_temp - max_val); + probs[j] = v; + sum += v; + } + + if (sum <= 0.f) { + out_ptr[b] = static_cast(0); + valid_ptr[b] = false; + return; + } + + for (Tensor::Size j = 0; j < vocab_size; ++j) { + probs[j] /= sum; + } + + // --- Step 2: top_k filtering --- + int k = GetParam(top_k, top_k_val, b); + if (k > 0 && static_cast(k) < vocab_size) { + // Find the k-th largest value using nth_element. + std::vector> indexed(vocab_size); + for (Tensor::Size j = 0; j < vocab_size; ++j) { + indexed[j] = {probs[j], j}; + } + // Partial sort: top k elements at the front. + std::nth_element( + indexed.begin(), indexed.begin() + k, indexed.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + + // Mask out everything beyond top-k. + for (Tensor::Size j = static_cast(k); j < vocab_size; + ++j) { + probs[indexed[j].second] = 0.f; + } + + // Renormalize. + float renorm_sum = 0.f; + for (Tensor::Size j = 0; j < vocab_size; ++j) { + renorm_sum += probs[j]; + } + if (renorm_sum > 0.f) { + for (Tensor::Size j = 0; j < vocab_size; ++j) { + probs[j] /= renorm_sum; + } + } + } + + // --- Step 3: top_p filtering --- + float p = GetFloatParam(top_p, top_p_val, b); + if (p > 0.f && p < 1.f) { + // Sort indices by probability descending. + std::vector sorted_idx(vocab_size); + std::iota(sorted_idx.begin(), sorted_idx.end(), 0); + std::sort(sorted_idx.begin(), sorted_idx.end(), + [&](Tensor::Size a, Tensor::Size b) { + return probs[a] > probs[b]; + }); + + float cumsum = 0.f; + Tensor::Size cutoff = vocab_size; + for (Tensor::Size i = 0; i < vocab_size; ++i) { + cumsum += probs[sorted_idx[i]]; + if (cumsum >= p) { + cutoff = i + 1; + break; + } + } + // Zero out everything beyond the cutoff. + for (Tensor::Size i = cutoff; i < vocab_size; ++i) { + probs[sorted_idx[i]] = 0.f; + } + + // Renormalize. + float renorm_sum = 0.f; + for (Tensor::Size j = 0; j < vocab_size; ++j) { + renorm_sum += probs[j]; + } + if (renorm_sum > 0.f) { + for (Tensor::Size j = 0; j < vocab_size; ++j) { + probs[j] /= renorm_sum; + } + } + } + + // --- Step 4: min_p filtering --- + float mp = GetFloatParam(min_p, min_p_val, b); + if (mp > 0.f) { + // Find max probability. + float max_prob = *std::max_element(probs.begin(), probs.end()); + float threshold = max_prob * mp; + + for (Tensor::Size j = 0; j < vocab_size; ++j) { + if (probs[j] < threshold) { + probs[j] = 0.f; + } + } + + // Renormalize. + float renorm_sum = 0.f; + for (Tensor::Size j = 0; j < vocab_size; ++j) { + renorm_sum += probs[j]; + } + if (renorm_sum > 0.f) { + for (Tensor::Size j = 0; j < vocab_size; ++j) { + probs[j] /= renorm_sum; + } + } + } + + // --- Step 5: Sample from CDF --- + std::mt19937 rng(static_cast(seed) + + static_cast(b) + offset); + std::uniform_real_distribution dist(0.f, 1.f); + float u = dist(rng); + + float cdf = 0.f; + Tensor::Size sampled = 0; + bool found = false; + for (Tensor::Size j = 0; j < vocab_size; ++j) { + cdf += probs[j]; + if (cdf > u) { + sampled = j; + found = true; + break; + } + } + + if (!found) { + // Fallback: pick the last non-zero probability token. + for (Tensor::Size j = vocab_size; j > 0; --j) { + if (probs[j - 1] > 0.f) { + sampled = j - 1; + found = true; + break; + } + } + } + + out_ptr[b] = static_cast(sampled); + valid_ptr[b] = found; + }; + + if (deterministic) { + std::vector probs(vocab_size); + for (Tensor::Size b = 0; b < batch_size_; ++b) { + sample_batch(b, probs); + } + } else { + std::vector probs(vocab_size); +#pragma omp parallel for firstprivate(probs) + for (Tensor::Size b = 0; b < batch_size_; ++b) { + sample_batch(b, probs); + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_random_sample.py b/tests/test_random_sample.py new file mode 100644 index 00000000..9b8c6b47 --- /dev/null +++ b/tests/test_random_sample.py @@ -0,0 +1,222 @@ +import infini.ops +import pytest +import torch + +from tests.utils import empty_strided, randn_strided + +# Only CPU implementation exists for now. +_CPU_ONLY = pytest.mark.parametrize("device", ("cpu",)) + + +# --- Helpers --- + + +def _random_sample( + logits, + out, + valid, + temperature=None, + temperature_val=1.0, + top_k=None, + top_k_val=0, + top_p=None, + top_p_val=1.0, + min_p=None, + min_p_val=0.0, + seed=0, + offset=0, + deterministic=True, +): + infini.ops.random_sample( + logits, + out, + valid, + temperature, + temperature_val, + top_k, + top_k_val, + top_p, + top_p_val, + min_p, + min_p_val, + seed, + offset, + deterministic, + ) + return out, valid + + +def _torch_argmax_sample(logits): + return torch.argmax(logits, dim=-1) + + +# --- Tests --- + + +@pytest.mark.parametrize("batch_size, vocab_size", ((1, 16), (4, 128), (8, 256))) +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_greedy_topk1(batch_size, vocab_size, dtype, device): + logits = randn_strided((batch_size, vocab_size), None, dtype=dtype, device=device) + out = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + _random_sample(logits, out, valid, top_k_val=1, seed=42) + + expected = _torch_argmax_sample(logits) + assert torch.equal(out, expected), f"top_k=1 should give argmax, got {out}, expected {expected}" + assert valid.all(), "all samples should be valid" + + +@pytest.mark.parametrize("batch_size, vocab_size", ((1, 16), (4, 64))) +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_reproducibility(batch_size, vocab_size, dtype, device): + logits = randn_strided((batch_size, vocab_size), None, dtype=dtype, device=device) + + out1 = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid1 = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + out2 = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid2 = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + _random_sample(logits, out1, valid1, seed=123, offset=0, deterministic=True) + _random_sample(logits, out2, valid2, seed=123, offset=0, deterministic=True) + + assert torch.equal(out1, out2), "same seed should give same output" + assert torch.equal(valid1, valid2) + + +@pytest.mark.parametrize("batch_size, vocab_size", ((2, 32), (4, 64))) +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_output_valid(batch_size, vocab_size, dtype, device): + logits = randn_strided((batch_size, vocab_size), None, dtype=dtype, device=device) + out = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + _random_sample(logits, out, valid, seed=42) + + assert valid.all(), "all samples should be valid for normal inputs" + assert (out >= 0).all() and (out < vocab_size).all(), "sampled indices out of range" + + +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_topp_filtering(dtype, device): + batch_size, vocab_size = 4, 16 + logits = torch.full((batch_size, vocab_size), -10.0, dtype=dtype, device=device) + logits[:, 0] = 10.0 + + out = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + _random_sample(logits, out, valid, top_p_val=0.5, seed=42) + + assert (out == 0).all(), "top_p=0.5 should always pick the dominant token" + + +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_minp_filtering(dtype, device): + batch_size, vocab_size = 4, 16 + logits = torch.full((batch_size, vocab_size), -10.0, dtype=dtype, device=device) + logits[:, 3] = 10.0 + + out = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + _random_sample(logits, out, valid, min_p_val=0.5, seed=42) + + assert (out == 3).all(), "min_p=0.5 should always pick the dominant token" + + +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_1d_logits(dtype, device): + vocab_size = 32 + logits = randn_strided((vocab_size,), None, dtype=dtype, device=device) + out = empty_strided((1,), None, dtype=torch.int32, device=device) + valid = empty_strided((1,), None, dtype=torch.uint8, device=device) + + _random_sample(logits, out, valid, top_k_val=1, seed=42) + + expected = _torch_argmax_sample(logits.unsqueeze(0)) + assert torch.equal(out, expected) + assert valid.all() + + +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_seed_offset_reproducibility(dtype, device): + """Same seed+offset reproduces; different seed likely differs.""" + batch_size, vocab_size = 4, 256 + logits = randn_strided((batch_size, vocab_size), None, dtype=dtype, device=device) + + out1 = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid1 = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + out2 = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid2 = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + out3 = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid3 = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + # Same seed + offset → must be identical + _random_sample(logits, out1, valid1, seed=1, offset=0) + _random_sample(logits, out2, valid2, seed=1, offset=0) + assert torch.equal(out1, out2), "same seed+offset should reproduce" + + # Different offset → must be different (different RNG state) + _random_sample(logits, out3, valid3, seed=1, offset=999999) + assert not torch.equal(out1, out3), "different offset should produce different results" + + +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_int64_output(dtype, device): + batch_size, vocab_size = 2, 32 + logits = randn_strided((batch_size, vocab_size), None, dtype=dtype, device=device) + out = empty_strided((batch_size,), None, dtype=torch.int64, device=device) + valid = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + _random_sample(logits, out, valid, top_k_val=1, seed=42) + + expected = _torch_argmax_sample(logits) + assert out.dtype == torch.int64 + assert torch.equal(out, expected) + + +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_per_batch_tensor_params(dtype, device): + """Per-batch tensor parameters (int64 top_k, float32 temperature) should work.""" + batch_size, vocab_size = 4, 32 + logits = randn_strided((batch_size, vocab_size), None, dtype=dtype, device=device) + out = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + # top_k as int64 tensor: batch 0 uses top_k=1 (greedy), others use top_k=0 (no filter). + top_k_tensor = torch.tensor([1, 0, 0, 0], dtype=torch.int64, device=device) + + _random_sample(logits, out, valid, top_k=top_k_tensor, seed=42) + + # Batch 0 must be argmax (top_k=1). + assert out[0].item() == torch.argmax(logits[0]).item() + assert valid.all() + + +@pytest.mark.parametrize("dtype", (torch.float32,)) +@_CPU_ONLY +def test_per_batch_temperature_tensor(dtype, device): + """Per-batch float32 temperature tensor should work.""" + batch_size, vocab_size = 4, 32 + logits = randn_strided((batch_size, vocab_size), None, dtype=dtype, device=device) + out = empty_strided((batch_size,), None, dtype=torch.int32, device=device) + valid = empty_strided((batch_size,), None, dtype=torch.uint8, device=device) + + # Very low temperature → near-deterministic for all batches. + temp_tensor = torch.full((batch_size,), 0.01, dtype=torch.float32, device=device) + + _random_sample(logits, out, valid, temperature=temp_tensor, seed=42) + + expected = _torch_argmax_sample(logits) + assert torch.equal(out, expected), "near-zero temperature should give argmax" + assert valid.all() From 740b7d3c929ec694c5d5069a90cad736c7803f4b Mon Sep 17 00:00:00 2001 From: Li Baoming <1508269885@qq.com> Date: Tue, 21 Apr 2026 10:33:29 +0800 Subject: [PATCH 2/2] Code format --- src/base/random_sample.h | 46 ++++++++++++--------------- src/cpu/random_sample/random_sample.h | 26 ++++++--------- tests/test_random_sample.py | 8 +++-- 3 files changed, 37 insertions(+), 43 deletions(-) diff --git a/src/base/random_sample.h b/src/base/random_sample.h index e9066436..c1b31ae2 100644 --- a/src/base/random_sample.h +++ b/src/base/random_sample.h @@ -35,9 +35,8 @@ class RandomSample : public Operator { std::optional temperature, float temperature_val, std::optional top_k, int top_k_val, std::optional top_p, float top_p_val, - std::optional min_p, float min_p_val, - std::uint64_t seed, std::uint64_t offset, - bool deterministic) + std::optional min_p, float min_p_val, std::uint64_t seed, + std::uint64_t offset, bool deterministic) : logits_dtype_{logits.dtype()}, out_dtype_{out.dtype()}, ndim_{logits.ndim()}, @@ -70,34 +69,29 @@ class RandomSample : public Operator { // Simplified constructor: no filtering, default temperature. RandomSample(const Tensor logits, Tensor out, Tensor valid, std::uint64_t seed, std::uint64_t offset) - : RandomSample{logits, out, valid, - std::nullopt, 1.0f, - std::nullopt, 0, - std::nullopt, 1.0f, - std::nullopt, 0.0f, - seed, offset, false} {} + : RandomSample{logits, out, valid, std::nullopt, + 1.0f, std::nullopt, 0, std::nullopt, + 1.0f, std::nullopt, 0.0f, seed, + offset, false} {} virtual void operator()(const Tensor logits, Tensor out, Tensor valid, std::optional temperature, - float temperature_val, - std::optional top_k, int top_k_val, - std::optional top_p, float top_p_val, - std::optional min_p, float min_p_val, - std::uint64_t seed, std::uint64_t offset, - bool deterministic) const = 0; + float temperature_val, std::optional top_k, + int top_k_val, std::optional top_p, + float top_p_val, std::optional min_p, + float min_p_val, std::uint64_t seed, + std::uint64_t offset, bool deterministic) const = 0; virtual void operator()(const Tensor logits, Tensor out, Tensor valid, std::uint64_t seed, std::uint64_t offset) const { - return operator()(logits, out, valid, - temperature_, temperature_val_, - top_k_, top_k_val_, - top_p_, top_p_val_, - min_p_, min_p_val_, - seed, offset, deterministic_); + return operator()(logits, out, valid, temperature_, temperature_val_, + top_k_, top_k_val_, top_p_, top_p_val_, min_p_, + min_p_val_, seed, offset, deterministic_); } protected: - static void ValidateIntParam(std::optional t, Tensor::Size batch_size) { + static void ValidateIntParam(std::optional t, + Tensor::Size batch_size) { if (!t.has_value()) return; const auto& tensor = *t; assert(tensor.ndim() == 1 && tensor.size(0) == batch_size && @@ -107,7 +101,8 @@ class RandomSample : public Operator { "per-batch int param must be int32 or int64"); } - static void ValidateFloatParam(std::optional t, Tensor::Size batch_size) { + static void ValidateFloatParam(std::optional t, + Tensor::Size batch_size) { if (!t.has_value()) return; const auto& tensor = *t; assert(tensor.ndim() == 1 && tensor.size(0) == batch_size && @@ -119,8 +114,9 @@ class RandomSample : public Operator { "per-batch float param must be float16/bfloat16/float32/float64"); } - void ValidateParams(std::optional temperature, std::optional top_k, - std::optional top_p, std::optional min_p) const { + void ValidateParams(std::optional temperature, + std::optional top_k, std::optional top_p, + std::optional min_p) const { ValidateFloatParam(temperature, batch_size_); ValidateIntParam(top_k, batch_size_); ValidateFloatParam(top_p, batch_size_); diff --git a/src/cpu/random_sample/random_sample.h b/src/cpu/random_sample/random_sample.h index 4b49c3db..8311ce2c 100644 --- a/src/cpu/random_sample/random_sample.h +++ b/src/cpu/random_sample/random_sample.h @@ -20,9 +20,8 @@ namespace infini::ops { template <> -class Operator - : public RandomSample, - Caster { +class Operator : public RandomSample, + Caster { public: using RandomSample::RandomSample; @@ -33,8 +32,7 @@ class Operator std::optional min_p, float min_p_val, std::uint64_t seed, std::uint64_t offset, bool deterministic) const override { - DispatchFunc, + DispatchFunc, List>( {logits_dtype_, out_dtype_}, [&](auto tag1, auto tag2) { @@ -49,12 +47,9 @@ class Operator void operator()(const Tensor logits, Tensor out, Tensor valid, std::uint64_t seed, std::uint64_t offset) const override { - return operator()(logits, out, valid, - std::nullopt, temperature_val_, - std::nullopt, top_k_val_, - std::nullopt, top_p_val_, - std::nullopt, min_p_val_, - seed, offset, deterministic_); + return operator()(logits, out, valid, std::nullopt, temperature_val_, + std::nullopt, top_k_val_, std::nullopt, top_p_val_, + std::nullopt, min_p_val_, seed, offset, deterministic_); } private: @@ -112,9 +107,8 @@ class Operator std::optional temperature, float temperature_val, std::optional top_k, int top_k_val, std::optional top_p, float top_p_val, - std::optional min_p, float min_p_val, - std::uint64_t seed, std::uint64_t offset, - bool deterministic) const { + std::optional min_p, float min_p_val, std::uint64_t seed, + std::uint64_t offset, bool deterministic) const { assert(valid.dtype() == DataType::kUInt8 && "`RandomSample` requires uint8 valid tensor"); @@ -143,8 +137,8 @@ class Operator float sum = 0.f; for (Tensor::Size j = 0; j < vocab_size; ++j) { - float v = std::exp( - Cast(logits_row[j * col_stride]) * inv_temp - max_val); + float v = std::exp(Cast(logits_row[j * col_stride]) * inv_temp - + max_val); probs[j] = v; sum += v; } diff --git a/tests/test_random_sample.py b/tests/test_random_sample.py index 9b8c6b47..655eef16 100644 --- a/tests/test_random_sample.py +++ b/tests/test_random_sample.py @@ -64,7 +64,9 @@ def test_greedy_topk1(batch_size, vocab_size, dtype, device): _random_sample(logits, out, valid, top_k_val=1, seed=42) expected = _torch_argmax_sample(logits) - assert torch.equal(out, expected), f"top_k=1 should give argmax, got {out}, expected {expected}" + assert torch.equal(out, expected), ( + f"top_k=1 should give argmax, got {out}, expected {expected}" + ) assert valid.all(), "all samples should be valid" @@ -166,7 +168,9 @@ def test_seed_offset_reproducibility(dtype, device): # Different offset → must be different (different RNG state) _random_sample(logits, out3, valid3, seed=1, offset=999999) - assert not torch.equal(out1, out3), "different offset should produce different results" + assert not torch.equal(out1, out3), ( + "different offset should produce different results" + ) @pytest.mark.parametrize("dtype", (torch.float32,))