Skip to content

feat: Add cpu random sample op#68

Open
baominghelly wants to merge 2 commits intomasterfrom
feat/dev-random-sample-cpu
Open

feat: Add cpu random sample op#68
baominghelly wants to merge 2 commits intomasterfrom
feat/dev-random-sample-cpu

Conversation

@baominghelly
Copy link
Copy Markdown

总结

添加 RandomSample 算子的基类接口和 CPU 后端实现,用于 LLM 推理中的 token 采样。

改动

  • src/base/random_sample.h — 算子基类,定义采样参数接口(temperature、top_k、top_p、min_p),支持 per-batch 标量/张量双模式,包含参数校验
  • src/cpu/random_sample/random_sample.h — CPU 后端实现,支持 stride-based 非连续张量,多 dtype 分发(float16/32/64/bfloat16 logits + int32/int64 输出),确定性/非确定性执行模式
  • tests/test_random_sample.py — 14 个测试用例,覆盖 greedy decoding、可复现性、top_k/top_p/min_p 过滤、1D logits、per-batch tensor 参数、int64 输出等场景

测试计划

pytest tests/test_random_sample.py -v — CPU 后端 14/14 通过

pip install .[dev] && CUDA_VISIBLE_DEVICES=6 pytest tests/test_random_sample.py -v
Looking in indexes: https://mirrors.aliyun.com/pypi/simple/
Processing ./.
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: pytest in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from InfiniOps==0.1.0) (9.0.2)
Requirement already satisfied: pytest-cov in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from InfiniOps==0.1.0) (7.1.0)
Requirement already satisfied: pytest-xdist in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from InfiniOps==0.1.0) (3.8.0)
Requirement already satisfied: ruff in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from InfiniOps==0.1.0) (0.15.9)
Requirement already satisfied: torch in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from InfiniOps==0.1.0) (2.10.0+cu128)
Requirement already satisfied: pyyaml in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from InfiniOps==0.1.0) (6.0.3)
Requirement already satisfied: exceptiongroup>=1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from pytest->InfiniOps==0.1.0) (1.3.1)
Requirement already satisfied: iniconfig>=1.0.1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from pytest->InfiniOps==0.1.0) (2.3.0)
Requirement already satisfied: packaging>=22 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from pytest->InfiniOps==0.1.0) (26.0)
Requirement already satisfied: pluggy<2,>=1.5 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from pytest->InfiniOps==0.1.0) (1.6.0)
Requirement already satisfied: pygments>=2.7.2 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from pytest->InfiniOps==0.1.0) (2.20.0)
Requirement already satisfied: tomli>=1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from pytest->InfiniOps==0.1.0) (2.4.1)
Requirement already satisfied: typing-extensions>=4.6.0 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from exceptiongroup>=1->pytest->InfiniOps==0.1.0) (4.15.0)
Requirement already satisfied: coverage>=7.10.6 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from coverage[toml]>=7.10.6->pytest-cov->InfiniOps==0.1.0) (7.13.5)
Requirement already satisfied: execnet>=2.1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from pytest-xdist->InfiniOps==0.1.0) (2.1.2)
Requirement already satisfied: filelock in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (3.25.2)
Requirement already satisfied: sympy>=1.13.3 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (3.4.2)
Requirement already satisfied: jinja2 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (2026.3.0)
Requirement already satisfied: cuda-bindings==12.9.4 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.9.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.4.5 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (3.4.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (1.13.1.3)
Requirement already satisfied: triton==3.6.0 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from torch->InfiniOps==0.1.0) (3.6.0)
Requirement already satisfied: cuda-pathfinder~=1.1 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from cuda-bindings==12.9.4->torch->InfiniOps==0.1.0) (1.5.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from sympy>=1.13.3->torch->InfiniOps==0.1.0) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/baoming/miniconda3/envs/infiniops/lib/python3.10/site-packages (from jinja2->torch->InfiniOps==0.1.0) (3.0.3)
Building wheels for collected packages: InfiniOps
  Building wheel for InfiniOps (pyproject.toml) ... done
  Created wheel for InfiniOps: filename=infiniops-0.1.0-cp310-cp310-linux_x86_64.whl size=457424 sha256=a3a8d94d1ca13f53c869a48058534d7f2b83d2250ee6afbd35d89e4333408e75
  Stored in directory: /tmp/pip-ephem-wheel-cache-0agv1omj/wheels/47/e9/84/43c5ddd5917ea3624f6ca758a125e9b600868be11c3e6071e9
Successfully built InfiniOps
Installing collected packages: InfiniOps
  Attempting uninstall: InfiniOps
    Found existing installation: InfiniOps 0.1.0
    Uninstalling InfiniOps-0.1.0:
      Successfully uninstalled InfiniOps-0.1.0
Successfully installed InfiniOps-0.1.0
========================================================================================================================================== test session starts ==========================================================================================================================================
platform linux -- Python 3.10.20, pytest-9.0.2, pluggy-1.6.0 -- /home/baoming/miniconda3/envs/infiniops/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/baoming/workplace/ops/InfiniOps
configfile: pyproject.toml
plugins: cov-7.1.0, anyio-4.13.0, xdist-3.8.0
collected 14 items                                                                                                                                                                                                                                                                                      

tests/test_random_sample.py::test_greedy_topk1[cpu-dtype0-1-16] PASSED                                                                                                                                                                                                                            [  7%]
tests/test_random_sample.py::test_greedy_topk1[cpu-dtype0-4-128] PASSED                                                                                                                                                                                                                           [ 14%]
tests/test_random_sample.py::test_greedy_topk1[cpu-dtype0-8-256] PASSED                                                                                                                                                                                                                           [ 21%]
tests/test_random_sample.py::test_reproducibility[cpu-dtype0-1-16] PASSED                                                                                                                                                                                                                         [ 28%]
tests/test_random_sample.py::test_reproducibility[cpu-dtype0-4-64] PASSED                                                                                                                                                                                                                         [ 35%]
tests/test_random_sample.py::test_output_valid[cpu-dtype0-2-32] PASSED                                                                                                                                                                                                                            [ 42%]
tests/test_random_sample.py::test_output_valid[cpu-dtype0-4-64] PASSED                                                                                                                                                                                                                            [ 50%]
tests/test_random_sample.py::test_topp_filtering[cpu-dtype0] PASSED                                                                                                                                                                                                                               [ 57%]
tests/test_random_sample.py::test_minp_filtering[cpu-dtype0] PASSED                                                                                                                                                                                                                               [ 64%]
tests/test_random_sample.py::test_1d_logits[cpu-dtype0] PASSED                                                                                                                                                                                                                                    [ 71%]
tests/test_random_sample.py::test_seed_offset_reproducibility[cpu-dtype0] PASSED                                                                                                                                                                                                                  [ 78%]
tests/test_random_sample.py::test_int64_output[cpu-dtype0] PASSED                                                                                                                                                                                                                                 [ 85%]
tests/test_random_sample.py::test_per_batch_tensor_params[cpu-dtype0] PASSED                                                                                                                                                                                                                      [ 92%]
tests/test_random_sample.py::test_per_batch_temperature_tensor[cpu-dtype0] PASSED                                                                                                                                                                                                                 [100%]

========================================================================================================================================== 14 passed in 0.17s ===========================================================================================================================================

@baominghelly baominghelly self-assigned this Apr 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant