diff --git a/.github/workflows/build_linux_wheel/action.yml b/.github/workflows/build_linux_wheel/action.yml index 1d62d1ae1c5..1e70c632035 100644 --- a/.github/workflows/build_linux_wheel/action.yml +++ b/.github/workflows/build_linux_wheel/action.yml @@ -69,12 +69,7 @@ runs: args: ${{ inputs.args }} before-script-linux: | set -e - apt install -y unzip - if [ $(uname -m) = "x86_64" ]; then - PROTOC_ARCH="x86_64" - else - PROTOC_ARCH="aarch_64" - fi - curl -L https://github.com/protocolbuffers/protobuf/releases/download/v24.4/protoc-24.4-linux-$PROTOC_ARCH.zip > /tmp/protoc.zip \ + yum install -y openssl-devel clang \ + && curl -L https://github.com/protocolbuffers/protobuf/releases/download/v24.4/protoc-24.4-linux-aarch_64.zip > /tmp/protoc.zip \ && unzip /tmp/protoc.zip -d /usr/local \ && rm /tmp/protoc.zip diff --git a/.github/workflows/bump-version/action.yml b/.github/workflows/bump-version/action.yml index ff2de1fcdc3..8d95117eeeb 100644 --- a/.github/workflows/bump-version/action.yml +++ b/.github/workflows/bump-version/action.yml @@ -24,19 +24,28 @@ runs: run: | cargo install cargo-workspaces --version 0.2.44 cargo ws version --no-git-commit -y --exact --force 'lance*' ${{ inputs.part }} + - name: Update python lockfile + working-directory: python + shell: bash + run: | + cargo update -p lance - name: Bump java version working-directory: java shell: bash run: | # Get current version current_version=$(mvn help:evaluate -Dexpression=project.version -q -DforceStdout) - current_version=${current_version%\%} + current_version=${current_version%%} + + base_version="${current_version%-*}" + if [[ "$current_version" == *-* ]]; then + suffix="-${current_version#*-}" + else + suffix="" + fi # Split the version into components using parameter expansion - major=${current_version%%.*} - minor=${current_version#*.} - minor=${minor%%.*} - patch=${current_version##*.} + IFS=. read major minor patch <<<"$base_version" case "${{ inputs.part }}" in patch) @@ -57,6 +66,6 @@ runs: ;; esac - new_version="${major}.${minor}.${patch}" + new_version="${major}.${minor}.${patch}${suffix}" mvn versions:set versions:commit -DnewVersion=$new_version diff --git a/.github/workflows/cargo-publish.yml b/.github/workflows/cargo-publish.yml index 8172b1e5845..1fed76e22ec 100644 --- a/.github/workflows/cargo-publish.yml +++ b/.github/workflows/cargo-publish.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: inputs: tag: - description: 'Tag to publish (e.g., v1.0.0)' + description: "Tag to publish (e.g., v1.0.0)" required: true type: string @@ -19,12 +19,13 @@ env: jobs: build: - runs-on: ubuntu-24.04 + # Needs additional disk space for the full build. + runs-on: ubuntu-2404-4x-x64 timeout-minutes: 60 env: # Need up-to-date compilers for kernels CC: clang-18 - CXX: clang-18 + CXX: clang++-18 defaults: run: working-directory: . @@ -53,5 +54,5 @@ jobs: - uses: albertlockett/publish-crates@v2.2 with: registry-token: ${{ secrets.CARGO_REGISTRY_TOKEN }} - args: '--all-features' + args: "--all-features" path: . diff --git a/.github/workflows/ci-benchmarks.yml b/.github/workflows/ci-benchmarks.yml index 90fc72af07c..bf6c4ee59ff 100644 --- a/.github/workflows/ci-benchmarks.yml +++ b/.github/workflows/ci-benchmarks.yml @@ -1,6 +1,7 @@ name: Run Regression Benchmarks on: + workflow_dispatch: push: branches: - main @@ -12,7 +13,7 @@ jobs: env: # Need up-to-date compilers for kernels CC: clang-18 - CXX: clang-18 + CXX: clang++-18 defaults: run: shell: bash diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 331a30ffd6f..8cd8fa8af00 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -6,12 +6,11 @@ on: pull_request: paths: - docs/** + - python/python/** - .github/workflows/docs-check.yml env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" + RUSTFLAGS: "-C debuginfo=0" # according to: https://matklad.github.io/2021/09/04/fast-rust-builds.html # CI builds are faster with incremental disabled. CARGO_INCREMENTAL: "0" @@ -26,7 +25,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" cache: 'pip' cache-dependency-path: "docs/requirements.txt" - name: Install dependencies @@ -34,11 +33,21 @@ jobs: sudo apt install -y -qq doxygen pandoc - name: Build python wheel uses: ./.github/workflows/build_linux_wheel - - name: Build Python + - name: Free disk space working-directory: python run: | - python -m pip install $(ls target/wheels/*.whl) - python -m pip install -r ../docs/requirements.txt + sudo chown 1001:118 -R target + mv target/wheels/*.whl ./ + cargo clean + - name: Build Python + working-directory: docs + run: | + python -m pip install ../python/*.whl + python -m pip install -r requirements.txt + - name: Run test + working-directory: docs + run: | + make doctest - name: Build docs working-directory: docs run: | diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index f5f40b80ac8..4e22458bc21 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -19,9 +19,7 @@ concurrency: cancel-in-progress: true env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" + RUSTFLAGS: "-C debuginfo=0" # according to: https://matklad.github.io/2021/09/04/fast-rust-builds.html # CI builds are faster with incremental disabled. CARGO_INCREMENTAL: "0" @@ -47,10 +45,16 @@ jobs: sudo apt install -y -qq doxygen pandoc - name: Build python wheel uses: ./.github/workflows/build_linux_wheel + - name: Free disk space + working-directory: python + run: | + sudo chown 1001:118 -R target + mv target/wheels/*.whl ./ + cargo clean - name: Build Python working-directory: python run: | - python -m pip install $(ls target/wheels/*.whl) + python -m pip install ../python/*.whl python -m pip install -r ../docs/requirements.txt - name: Build docs working-directory: docs diff --git a/.github/workflows/duckdb.yml b/.github/workflows/duckdb.yml deleted file mode 100644 index a215ae84a63..00000000000 --- a/.github/workflows/duckdb.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: DuckDB Extension -on: - push: - branches: - - main - pull_request: - paths: - - integration/duckdb_lance/* - - .github/workflows/duckdb.yml - - ./rust/* - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - Linux: - runs-on: ubuntu-22.04 - timeout-minutes: 45 - defaults: - run: - working-directory: ./integration/duckdb_lance - steps: - - uses: actions/checkout@v4 - - name: Install dependencies - run: | - sudo apt update - sudo apt install -y protobuf-compiler libssl-dev - - name: Checkout submodules - run: | - git submodule init - git submodule update - - name: Make - run: make build - # - name: Upload Lance duckdb extension - # uses: actions/upload-artifact@v3 - # with: - # name: duckdb-ubuntu-extension - # path: integration/duckdb/build/lance.duckdb_extension - # retention-days: 1 - MacOS: - runs-on: macos-14 - timeout-minutes: 40 - defaults: - run: - working-directory: ./integration/duckdb_lance - steps: - - uses: actions/checkout@v4 - - name: Install dependencies - run: | - brew install protobuf - - name: Checkout submodules - run: | - git submodule init - git submodule update - - name: Build - run: make build - # - name: Upload Lance duckdb extension - # uses: actions/upload-artifact@v3 - # with: - # name: duckdb-intel-mac-extension - # path: integration/duckdb/build/lance.duckdb_extension - # retention-days: 1 - diff --git a/.github/workflows/file_verification/test_write_read.py b/.github/workflows/file_verification/test_write_read.py index 03c53da7274..4bdfa354a51 100644 --- a/.github/workflows/file_verification/test_write_read.py +++ b/.github/workflows/file_verification/test_write_read.py @@ -48,5 +48,5 @@ assert tab_lance == parquet_table print(f"Table read from Lance is the same as table read from Parquet for file: {file_path}") - except Exception as e: + except Exception: raise AssertionError(f"Table read from Lance is not the same as table read from Parquet for file: {file_path}") \ No newline at end of file diff --git a/.github/workflows/java-publish.yml b/.github/workflows/java-publish.yml index 47a4213e13e..b65118a8abc 100644 --- a/.github/workflows/java-publish.yml +++ b/.github/workflows/java-publish.yml @@ -12,7 +12,7 @@ jobs: macos-arm64: name: Build on MacOS Arm64 runs-on: macos-14 - timeout-minutes: 30 + timeout-minutes: 60 defaults: run: working-directory: ./java @@ -36,28 +36,68 @@ jobs: name: Build on Linux Arm64 runs-on: ubuntu-2404-8x-arm64 timeout-minutes: 60 - defaults: - run: - working-directory: ./java steps: - name: Checkout repository uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: - toolchain: "stable" - cache-workspaces: "src/rust" - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - rustflags: "-C debuginfo=1" - - name: Install dependencies - run: | - sudo apt -y -qq update - sudo apt install -y protobuf-compiler libssl-dev pkg-config - - name: Build release + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Check glibc version outside docker + run: ldd --version + - name: Build and run in Ubuntu 20.04 container run: | - cargo build --release - cp ../target/release/liblance_jni.so liblance_jni.so + docker run --platform linux/arm64 -v ${{ github.workspace }}:/workspace -w /workspace debian:10 bash -c " + + set -ex + apt-get update + + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends --assume-yes \ + apt-transport-https \ + ca-certificates \ + curl \ + gpg \ + bash \ + less \ + openssl \ + libssl-dev \ + pkg-config \ + libsqlite3-dev \ + libsqlite3-0 \ + libreadline-dev \ + git \ + cmake \ + dh-autoreconf \ + clang \ + g++ \ + libc++-dev \ + libc++abi-dev \ + libprotobuf-dev \ + libncurses5-dev \ + libncursesw5-dev \ + libudev-dev \ + libhidapi-dev \ + zip \ + unzip + + # https://github.com/databendlabs/databend/issues/8035 + PROTOC_ZIP=protoc-3.15.0-linux-aarch_64.zip + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v3.15.0/\$PROTOC_ZIP + unzip -o \$PROTOC_ZIP -d /usr/local + rm -f \$PROTOC_ZIP + protoc --version + + curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable + source \$HOME/.cargo/env + + cd java + + # https://github.com/rustls/rustls/issues/1967 + export CC=clang + export CXX=clang++ + ldd --version + + cargo build --release + cp ../target/release/liblance_jni.so liblance_jni.so + " - uses: actions/upload-artifact@v4 with: name: liblance_jni_linux_aarch64.zip @@ -66,7 +106,7 @@ jobs: if-no-files-found: error linux-x86: runs-on: ubuntu-24.04 - timeout-minutes: 45 + timeout-minutes: 60 needs: [macos-arm64, linux-arm64] defaults: run: @@ -97,21 +137,142 @@ jobs: mkdir -p ./core/target/classes/nativelib/darwin-aarch64 ./core/target/classes/nativelib/linux-aarch64 cp ../liblance_jni_darwin_aarch64.zip/liblance_jni.dylib ./core/target/classes/nativelib/darwin-aarch64/liblance_jni.dylib cp ../liblance_jni_linux_aarch64.zip/liblance_jni.so ./core/target/classes/nativelib/linux-aarch64/liblance_jni.so - - name: Set github - run: | - git config --global user.email "Lance Github Runner" - git config --global user.name "dev+gha@lancedb.com" - - name: Dry run + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Check glibc version outside docker + run: ldd --version + - name: Build and run in Ubuntu 20.04 container (Dry Run) if: github.event_name == 'pull_request' run: | - mvn --batch-mode -DskipTests -Drust.release.build=true package - - name: Publish with Java 8 + docker run --platform linux/amd64 -v ${{ github.workspace }}:/workspace -w /workspace openjdk:8-jdk-buster bash -c " + set -ex + apt-get update + + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends --assume-yes \ + apt-transport-https \ + ca-certificates \ + curl \ + gpg \ + bash \ + less \ + openssl \ + libssl-dev \ + pkg-config \ + libsqlite3-dev \ + libsqlite3-0 \ + libreadline-dev \ + git \ + cmake \ + dh-autoreconf \ + clang \ + g++ \ + libc++-dev \ + libc++abi-dev \ + libprotobuf-dev \ + libncurses5-dev \ + libncursesw5-dev \ + libudev-dev \ + libhidapi-dev \ + zip \ + unzip \ + + # manually install maven, apt will use java11 + MAVEN_VERSION=3.9.6 + curl -OL https://dlcdn.apache.org/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz + tar -xzf apache-maven-3.9.9-bin.tar.gz + mv apache-maven-3.9.9 /opt/maven + ln -s /opt/maven/bin/mvn /usr/bin/mvn + + # https://github.com/databendlabs/databend/issues/8035 + PROTOC_ZIP=protoc-3.15.0-linux-x86_64.zip + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v3.15.0/\$PROTOC_ZIP + unzip -o \$PROTOC_ZIP -d /usr/local + rm -f \$PROTOC_ZIP + protoc --version + + # set Github + git config --global user.email \"Lance Github Runner\" + git config --global user.name \"dev+gha@lancedb.com\" + + curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable + source \$HOME/.cargo/env + + cd java + + # https://github.com/rustls/rustls/issues/1967 + export CC=clang + export CXX=clang++ + ldd --version + + mvn --batch-mode -DskipTests -Drust.release.build=true package + " + - name: Build and run in Ubuntu 20.04 container (Publish to Sonatype) if: github.event_name == 'release' run: | - echo "use-agent" >> ~/.gnupg/gpg.conf - echo "pinentry-mode loopback" >> ~/.gnupg/gpg.conf - export GPG_TTY=$(tty) - mvn --batch-mode -DskipTests -Drust.release.build=true -DpushChanges=false -Dgpg.passphrase=${{ secrets.GPG_PASSPHRASE }} deploy -P deploy-to-ossrh -P shade-jar - env: - SONATYPE_USER: ${{ secrets.SONATYPE_USER }} - SONATYPE_TOKEN: ${{ secrets.SONATYPE_TOKEN }} + docker run --platform linux/amd64 -v ${{ github.workspace }}:/workspace -w /workspace openjdk:8-jdk-buster bash -c " + set -ex + apt-get update + + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends --assume-yes \ + apt-transport-https \ + ca-certificates \ + curl \ + gpg \ + bash \ + less \ + openssl \ + libssl-dev \ + pkg-config \ + libsqlite3-dev \ + libsqlite3-0 \ + libreadline-dev \ + git \ + cmake \ + dh-autoreconf \ + clang \ + g++ \ + libc++-dev \ + libc++abi-dev \ + libprotobuf-dev \ + libncurses5-dev \ + libncursesw5-dev \ + libudev-dev \ + libhidapi-dev \ + zip \ + unzip + + # manually install maven, apt will use java11 + MAVEN_VERSION=3.9.6 + curl -OL https://dlcdn.apache.org/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz + tar -xzf apache-maven-3.9.9-bin.tar.gz + mv apache-maven-3.9.9 /opt/maven + ln -s /opt/maven/bin/mvn /usr/bin/mvn + + # https://github.com/databendlabs/databend/issues/8035 + PROTOC_ZIP=protoc-3.15.0-linux-x86_64.zip + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v3.15.0/\$PROTOC_ZIP + unzip -o \$PROTOC_ZIP -d /usr/local + rm -f \$PROTOC_ZIP + protoc --version + + # set Github + git config --global user.email \"Lance Github Runner\" + git config --global user.name \"dev+gha@lancedb.com\" + + curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable + source \$HOME/.cargo/env + + cd java + + # https://github.com/rustls/rustls/issues/1967 + export CC=clang + export CXX=clang++ + ldd --version + + export SONATYPE_USER=${{ secrets.SONATYPE_USER }} + export SONATYPE_TOKEN=${{ secrets.SONATYPE_TOKEN }} + echo "use-agent" >> ~/.gnupg/gpg.conf + echo "pinentry-mode loopback" >> ~/.gnupg/gpg.conf + export GPG_TTY=$(tty) + mvn --batch-mode -DskipTests -Drust.release.build=true -DpushChanges=false -Dgpg.passphrase=${{ secrets.GPG_PASSPHRASE }} deploy -P deploy-to-ossrh -P shade-jar + " diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index aac412df337..0be50a24cb1 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -25,52 +25,91 @@ env: jobs: rust-clippy-fmt: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 name: Rust Clippy and Fmt Check - defaults: - run: - working-directory: ./java/core/lance-jni steps: - name: Checkout repository uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 with: - workspaces: java/core/lance-jni + workspaces: | + lance + java/core/lance-jni -> ../target/rust-maven-plugin/lance-jni - name: Install dependencies run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev + # pin the toolchain version to avoid surprises + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: stable + - uses: rui314/setup-mold@v1 + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov - name: Run cargo fmt + working-directory: java/core/lance-jni run: cargo fmt --check - name: Rust Clippy + working-directory: java/core/lance-jni run: cargo clippy --all-targets -- -D warnings build-and-test-java: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + timeout-minutes: 60 strategy: matrix: java-version: [8, 11, 17] name: Build and Test with Java ${{ matrix.java-version }} - defaults: - run: - working-directory: ./java steps: - - name: Checkout repository - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - with: - workspaces: java/core/lance-jni - name: Install dependencies run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev + # pin the toolchain version to avoid surprises + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: stable + - uses: rui314/setup-mold@v1 + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + - name: Checkout repository + uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + with: + workspaces: java/core/lance-jni -> ../target/rust-maven-plugin/lance-jni - name: Set up Java ${{ matrix.java-version }} uses: actions/setup-java@v4 with: distribution: temurin java-version: ${{ matrix.java-version }} cache: "maven" + - name: Running code style check with Java ${{ matrix.java-version }} + working-directory: java + run: | + if [ "${{ matrix.java-version }}" == "17" ]; then + export JAVA_TOOL_OPTIONS="$JAVA_TOOL_OPTIONS \ + -XX:+IgnoreUnrecognizedVMOptions \ + --add-opens=java.base/java.lang=ALL-UNNAMED \ + --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ + --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \ + --add-opens=java.base/java.io=ALL-UNNAMED \ + --add-opens=java.base/java.net=ALL-UNNAMED \ + --add-opens=java.base/java.nio=ALL-UNNAMED \ + --add-opens=java.base/java.util=ALL-UNNAMED \ + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \ + --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \ + --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED \ + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ + --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \ + --add-opens=java.base/sun.security.action=ALL-UNNAMED \ + --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \ + --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED \ + -Djdk.reflect.useDirectMethodHandle=false \ + -Dio.netty.tryReflectionSetAccessible=true" + fi + mvn spotless:check - name: Running tests with Java ${{ matrix.java-version }} + working-directory: java run: | if [ "${{ matrix.java-version }}" == "17" ]; then export JAVA_TOOL_OPTIONS="$JAVA_TOOL_OPTIONS \ @@ -93,4 +132,4 @@ jobs: -Djdk.reflect.useDirectMethodHandle=false \ -Dio.netty.tryReflectionSetAccessible=true" fi - mvn clean install + mvn install diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index 7e08b928388..7b8979cf144 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -15,15 +15,20 @@ jobs: - platform: x86_64 manylinux: "2_17" extra_args: "" + runner: ubuntu-22.04 - platform: x86_64 manylinux: "2_28" extra_args: "--features fp16kernels" + runner: ubuntu-22.04 - platform: aarch64 - manylinux: "2_24" + manylinux: "2_17" extra_args: "" - # We don't build fp16 kernels for aarch64, because it uses - # cross compilation image, which doesn't have a new enough compiler. - runs-on: "ubuntu-22.04" + runner: ubuntu-2404-4x-arm64 + - platform: aarch64 + manylinux: "2_28" + extra_args: "--features fp16kernels" + runner: ubuntu-2404-4x-arm64 + runs-on: ${{ matrix.config.runner }} steps: - uses: actions/checkout@v4 with: diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index fb677eab8e8..39475191bf7 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -27,6 +27,9 @@ env: RUSTFLAGS: "-C debuginfo=1" RUST_BACKTRACE: "1" CI: "true" + # Color output for pytest is off by default. + PYTEST_ADDOPTS: "--color=yes" + FORCE_COLOR: "1" jobs: lint: @@ -39,7 +42,7 @@ jobs: env: # Need up-to-date compilers for kernels CC: clang-18 - CXX: clang-18 + CXX: clang++-18 steps: - uses: actions/checkout@v4 with: @@ -54,20 +57,22 @@ jobs: workspaces: python - name: Install linting tools run: | - pip install ruff==0.4.1 maturin tensorflow tqdm ray[data] + pip install ruff==0.11.2 maturin tensorflow tqdm ray[data] pyright datasets polars[pyarrow,pandas] pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Lint Python run: | ruff format --check python ruff check python + pyright - name: Install dependencies run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev - name: Lint Rust run: | + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` cargo fmt --all -- --check - cargo clippy --locked --all-features --tests -- -D warnings + cargo clippy --locked --features ${ALL_FEATURES} --tests -- -D warnings - name: Build run: | python -m venv venv @@ -143,7 +148,7 @@ jobs: - uses: ./.github/workflows/build_linux_wheel with: arm-build: "true" - manylinux: "2_24" + manylinux: "2_28" - name: Install dependencies run: | sudo apt update -y -qq @@ -201,22 +206,6 @@ jobs: run: shell: bash working-directory: python - services: - minio: - image: lazybit/minio - ports: - - 9000:9000 - env: - MINIO_ACCESS_KEY: ACCESSKEY - MINIO_SECRET_KEY: SECRETKEY - options: --name=minio --health-cmd "curl http://localhost:9000/minio/health/live" - dynamodb-local: - image: amazon/dynamodb-local - ports: - - 8000:8000 - env: - AWS_ACCESS_KEY_ID: ACCESSKEY - AWS_SECRET_ACCESS_KEY: SECRETKEY steps: - uses: actions/checkout@v4 with: diff --git a/.github/workflows/run_integtests/action.yml b/.github/workflows/run_integtests/action.yml index 0c1d7e9dfb9..38115e49fea 100644 --- a/.github/workflows/run_integtests/action.yml +++ b/.github/workflows/run_integtests/action.yml @@ -9,6 +9,10 @@ runs: shell: bash run: | pip3 install $(ls target/wheels/pylance-*.whl)[tests,ray] + - name: Start localstack + shell: bash + run: | + docker compose -f docker-compose.yml up -d --wait - name: Run python tests shell: bash working-directory: python diff --git a/.github/workflows/rust-benchmark.yml b/.github/workflows/rust-benchmark.yml index dfcc14b7535..55e835dd693 100644 --- a/.github/workflows/rust-benchmark.yml +++ b/.github/workflows/rust-benchmark.yml @@ -29,7 +29,7 @@ env: jobs: Benchmark: - runs-on: [self-hosted, linux, x64] + runs-on: warp-ubuntu-latest-arm64-8x timeout-minutes: 120 steps: - name: Checkout diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 31b02b866a4..79b91b116f7 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -36,6 +36,8 @@ jobs: - name: Check formatting run: cargo fmt -- --check clippy: + permissions: + checks: write runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -44,13 +46,27 @@ jobs: run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev - - name: Run clippy + - name: Get features run: | - cargo clippy --version - cargo clippy --locked --all-features --tests --benches -- -D warnings + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + echo "ALL_FEATURES=${ALL_FEATURES}" >> $GITHUB_ENV + - uses: auguwu/clippy-action@1.4.0 + with: + check-args: --locked --features ${{ env.ALL_FEATURES }} --tests --benches + token: ${{secrets.GITHUB_TOKEN}} + deny: warnings + cargo-deny: + name: Check Rust dependencies (cargo-deny) + runs-on: ubuntu-24.04 + steps: + - uses: actions/checkout@v4 + - uses: EmbarkStudios/cargo-deny-action@v2 + with: + log-level: warn + command: check linux-build: runs-on: "ubuntu-24.04" - timeout-minutes: 45 + timeout-minutes: 60 strategy: matrix: toolchain: @@ -59,7 +75,7 @@ jobs: env: # Need up-to-date compilers for kernels CC: clang - CXX: clang + CXX: clang++ steps: - uses: actions/checkout@v4 # pin the toolchain version to avoid surprises @@ -73,21 +89,25 @@ jobs: sudo apt update sudo apt install -y protobuf-compiler libssl-dev rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} - - name: Start DynamoDB local for tests - run: | - docker run -d -e AWS_ACCESS_KEY_ID=DUMMYKEY -e AWS_SECRET_ACCESS_KEY=DUMMYKEY -p 8000:8000 amazon/dynamodb-local + - name: Start DynamodDB and S3 + run: docker compose -f docker-compose.yml up -d --wait - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: Run tests if: ${{ matrix.toolchain == 'stable' }} run: | - cargo llvm-cov --locked --workspace --codecov --output-path coverage.codecov --all-features + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + cargo llvm-cov --locked --workspace --codecov --output-path coverage.codecov --features ${ALL_FEATURES} - name: Build tests (nightly) - run: cargo test --locked --all-features --workspace --no-run + if: ${{ matrix.toolchain != 'stable' }} + run: | + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + cargo test --locked --features ${ALL_FEATURES} --workspace --no-run - name: Run tests (nightly) if: ${{ matrix.toolchain != 'stable' }} run: | - cargo test --all-features --workspace + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + cargo test --features ${ALL_FEATURES} --workspace - name: Upload coverage to Codecov if: ${{ matrix.toolchain == 'stable' }} uses: codecov/codecov-action@v4 @@ -99,7 +119,7 @@ jobs: fail_ci_if_error: false linux-arm: runs-on: ubuntu-2404-4x-arm64 - timeout-minutes: 45 + timeout-minutes: 75 steps: - uses: actions/checkout@v4 - uses: actions-rust-lang/setup-rust-toolchain@v1 @@ -113,20 +133,21 @@ jobs: sudo apt install -y protobuf-compiler libssl-dev pkg-config - name: Build tests run: | - cargo test --locked --all-features --no-run - - name: Start DynamoDB local for tests - run: | - docker run -d -e AWS_ACCESS_KEY_ID=DUMMYKEY -e AWS_SECRET_ACCESS_KEY=DUMMYKEY -p 8000:8000 amazon/dynamodb-local + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + cargo test --locked --features ${ALL_FEATURES} --no-run + - name: Start DynamodDB and S3 + run: docker compose -f docker-compose.yml up -d --wait - name: Run tests run: | - cargo test --locked --all-features + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + cargo test --locked --features ${ALL_FEATURES} build-no-lock: runs-on: ubuntu-24.04 timeout-minutes: 30 env: # Need up-to-date compilers for kernels CC: clang - CXX: clang + CXX: clang++ steps: - uses: actions/checkout@v4 # Remote cargo.lock to force a fresh build @@ -139,7 +160,9 @@ jobs: sudo apt update sudo apt install -y protobuf-compiler libssl-dev - name: Build all - run: cargo build --benches --all-features --tests + run: | + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + cargo build --benches --features ${ALL_FEATURES} --tests mac-build: runs-on: "macos-14" timeout-minutes: 45 @@ -150,7 +173,7 @@ jobs: - nightly defaults: run: - working-directory: ./rust/lance + working-directory: ./rust steps: - uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 @@ -165,16 +188,19 @@ jobs: run: | rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} - name: Build tests - run: cargo test --locked --all-features --no-run + run: | + cargo test --locked --features fp16kernels,cli,tensorflow,dynamodb,substrait --no-run - name: Run tests - run: cargo test --all-features + run: | + cargo test --features fp16kernels,cli,tensorflow,dynamodb,substrait - name: Check benchmarks - run: cargo check --benches --all-features + run: | + cargo check --benches --features fp16kernels,cli,tensorflow,dynamodb,substrait windows-build: runs-on: windows-latest defaults: run: - working-directory: rust/lance + working-directory: rust steps: - uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 @@ -199,11 +225,11 @@ jobs: runs-on: ubuntu-24.04 strategy: matrix: - msrv: ["1.78.0"] # This should match up with rust-version in Cargo.toml + msrv: ["1.82.0"] # This should match up with rust-version in Cargo.toml env: # Need up-to-date compilers for kernels CC: clang - CXX: clang + CXX: clang++ steps: - uses: actions/checkout@v4 with: @@ -218,4 +244,6 @@ jobs: with: toolchain: ${{ matrix.msrv }} - name: cargo +${{ matrix.msrv }} check - run: cargo check --workspace --tests --benches --all-features + run: | + ALL_FEATURES=`cargo metadata --format-version=1 --no-deps | jq -r '.packages[] | .features | keys | .[]' | grep -v protoc | sort | uniq | paste -s -d "," -` + cargo check --workspace --tests --benches --features ${ALL_FEATURES} diff --git a/.gitignore b/.gitignore index a70b512b409..a612a80a321 100644 --- a/.gitignore +++ b/.gitignore @@ -67,11 +67,6 @@ docs/api/python **/.ipynb_checkpoints/ docs/notebooks - -integration/duckdb/*-build -integration/duckdb/lance.duckdb_extension.*.zip - -notebooks/lance.duckdb_extension notebooks/sift notebooks/image_data/data benchmarks/sift/sift @@ -97,4 +92,5 @@ target python/venv test_data/venv -**/*.profraw \ No newline at end of file +**/*.profraw +*.lance diff --git a/.gitmodules b/.gitmodules index 05d79fc7c9a..e69de29bb2d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +0,0 @@ -[submodule "integration/duckdb_lance/duckdb"] - path = integration/duckdb_lance/duckdb - url = https://github.com/duckdb/duckdb.git -[submodule "integration/duckdb_lance/duckdb-ext/duckdb"] - path = integration/duckdb_lance/duckdb-ext/duckdb - url = https://github.com/duckdb/duckdb.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a49e64a867d..09c956152fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.2 + rev: v0.11.2 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/.typos.toml b/.typos.toml index 9b142594f9f..1535c0a746f 100644 --- a/.typos.toml +++ b/.typos.toml @@ -1,3 +1,6 @@ +[default] +extend-ignore-re = ["(?Rm)^.*(#|//)\\s*spellchecker:disable-line$"] + [default.extend-words] DNE = "DNE" arange = "arange" @@ -7,4 +10,5 @@ abd = "abd" afe = "afe" [files] -extend-exclude = ["notebooks/*.ipynb"] \ No newline at end of file +extend-exclude = ["notebooks/*.ipynb"] +# If a line ends with # or // and has spellchecker:disable-line, ignore it diff --git a/Cargo.lock b/Cargo.lock index 40107d0766a..f7374ff1934 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,18 +4,24 @@ version = 3 [[package]] name = "addr2line" -version = "0.22.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] [[package]] -name = "adler" -version = "1.0.2" +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "adler32" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" @@ -25,10 +31,10 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -42,18 +48,18 @@ dependencies = [ [[package]] name = "aligned-vec" -version = "0.6.1" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e0966165eaf052580bd70eb1b32cb3d6245774c0104d1b2793e9650bf83b52a" +checksum = "dc890384c8602f339876ded803c97ad529f3842aba97f6392b3dba0dd171769b" dependencies = [ "equator", ] [[package]] name = "all_asserts" -version = "2.3.1" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca77caf0ca1057c274cda103cda1363d892b7cad5f2e646afde4df0697bea100" +checksum = "514ce16346f9fc96702fd52f2ae7e383b185516ee6f556efd7c3176be8fe7bea" [[package]] name = "alloc-no-stdlib" @@ -72,9 +78,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android-tzdata" @@ -99,9 +105,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -114,43 +120,44 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "once_cell", + "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" [[package]] name = "approx" @@ -169,21 +176,21 @@ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" [[package]] name = "arrayref" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" +checksum = "b5ec52ba94edeed950e4a41f75d35376df196e8cb04437f7280a5aa49f20f796" dependencies = [ "arrow-arith", "arrow-array", @@ -202,24 +209,23 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" +checksum = "8fc766fdacaf804cb10c7c70580254fcdb5d55cdfda2bc57b02baf5223a3af9e" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "half", "num", ] [[package]] name = "arrow-array" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" +checksum = "a12fcdb3f1d03f69d3ec26ac67645a8fe3f878d77b5ebb0b15d64a116c212985" dependencies = [ "ahash", "arrow-buffer", @@ -228,15 +234,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown", + "hashbrown 0.15.2", "num", ] [[package]] name = "arrow-buffer" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" +checksum = "263f4801ff1839ef53ebd06f99a56cecd1dbaf314ec893d93168e2e860e0291c" dependencies = [ "bytes", "half", @@ -245,9 +251,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" +checksum = "ede6175fbc039dfc946a61c1b6d42fd682fcecf5ab5d148fbe7667705798cac9" dependencies = [ "arrow-array", "arrow-buffer", @@ -266,28 +272,25 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c13c36dc5ddf8c128df19bab27898eea64bf9da2b555ec1cd17a8ff57fba9ec2" +checksum = "1644877d8bc9a0ef022d9153dc29375c2bda244c39aec05a91d0e87ccf77995f" dependencies = [ "arrow-array", - "arrow-buffer", "arrow-cast", - "arrow-data", "arrow-schema", "chrono", "csv", "csv-core", "lazy_static", - "lexical-core", "regex", ] [[package]] name = "arrow-data" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" +checksum = "61cfdd7d99b4ff618f167e548b2411e5dd2c98c0ddebedd7df433d34c20a4429" dependencies = [ "arrow-buffer", "arrow-schema", @@ -297,13 +300,12 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e786e1cdd952205d9a8afc69397b317cfbb6e0095e445c69cda7e8da5c1eeb0f" +checksum = "62ff528658b521e33905334723b795ee56b393dbe9cf76c8b1f64b648c65a60c" dependencies = [ "arrow-array", "arrow-buffer", - "arrow-cast", "arrow-data", "arrow-schema", "flatbuffers", @@ -313,9 +315,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb22284c5a2a01d73cebfd88a33511a3234ab45d66086b2ca2d1228c3498e445" +checksum = "0ee5b4ca98a7fb2efb9ab3309a5d1c88b5116997ff93f3147efdc1062a6158e9" dependencies = [ "arrow-array", "arrow-buffer", @@ -326,33 +328,32 @@ dependencies = [ "half", "indexmap", "lexical-core", + "memchr", "num", "serde", "serde_json", + "simdutf8", ] [[package]] name = "arrow-ord" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" +checksum = "f0a3334a743bd2a1479dbc635540617a3923b4b2f6870f37357339e6b5363c21" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "half", - "num", ] [[package]] name = "arrow-row" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" +checksum = "8d1d7a7291d2c5107e92140f75257a99343956871f3d3ab33a7b41532f79cb68" dependencies = [ - "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -362,18 +363,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" +checksum = "39cfaf5e440be44db5413b75b72c2a87c1f8f0627117d110264048f2969b99e9" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", ] [[package]] name = "arrow-select" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" +checksum = "69efcd706420e52cd44f5c4358d279801993846d1c2a8e52111853d61d55a619" dependencies = [ "ahash", "arrow-array", @@ -385,9 +386,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" +checksum = "a21546b337ab304a32cfc0770f671db7411787586b45b78b4593ae78e64e2b03" dependencies = [ "arrow-array", "arrow-buffer", @@ -397,7 +398,7 @@ dependencies = [ "memchr", "num", "regex", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -423,34 +424,16 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "async-compression" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fec134f64e2bc57411226dfc4e52dec859ddfc7e711fc5e07b612584f000e4aa" -dependencies = [ - "bzip2", - "flate2", - "futures-core", - "futures-io", - "memchr", - "pin-project-lite", - "tokio", - "xz2", - "zstd", - "zstd-safe", -] - [[package]] name = "async-executor" -version = "1.13.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7ebdfa2ebdab6b1760375fa7d6f382b9f486eac35fc994625a00e89280bdbb7" +checksum = "30ca9a001c1e8ba5149f91a74362376cc6bc5b919d92d988668657bd570bdcec" dependencies = [ "async-task", "concurrent-queue", - "fastrand 2.1.0", - "futures-lite 2.3.0", + "fastrand", + "futures-lite", "slab", ] @@ -462,59 +445,30 @@ checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c" dependencies = [ "async-channel 2.3.1", "async-executor", - "async-io 2.3.3", - "async-lock 3.4.0", + "async-io", + "async-lock", "blocking", - "futures-lite 2.3.0", + "futures-lite", "once_cell", ] [[package]] name = "async-io" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" -dependencies = [ - "async-lock 2.8.0", - "autocfg", - "cfg-if", - "concurrent-queue", - "futures-lite 1.13.0", - "log", - "parking", - "polling 2.8.0", - "rustix 0.37.27", - "slab", - "socket2 0.4.10", - "waker-fn", -] - -[[package]] -name = "async-io" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6baa8f0178795da0e71bc42c9e5d13261aac7ee549853162e66a241ba17964" +checksum = "43a2b323ccce0a1d90b449fd71f2a06ca7faa7c54c2751f06c9bd851fc061059" dependencies = [ - "async-lock 3.4.0", + "async-lock", "cfg-if", "concurrent-queue", "futures-io", - "futures-lite 2.3.0", + "futures-lite", "parking", - "polling 3.7.2", - "rustix 0.38.34", + "polling", + "rustix 0.38.44", "slab", "tracing", - "windows-sys 0.52.0", -] - -[[package]] -name = "async-lock" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" -dependencies = [ - "event-listener 2.5.3", + "windows-sys 0.59.0", ] [[package]] @@ -523,7 +477,7 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "event-listener 5.3.1", + "event-listener 5.4.0", "event-listener-strategy", "pin-project-lite", ] @@ -545,24 +499,24 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "async-std" -version = "1.12.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62565bb4402e926b29953c785397c6dc0391b7b446e45008b0049eb43cec6f5d" +checksum = "730294c1c08c2e0f85759590518f6333f0d5a0a766a27d519c1b244c3dfd8a24" dependencies = [ "async-channel 1.9.0", "async-global-executor", - "async-io 1.13.0", - "async-lock 2.8.0", + "async-io", + "async-lock", "crossbeam-utils", "futures-channel", "futures-core", "futures-io", - "futures-lite 1.13.0", + "futures-lite", "gloo-timers", "kv-log-macro", "log", @@ -582,13 +536,13 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.81" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -614,15 +568,15 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.5" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e95816a168520d72c0e7680c405a5a8c1fb6a035b4bc4b9d7b0de8e1a941697" +checksum = "8c39646d1a6b51240a1a23bb57ea4eebede7e16fbc237fdc876980233dcecb4f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -637,9 +591,9 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.0", + "fastrand", "hex", - "http 0.2.12", + "http 1.3.1", "ring", "time", "tokio", @@ -650,9 +604,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" +checksum = "4471bef4c22a06d2c7a1b6492493d3fdf24a805323109d6874f9c94d5906ac14" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -660,22 +614,46 @@ dependencies = [ "zeroize", ] +[[package]] +name = "aws-lc-rs" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9f7720b74ed28ca77f90769a71fd8c637a0137f6fae4ae947e1050229cff57f" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "aws-runtime" -version = "1.4.2" +version = "1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2424565416eef55906f9f8cece2072b6b6a76075e3ff81483ebe938a89a4c05f" +checksum = "0aff45ffe35196e593ea3b9dd65b320e51e2dda95aff4390bc459e461d09c6ad" dependencies = [ "aws-credential-types", "aws-sigv4", "aws-smithy-async", + "aws-smithy-eventstream", "aws-smithy-http", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.0", + "fastrand", "http 0.2.12", "http-body 0.4.6", "once_cell", @@ -687,32 +665,67 @@ dependencies = [ [[package]] name = "aws-sdk-dynamodb" -version = "1.44.0" +version = "1.71.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e04d98940e69f94525e47f5dda2e28919b81c229a8d25c941be31104c6a4afa8" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-s3" +version = "1.82.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecfba3908f1ecc5f05beacfd44ac86c25be29bf070bb2e32a0d4c423858e13bd" +checksum = "e6eab2900764411ab01c8e91a76fd11a63b4e12bc3da97d9e14a0ce1343d86d3" dependencies = [ "aws-credential-types", "aws-runtime", + "aws-sigv4", "aws-smithy-async", + "aws-smithy-checksums", + "aws-smithy-eventstream", "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", + "aws-smithy-xml", "aws-types", "bytes", - "fastrand 2.1.0", + "fastrand", + "hex", + "hmac", "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "lru", "once_cell", + "percent-encoding", "regex-lite", + "sha2", "tracing", + "url", ] [[package]] name = "aws-sdk-sso" -version = "1.41.0" +version = "1.64.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af0a3f676cba2c079c9563acc9233998c8951cdbe38629a0bef3c8c1b02f3658" +checksum = "02d4bdb0e5f80f0689e61c77ab678b2b9304af329616af38aef5b6b967b8e736" dependencies = [ "aws-credential-types", "aws-runtime", @@ -724,6 +737,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -732,9 +746,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.42.0" +version = "1.65.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91b6a04495547162cf52b075e3c15a17ab6608bf9c5785d3e5a5509b3f09f5c" +checksum = "acbbb3ce8da257aedbccdcb1aadafbbb6a5fe9adf445db0e1ea897bdc7e22d08" dependencies = [ "aws-credential-types", "aws-runtime", @@ -746,6 +760,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -754,9 +769,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.41.0" +version = "1.65.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99c56bcd6a56cab7933980a54148b476a5a69a7694e3874d9aa2a566f150447d" +checksum = "96a78a8f50a1630db757b60f679c8226a8a70ee2ab5f5e6e51dc67f6c61c7cfd" dependencies = [ "aws-credential-types", "aws-runtime", @@ -769,6 +784,7 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -777,50 +793,91 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.3" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5df1b0fa6be58efe9d4ccc257df0a53b89cd8909e86591a13ca54817c87517be" +checksum = "69d03c3c05ff80d54ff860fe38c726f6f494c639ae975203a101335f223386db" dependencies = [ "aws-credential-types", + "aws-smithy-eventstream", "aws-smithy-http", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", + "crypto-bigint 0.5.5", "form_urlencoded", "hex", "hmac", "http 0.2.12", - "http 1.1.0", + "http 1.3.1", "once_cell", + "p256", "percent-encoding", + "ring", "sha2", + "subtle", "time", "tracing", + "zeroize", ] [[package]] name = "aws-smithy-async" -version = "1.2.1" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62220bc6e97f946ddd51b5f1361f78996e704677afc518a4ff66b7a72ea1378c" +checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" dependencies = [ "futures-util", "pin-project-lite", "tokio", ] +[[package]] +name = "aws-smithy-checksums" +version = "0.63.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65d21e1ba6f2cdec92044f904356a19f5ad86961acf015741106cdfafd747c0" +dependencies = [ + "aws-smithy-http", + "aws-smithy-types", + "bytes", + "crc32c", + "crc32fast", + "crc64fast-nvme", + "hex", + "http 0.2.12", + "http-body 0.4.6", + "md-5", + "pin-project-lite", + "sha1", + "sha2", + "tracing", +] + +[[package]] +name = "aws-smithy-eventstream" +version = "0.60.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c45d3dddac16c5c59d553ece225a88870cf81b7b813c9cc17b78cf4685eac7a" +dependencies = [ + "aws-smithy-types", + "bytes", + "crc32fast", +] + [[package]] name = "aws-smithy-http" -version = "0.60.10" +version = "0.62.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01dbcb6e2588fd64cfb6d7529661b06466419e4c54ed1c62d6510d2d0350a728" +checksum = "c5949124d11e538ca21142d1fba61ab0a2a2c1bc3ed323cdb3e4b878bfb83166" dependencies = [ + "aws-smithy-eventstream", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "once_cell", "percent-encoding", @@ -829,15 +886,53 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws-smithy-http-client" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8aff1159006441d02e57204bf57a1b890ba68bedb6904ffd2873c1c4c11c546b" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "h2 0.4.8", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper 1.6.0", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.5", + "hyper-util", + "pin-project-lite", + "rustls 0.21.12", + "rustls 0.23.25", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tower", + "tracing", +] + [[package]] name = "aws-smithy-json" -version = "0.60.7" +version = "0.61.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" +checksum = "92144e45819cae7dc62af23eac5a038a58aa544432d2102609654376a900bd07" dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-observability" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445d065e76bc1ef54963db400319f1dd3ebb3e0a74af20f7f7630625b0cc7cc0" +dependencies = [ + "aws-smithy-runtime-api", + "once_cell", +] + [[package]] name = "aws-smithy-query" version = "0.60.7" @@ -850,42 +945,40 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "0152749e17ce4d1b47c7747bdfec09dac1ccafdcbc741ebf9daa2a373356730f" dependencies = [ "aws-smithy-async", "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", - "fastrand 2.1.0", - "h2 0.3.26", + "fastrand", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "httparse", - "hyper 0.14.30", - "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", "pin-utils", - "rustls 0.21.12", "tokio", "tracing", ] [[package]] name = "aws-smithy-runtime-api" -version = "1.7.2" +version = "1.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" +checksum = "3da37cf5d57011cb1753456518ec76e31691f1f474b73934a284eb2a1c76510f" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.1.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -894,16 +987,16 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.4" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "273dcdfd762fae3e1650b8024624e7cd50e484e37abdab73a7a706188ad34543" +checksum = "836155caafba616c0ff9b07944324785de2ab016141c3550bd1c07882f8cee8f" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.1.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -920,18 +1013,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.8" +version = "0.60.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d123fbc2a4adc3c301652ba8e149bf4bc1d1725affb9784eb20c953ace06bf55" +checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.3" +version = "1.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5221b91b3e441e6675310829fd8984801b772cb1546ef6c0e54dec9f1ac13fef" +checksum = "3873f8deed8927ce8d04487630dc9ff73193bab64742a61d050e57a68dec4125" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -943,19 +1036,25 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.73" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", + "windows-targets 0.52.6", ] +[[package]] +name = "base16ct" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" + [[package]] name = "base64" version = "0.21.7" @@ -978,20 +1077,71 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" + +[[package]] +name = "bigdecimal" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a22f228ab7a1b23027ccc6c350b72868017af7ea8356fbdf19f8d991c690013" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.100", + "which", +] + [[package]] name = "bit-set" -version = "0.5.3" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" dependencies = [ "bit-vec", ] [[package]] name = "bit-vec" -version = "0.6.3" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bitflags" @@ -1001,9 +1151,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" [[package]] name = "bitpacking" @@ -1037,9 +1187,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.3" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9ec96fe9a81b5e365f9db71fe00edc4fe4ca2cc7dcb7861f0603012a7caa210" +checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3" dependencies = [ "arrayref", "arrayvec", @@ -1066,7 +1216,7 @@ dependencies = [ "async-channel 2.3.1", "async-task", "futures-io", - "futures-lite 2.3.0", + "futures-lite", "piper", ] @@ -1083,13 +1233,13 @@ dependencies = [ [[package]] name = "brotli" -version = "6.0.0" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", - "brotli-decompressor 4.0.1", + "brotli-decompressor 4.0.2", ] [[package]] @@ -1104,9 +1254,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.1" +version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1114,15 +1264,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" [[package]] name = "byteorder" @@ -1132,9 +1282,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.1" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "bytes-utils" @@ -1146,27 +1296,6 @@ dependencies = [ "either", ] -[[package]] -name = "bzip2" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" -dependencies = [ - "bzip2-sys", - "libc", -] - -[[package]] -name = "bzip2-sys" -version = "0.1.11+1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "cast" version = "0.3.0" @@ -1175,50 +1304,75 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.7" +version = "1.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" +checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c" dependencies = [ "jobserver", "libc", + "shlex", ] [[package]] -name = "census" -version = "0.4.2" +name = "cedarwood" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4c707c6a209cbe82d10abd08e1ea8995e9ea937d2550646e02798948992be0" - +checksum = "6d910bedd62c24733263d0bed247460853c9d22e8956bd4cd964302095e04e90" +dependencies = [ + "smallvec", +] + +[[package]] +name = "census" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f4c707c6a209cbe82d10abd08e1ea8995e9ea937d2550646e02798948992be0" + [[package]] name = "cesu8" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "chrono-tz" -version = "0.9.0" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" +checksum = "efdce149c370f133a071ca8ef6ea340b7b88748ab0810097a9e2976eaa34b4f3" dependencies = [ "chrono", "chrono-tz-build", @@ -1227,12 +1381,11 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" +checksum = "8f10f8c9340e31fc120ff885fcdb54a0b48e474bbd77cab557f0c30a3e569402" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] @@ -1263,11 +1416,22 @@ dependencies = [ "half", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" -version = "4.5.13" +version = "4.5.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fbb260a053428790f3de475e304ff84cdbc4face759ea7a3e64c1edd938a7fc" +checksum = "d8aa86934b44c19c50f87cc2790e19f54f7a67aedb64101c2e1a2e5ecfb73944" dependencies = [ "clap_builder", "clap_derive", @@ -1275,9 +1439,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.13" +version = "4.5.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64b17d7ea74e9f833c7dbf2cbe4fb12ff26783eda4782a8975b72f895c9b4d99" +checksum = "2414dbb2dd0695280da6ea9261e327479e9d37b0630f6b53ba2a11c60c679fd9" dependencies = [ "anstream", "anstyle", @@ -1287,27 +1451,36 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.13" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] [[package]] name = "colorchoice" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "combine" @@ -1321,12 +1494,11 @@ dependencies = [ [[package]] name = "comfy-table" -version = "7.1.1" +version = "7.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +checksum = "4a65ebfec4fb190b6f90e944a817d60499ee0744e582530e2c9900a22e591d9a" dependencies = [ - "strum", - "strum_macros", + "unicode-segmentation", "unicode-width", ] @@ -1339,6 +1511,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-random" version = "0.1.18" @@ -1354,16 +1532,16 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom", + "getrandom 0.2.15", "once_cell", "tiny-keccak", ] [[package]] name = "constant_time_eq" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" [[package]] name = "convert_case" @@ -1384,26 +1562,45 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core2" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] [[package]] name = "cpp_demangle" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8227005286ec39567949b33df9896bcadfa6051bccca2488129f108ca23119" +checksum = "96e58d342ad113c2b878f16d5d034c03be492ae460cdbc02b7f0f2284d310c7d" dependencies = [ "cfg-if", ] [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1423,6 +1620,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32c" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +dependencies = [ + "rustc_version", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -1432,6 +1638,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crc64fast-nvme" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4955638f00a809894c947f85a024020a20815b65a5eea633798ea7924edab2b3" +dependencies = [ + "crc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -1472,18 +1687,18 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.13" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -1500,24 +1715,46 @@ dependencies = [ [[package]] name = "crossbeam-queue" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" + +[[package]] +name = "crypto-bigint" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef2b4b23cddf68b89b8f8069890e8c270d54e2d5fe1b143820234805e4cb17ef" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + +[[package]] +name = "crypto-bigint" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] [[package]] name = "crypto-common" @@ -1531,9 +1768,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -1543,26 +1780,54 @@ dependencies = [ [[package]] name = "csv-core" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" dependencies = [ "memchr", ] [[package]] -name = "dashmap" -version = "5.5.3" +name = "darling" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "cfg-if", - "hashbrown", - "lock_api", - "once_cell", - "parking_lot_core", + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.100", ] +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "dary_heap" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" + [[package]] name = "dashmap" version = "6.1.0" @@ -1571,7 +1836,7 @@ checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" dependencies = [ "cfg-if", "crossbeam-utils", - "hashbrown", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -1579,150 +1844,217 @@ dependencies = [ [[package]] name = "datafusion" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4fd4a99fc70d40ef7e52b243b4a399c3f8d353a40d5ecb200deee05e49c61bb" +checksum = "914e6f9525599579abbd90b0f7a55afcaaaa40350b9e9ed52563f126dfe45fd3" dependencies = [ - "ahash", "arrow", - "arrow-array", "arrow-ipc", "arrow-schema", - "async-compression", "async-trait", "bytes", - "bzip2", "chrono", - "dashmap 6.1.0", "datafusion-catalog", + "datafusion-catalog-listing", "datafusion-common", "datafusion-common-runtime", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-nested", + "datafusion-functions-table", + "datafusion-functions-window", + "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", - "flate2", "futures", - "glob", - "half", - "hashbrown", - "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", - "num_cpus", "object_store", "parking_lot", "parquet", - "paste", - "pin-project-lite", - "rand", + "rand 0.8.5", + "regex", "sqlparser", "tempfile", "tokio", - "tokio-util", "url", "uuid", - "xz2", - "zstd", ] [[package]] name = "datafusion-catalog" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13b3cfbd84c6003594ae1972314e3df303a27ce8ce755fcea3240c90f4c0529" +checksum = "998a6549e6ee4ee3980e05590b2960446a56b343ea30199ef38acd0e0b9036e2" dependencies = [ - "arrow-schema", + "arrow", + "async-trait", + "dashmap", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", + "datafusion-sql", + "futures", + "itertools 0.14.0", + "log", + "parking_lot", +] + +[[package]] +name = "datafusion-catalog-listing" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5ac10096a5b3c0d8a227176c0e543606860842e943594ccddb45cf42a526e43" +dependencies = [ + "arrow", "async-trait", + "datafusion-catalog", "datafusion-common", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", + "futures", + "log", + "object_store", + "tokio", ] [[package]] name = "datafusion-common" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44fdbc877e3e40dcf88cc8f283d9f5c8851f0a3aa07fee657b1b75ac1ad49b9c" +checksum = "1f53d7ec508e1b3f68bd301cee3f649834fad51eff9240d898a4b2614cfd0a7a" dependencies = [ "ahash", "arrow", - "arrow-array", - "arrow-buffer", - "arrow-schema", - "chrono", + "arrow-ipc", + "base64 0.22.1", "half", - "hashbrown", - "instant", + "hashbrown 0.14.5", + "indexmap", "libc", - "num_cpus", + "log", "object_store", "parquet", + "paste", "sqlparser", + "tokio", + "web-time", ] [[package]] name = "datafusion-common-runtime" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7496d1f664179f6ce3a5cbef6566056ccaf3ea4aa72cc455f80e62c1dd86b1" +checksum = "e0fcf41523b22e14cc349b01526e8b9f59206653037f2949a4adbfde5f8cb668" dependencies = [ + "log", "tokio", ] [[package]] -name = "datafusion-execution" -version = "41.0.0" +name = "datafusion-datasource" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799e70968c815b611116951e3dd876aef04bf217da31b72eec01ee6a959336a1" +checksum = "cf7f37ad8b6e88b46c7eeab3236147d32ea64b823544f498455a8d9042839c92" dependencies = [ "arrow", + "async-trait", + "bytes", "chrono", - "dashmap 6.1.0", + "datafusion-catalog", + "datafusion-common", + "datafusion-common-runtime", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "futures", + "glob", + "itertools 0.14.0", + "log", + "object_store", + "rand 0.8.5", + "tokio", + "url", +] + +[[package]] +name = "datafusion-doc" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7db7a0239fd060f359dc56c6e7db726abaa92babaed2fb2e91c3a8b2fff8b256" + +[[package]] +name = "datafusion-execution" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0938f9e5b6bc5782be4111cdfb70c02b7b5451bf34fd57e4de062a7f7c4e31f1" +dependencies = [ + "arrow", + "dashmap", "datafusion-common", "datafusion-expr", "futures", - "hashbrown", "log", "object_store", "parking_lot", - "rand", + "rand 0.8.5", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c1841c409d9518c17971d15c9bae62e629eb937e6fb6c68cd32e9186f8b30d2" +checksum = "b36c28b00b00019a8695ad7f1a53ee1673487b90322ecbd604e2cf32894eb14f" dependencies = [ - "ahash", "arrow", - "arrow-array", - "arrow-buffer", "chrono", "datafusion-common", + "datafusion-doc", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", + "datafusion-functions-window-common", + "datafusion-physical-expr-common", + "indexmap", "paste", "serde_json", "sqlparser", - "strum", - "strum_macros", +] + +[[package]] +name = "datafusion-expr-common" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18f0a851a436c5a2139189eb4617a54e6a9ccb9edc96c4b3c83b3bb7c58b950e" +dependencies = [ + "arrow", + "datafusion-common", + "indexmap", + "itertools 0.14.0", + "paste", ] [[package]] name = "datafusion-functions" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8e481cf34d2a444bd8fa09b65945f0ce83dc92df8665b761505b3d9f351bebb" +checksum = "e3196e37d7b65469fb79fee4f05e5bb58a456831035f9a38aa5919aeb3298d40" dependencies = [ "arrow", "arrow-buffer", @@ -1731,14 +2063,16 @@ dependencies = [ "blake3", "chrono", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", - "hashbrown", + "datafusion-expr-common", + "datafusion-macros", "hex", - "itertools 0.12.1", + "itertools 0.14.0", "log", "md-5", - "rand", + "rand 0.8.5", "regex", "sha2", "unicode-segmentation", @@ -1747,130 +2081,193 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b4ece19f73c02727e5e8654d79cd5652de371352c1df3c4ac3e419ecd6943fb" +checksum = "adfc2d074d5ee4d9354fdcc9283d5b2b9037849237ddecb8942a29144b77ca05" dependencies = [ "ahash", "arrow", - "arrow-schema", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate-common", + "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", + "half", "log", "paste", - "sqlparser", +] + +[[package]] +name = "datafusion-functions-aggregate-common" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cbceba0f98d921309a9121b702bcd49289d383684cccabf9a92cda1602f3bbb" +dependencies = [ + "ahash", + "arrow", + "datafusion-common", + "datafusion-expr-common", + "datafusion-physical-expr-common", ] [[package]] name = "datafusion-functions-nested" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1474552cc824e8c9c88177d454db5781d4b66757d4aca75719306b8343a5e8d" +checksum = "170e27ce4baa27113ddf5f77f1a7ec484b0dbeda0c7abbd4bad3fc609c8ab71a" dependencies = [ "arrow", - "arrow-array", - "arrow-buffer", "arrow-ord", - "arrow-schema", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "itertools 0.12.1", + "datafusion-macros", + "datafusion-physical-expr-common", + "itertools 0.14.0", "log", "paste", - "rand", ] [[package]] -name = "datafusion-optimizer" -version = "41.0.0" +name = "datafusion-functions-table" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "791ff56f55608bc542d1ea7a68a64bdc86a9413f5a381d06a39fd49c2a3ab906" +checksum = "7d3a06a7f0817ded87b026a437e7e51de7f59d48173b0a4e803aa896a7bd6bb5" dependencies = [ "arrow", "async-trait", + "datafusion-catalog", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-plan", + "parking_lot", + "paste", +] + +[[package]] +name = "datafusion-functions-window" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6c608b66496a1e05e3d196131eb9bebea579eed1f59e88d962baf3dda853bc6" +dependencies = [ + "datafusion-common", + "datafusion-doc", + "datafusion-expr", + "datafusion-functions-window-common", + "datafusion-macros", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "log", + "paste", +] + +[[package]] +name = "datafusion-functions-window-common" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da2f9d83348957b4ad0cd87b5cb9445f2651863a36592fe5484d43b49a5f8d82" +dependencies = [ + "datafusion-common", + "datafusion-physical-expr-common", +] + +[[package]] +name = "datafusion-macros" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4800e1ff7ecf8f310887e9b54c9c444b8e215ccbc7b21c2f244cfae373b1ece7" +dependencies = [ + "datafusion-expr", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "datafusion-optimizer" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "971c51c54cd309001376fae752fb15a6b41750b6d1552345c46afbfb6458801b" +dependencies = [ + "arrow", "chrono", "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown", "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", - "paste", - "regex-syntax 0.8.4", + "regex", + "regex-syntax 0.8.5", ] [[package]] name = "datafusion-physical-expr" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a223962b3041304a3e20ed07a21d5de3d88d7e4e71ca192135db6d24e3365a4" +checksum = "e1447c2c6bc8674a16be4786b4abf528c302803fafa186aa6275692570e64d85" dependencies = [ "ahash", "arrow", - "arrow-array", - "arrow-buffer", - "arrow-ord", - "arrow-schema", - "arrow-string", - "base64 0.22.1", - "chrono", "datafusion-common", - "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", - "hashbrown", - "hex", + "hashbrown 0.14.5", "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", "paste", - "petgraph", - "regex", + "petgraph 0.7.1", ] [[package]] name = "datafusion-physical-expr-common" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db5e7d8532a1601cd916881db87a70b0a599900d23f3db2897d389032da53bc6" +checksum = "69f8c25dcd069073a75b3d2840a79d0f81e64bdd2c05f2d3d18939afb36a7dcb" dependencies = [ "ahash", "arrow", "datafusion-common", - "datafusion-expr", - "hashbrown", - "rand", + "datafusion-expr-common", + "hashbrown 0.14.5", + "itertools 0.14.0", ] [[package]] name = "datafusion-physical-optimizer" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb9c78f308e050f5004671039786a925c3fee83b90004e9fcfd328d7febdcc0" +checksum = "68da5266b5b9847c11d1b3404ee96b1d423814e1973e1ad3789131e5ec912763" dependencies = [ + "arrow", "datafusion-common", "datafusion-execution", + "datafusion-expr", + "datafusion-expr-common", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", + "itertools 0.14.0", + "log", ] [[package]] name = "datafusion-physical-plan" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d1116949432eb2d30f6362707e2846d942e491052a206f2ddcb42d08aea1ffe" +checksum = "88cc160df00e413e370b3b259c8ea7bfbebc134d32de16325950e9e923846b7f" dependencies = [ "ahash", "arrow", - "arrow-array", - "arrow-buffer", "arrow-ord", "arrow-schema", "async-trait", @@ -1879,54 +2276,52 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", + "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-physical-expr-common", "futures", "half", - "hashbrown", + "hashbrown 0.14.5", "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", - "once_cell", "parking_lot", "pin-project-lite", - "rand", "tokio", ] [[package]] name = "datafusion-sql" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45d0180711165fe94015d7c4123eb3e1cf5fb60b1506453200b8d1ce666bef0" +checksum = "325a212b67b677c0eb91447bf9a11b630f9fc4f62d8e5d145bf859f5a6b29e64" dependencies = [ "arrow", - "arrow-array", - "arrow-schema", + "bigdecimal", "datafusion-common", "datafusion-expr", + "indexmap", "log", "regex", "sqlparser", - "strum", ] [[package]] name = "datafusion-substrait" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0a0055aa98246c79f98f0d03df11f16cb7adc87818d02d4413e3f3cdadbbee" +checksum = "2c2be3226a683e02cff65181e66e62eba9f812ed0e9b7ec8fe11ac8dabf1a73f" dependencies = [ - "arrow-buffer", "async-recursion", + "async-trait", "chrono", "datafusion", - "itertools 0.12.1", + "itertools 0.14.0", "object_store", "pbjson-types", - "prost", - "substrait 0.36.0", + "prost 0.13.5", + "substrait 0.53.2", + "tokio", "url", ] @@ -1959,16 +2354,57 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "der" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1a467a65c5e759bce6e65eaf91cc29f466cdc57cb65777bd646872a8a1fd4de" +dependencies = [ + "const-oid", + "zeroize", +] + [[package]] name = "deranged" -version = "0.3.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", "serde", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.100", +] + [[package]] name = "diff" version = "0.1.13" @@ -2008,10 +2444,15 @@ dependencies = [ ] [[package]] -name = "doc-comment" -version = "0.3.3" +name = "displaydoc" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] [[package]] name = "downcast" @@ -2025,86 +2466,195 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" + +[[package]] +name = "ecdsa" +version = "0.14.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413301934810f597c1d19ca71c8710e99a3f1ba28a0d2ebc01551a2daeea3c5c" +dependencies = [ + "der", + "elliptic-curve", + "rfc6979", + "signature", +] [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] -name = "env_filter" -version = "0.1.2" +name = "elliptic-curve" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +checksum = "e7bb888ab5300a19b8e5bceef25ac745ad065f3c9f7efc6de1b91958110891d3" dependencies = [ - "log", + "base16ct", + "crypto-bigint 0.4.9", + "der", + "digest", + "ff", + "generic-array", + "group", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", ] [[package]] -name = "env_logger" -version = "0.10.2" +name = "encoding" +version = "0.2.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +checksum = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" +dependencies = [ + "encoding-index-japanese", + "encoding-index-korean", + "encoding-index-simpchinese", + "encoding-index-singlebyte", + "encoding-index-tradchinese", +] + +[[package]] +name = "encoding-index-japanese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8b2ff42e9a05335dbf8b5c6f7567e5591d0d916ccef4e0b1710d32a0d0c91" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-korean" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dc33fb8e6bcba213fe2f14275f0963fd16f0a02c878e3095ecfdf5bee529d81" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-simpchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87a7194909b9118fc707194baa434a4e3b0fb6a5a757c73c3adb07aa25031f7" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-singlebyte" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3351d5acffb224af9ca265f435b859c7c01537c0849754d3db3fdf2bfe2ae84a" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-tradchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd0e20d5688ce3cab59eb3ef3a2083a5c77bf496cb798dc6fcdb75f323890c18" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding_index_tests" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" + +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "encoding_rs_io" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cc3c5651fb62ab8aa3103998dade57efdd028544bd300516baa31840c252a83" +dependencies = [ + "encoding_rs", +] + +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ - "humantime", - "is-terminal", "log", "regex", - "termcolor", ] [[package]] name = "env_logger" -version = "0.11.5" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", "env_filter", + "jiff", "log", ] [[package]] name = "equator" -version = "0.2.2" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c35da53b5a021d2484a7cc49b2ac7f2d840f8236a286f84202369bd338d761ea" +checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc" dependencies = [ "equator-macro", ] [[package]] name = "equator-macro" -version = "0.2.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" +checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "equivalent" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2126,9 +2676,9 @@ dependencies = [ [[package]] name = "event-listener" -version = "5.3.1" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -2137,45 +2687,46 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.5.2" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" dependencies = [ - "event-listener 5.3.1", + "event-listener 5.4.0", "pin-project-lite", ] [[package]] name = "fastdivide" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59668941c55e5c186b8b58c391629af56774ec768f73c08bbcd56f09348eb00b" +checksum = "9afc2bd4d5a73106dd53d10d73d3401c2f32730ba2c0b93ddb888a8983680471" [[package]] name = "fastrand" -version = "1.9.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] -name = "fastrand" -version = "2.1.0" +name = "ff" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] [[package]] name = "filetime" -version = "0.2.23" +version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" +checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.4.1", - "windows-sys 0.52.0", + "libredox", + "windows-sys 0.59.0", ] [[package]] @@ -2196,11 +2747,17 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" -version = "24.3.25" +version = "24.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" +checksum = "4f1baf0dbf96932ec9a3038d57900329c015b0bfb7b63d904f3bc27e2b02a096" dependencies = [ "bitflags 1.3.2", "rustc_version", @@ -2208,9 +2765,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.31" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" +checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" dependencies = [ "crc32fast", "miniz_oxide", @@ -2222,6 +2779,27 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2233,9 +2811,9 @@ dependencies = [ [[package]] name = "fragile" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" +checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619" [[package]] name = "fs4" @@ -2243,22 +2821,37 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7e180ac76c23b45e767bd7ae9579bc0bb458618c4bc71835926e098e61d15f8" dependencies = [ - "rustix 0.38.34", + "rustix 0.38.44", "windows-sys 0.52.0", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "fsst" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "lance-datagen", - "rand", + "rand 0.8.5", "rand_xoshiro", "test-log", "tokio", ] +[[package]] +name = "fst" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" +dependencies = [ + "utf8-ranges", +] + [[package]] name = "funty" version = "2.0.0" @@ -2267,9 +2860,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -2282,9 +2875,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -2292,15 +2885,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -2309,32 +2902,17 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" - -[[package]] -name = "futures-lite" -version = "1.13.0" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" -dependencies = [ - "fastrand 1.9.0", - "futures-core", - "futures-io", - "memchr", - "parking", - "pin-project-lite", - "waker-fn", -] +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" -version = "2.3.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" +checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532" dependencies = [ - "fastrand 2.1.0", + "fastrand", "futures-core", "futures-io", "parking", @@ -2343,26 +2921,26 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -2372,9 +2950,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -2388,6 +2966,28 @@ dependencies = [ "slab", ] +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "generator" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bd114ceda131d3b1d665eba35788690ad37f5916457286b32ab6fd3c438dd" +dependencies = [ + "cfg-if", + "libc", + "log", + "rustversion", + "windows", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2407,27 +3007,41 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", "wasm-bindgen", ] [[package]] name = "gimli" -version = "0.29.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "gloo-timers" -version = "0.2.6" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c" +checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994" dependencies = [ "futures-channel", "futures-core", @@ -2435,6 +3049,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "group" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "h2" version = "0.3.26" @@ -2456,16 +3081,16 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.1.0", + "http 1.3.1", "indexmap", "slab", "tokio", @@ -2475,9 +3100,9 @@ dependencies = [ [[package]] name = "half" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" dependencies = [ "cfg-if", "crunchy", @@ -2495,10 +3120,15 @@ dependencies = [ ] [[package]] -name = "heck" -version = "0.4.1" +name = "hashbrown" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heck" @@ -2518,6 +3148,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hermit-abi" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" + [[package]] name = "hex" version = "0.4.3" @@ -2533,6 +3169,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "hostname" version = "0.3.1" @@ -2563,9 +3208,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2590,27 +3235,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.1.0", + "http 1.3.1", ] [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", - "http 1.1.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.9.4" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -2620,15 +3265,15 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "humantime" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" dependencies = [ "bytes", "futures-channel", @@ -2641,7 +3286,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.5.7", + "socket2", "tokio", "tower-service", "tracing", @@ -2650,15 +3295,15 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.6", - "http 1.1.0", + "h2 0.4.8", + "http 1.3.1", "http-body 1.0.1", "httparse", "itoa", @@ -2676,7 +3321,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.32", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2686,38 +3331,54 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.3" +version = "0.27.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.1.0", - "hyper 1.4.1", + "http 1.3.1", + "hyper 1.6.0", "hyper-util", - "rustls 0.23.12", - "rustls-native-certs 0.8.0", + "rustls 0.23.25", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", "tower-service", ] [[package]] name = "hyper-util" -version = "0.1.7" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.6.0", + "libc", "pin-project-lite", - "socket2 0.5.7", + "socket2", "tokio", - "tower", "tower-service", "tracing", ] @@ -2733,16 +3394,17 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", - "windows-core", + "windows-core 0.61.0", ] [[package]] @@ -2754,24 +3416,182 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7515e6d781098bf9f7205ab3fc7e9709d34554ae0b21ddbcb5febfa4bc7df11d" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5e8338228bdc8ab83303f16b797e177953730f601a96c25d10cb3ab0daa0cb7" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85fb8799753b75aee8d2a21d7c14d9f38921b54b3dbda10f5a3c7a7b82dba5e2" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "include-flate" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df49c16750695486c1f34de05da5b7438096156466e7f76c38fcdf285cf0113e" +dependencies = [ + "include-flate-codegen", + "lazy_static", + "libflate", +] + +[[package]] +name = "include-flate-codegen" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c5b246c6261be723b85c61ecf87804e8ea4a35cb68be0ff282ed84b95ffe7d7" +dependencies = [ + "libflate", + "proc-macro2", + "quote", + "syn 2.0.100", ] [[package]] name = "indexmap" -version = "2.3.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", ] [[package]] @@ -2816,32 +3636,21 @@ version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d762194228a2f1c11063e46e32e5acb96e66e906382b9eb5441f2e0504bbd5a" -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi 0.3.9", - "libc", - "windows-sys 0.48.0", -] - [[package]] name = "ipnet" -version = "2.9.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" -version = "0.4.12" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ - "hermit-abi 0.3.9", + "hermit-abi 0.5.0", "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2886,11 +3695,68 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "jieba-macros" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c676b32a471d3cfae8dac2ad2f8334cd52e53377733cca8c1fb0a5062fec192" +dependencies = [ + "phf_codegen", +] + +[[package]] +name = "jieba-rs" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "6d1bcad6332969e4d48ee568d430e14ee6dea70740c2549d005d87677ebefb0c" +dependencies = [ + "cedarwood", + "fxhash", + "include-flate", + "jieba-macros", + "lazy_static", + "phf", + "regex", +] + +[[package]] +name = "jiff" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f33145a5cbea837164362c7bd596106eb7c5198f97d1ba6f6ebb3223952e488" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde", +] + +[[package]] +name = "jiff-static" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43ce13c40ec6956157a3635d97a1ee2df323b263f09ea14165131289cb0f5c19" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] [[package]] name = "jni" @@ -2903,7 +3769,7 @@ dependencies = [ "combine", "jni-sys", "log", - "thiserror", + "thiserror 1.0.69", "walkdir", "windows-sys 0.45.0", ] @@ -2916,22 +3782,33 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.2", "libc", ] [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ + "once_cell", "wasm-bindgen", ] +[[package]] +name = "kanaria" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0f9d9652540055ac4fded998a73aca97d965899077ab1212587437da44196ff" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "kv-log-macro" version = "1.0.7" @@ -2943,7 +3820,7 @@ dependencies = [ [[package]] name = "lance" -version = "0.20.0" +version = "0.26.2" dependencies = [ "all_asserts", "approx", @@ -2951,6 +3828,7 @@ dependencies = [ "arrow-arith", "arrow-array", "arrow-buffer", + "arrow-ipc", "arrow-ord", "arrow-row", "arrow-schema", @@ -2958,22 +3836,27 @@ dependencies = [ "async-recursion", "async-trait", "async_cell", + "aws-config", "aws-credential-types", "aws-sdk-dynamodb", + "aws-sdk-s3", "byteorder", "bytes", "chrono", "clap", "criterion", - "dashmap 5.5.3", + "dashmap", "datafusion", + "datafusion-expr", "datafusion-functions", "datafusion-physical-expr", "deepsize", "dirs", - "env_logger 0.10.2", + "either", + "env_logger", "futures", "half", + "humantime", "itertools 0.13.0", "lance-arrow", "lance-core", @@ -2998,10 +3881,10 @@ dependencies = [ "pin-project", "pprof", "pretty_assertions", - "prost", - "prost-build", - "prost-types", - "rand", + "prost 0.12.6", + "prost 0.13.5", + "prost-types 0.13.5", + "rand 0.8.5", "random_word", "roaring", "rstest", @@ -3021,7 +3904,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "arrow-buffer", @@ -3029,15 +3912,16 @@ dependencies = [ "arrow-data", "arrow-schema", "arrow-select", - "getrandom", + "bytes", + "getrandom 0.2.15", "half", "num-traits", - "rand", + "rand 0.8.5", ] [[package]] name = "lance-core" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "arrow-buffer", @@ -3061,8 +3945,8 @@ dependencies = [ "object_store", "pin-project", "proptest", - "prost", - "rand", + "prost 0.13.5", + "rand 0.8.5", "roaring", "serde_json", "snafu", @@ -3076,7 +3960,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -3096,15 +3980,18 @@ dependencies = [ "lance-datagen", "lazy_static", "log", - "prost", + "pin-project", + "prost 0.13.5", "snafu", "substrait-expr", + "tempfile", "tokio", + "tracing", ] [[package]] name = "lance-datagen" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -3115,13 +4002,13 @@ dependencies = [ "futures", "hex", "pprof", - "rand", + "rand 0.8.5", "rand_xoshiro", ] [[package]] name = "lance-encoding" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrayref", "arrow", @@ -3147,13 +4034,15 @@ dependencies = [ "lance-testing", "lazy_static", "log", + "lz4", "num-traits", "paste", "pprof", - "prost", - "prost-build", - "prost-types", - "rand", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", + "protobuf-src", + "rand 0.8.5", "rand_xoshiro", "rstest", "seq-macro", @@ -3167,7 +4056,7 @@ dependencies = [ [[package]] name = "lance-encoding-datafusion" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "arrow-buffer", @@ -3188,10 +4077,11 @@ dependencies = [ "lance-io", "log", "pprof", - "prost", - "prost-build", - "prost-types", - "rand", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", + "protobuf-src", + "rand 0.8.5", "snafu", "test-log", "tokio", @@ -3199,7 +4089,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-arith", "arrow-array", @@ -3227,10 +4117,11 @@ dependencies = [ "pprof", "pretty_assertions", "proptest", - "prost", - "prost-build", - "prost-types", - "rand", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", + "protobuf-src", + "rand 0.8.5", "roaring", "snafu", "tempfile", @@ -3241,7 +4132,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "0.20.0" +version = "0.26.2" dependencies = [ "approx", "arrow", @@ -3262,9 +4153,13 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "deepsize", + "dirs", + "env_logger", + "fst", "futures", "half", "itertools 0.13.0", + "jieba-rs", "lance-arrow", "lance-core", "lance-datafusion", @@ -3276,14 +4171,17 @@ dependencies = [ "lance-table", "lance-testing", "lazy_static", + "lindera", + "lindera-tantivy", "log", "moka", "num-traits", "object_store", "pprof", - "prost", - "prost-build", - "rand", + "prost 0.13.5", + "prost-build 0.13.5", + "protobuf-src", + "rand 0.8.5", "random_word", "rayon", "roaring", @@ -3300,7 +4198,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-arith", @@ -3331,9 +4229,9 @@ dependencies = [ "path_abs", "pin-project", "pprof", - "prost", - "prost-build", - "rand", + "prost 0.13.5", + "rand 0.8.5", + "rstest", "shellexpand", "snafu", "tempfile", @@ -3345,19 +4243,22 @@ dependencies = [ [[package]] name = "lance-jni" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-schema", "datafusion", "jni", "lance", + "lance-core", "lance-datafusion", "lance-encoding", + "lance-file", "lance-index", "lance-io", "lance-linalg", "lazy_static", + "object_store", "serde", "serde_json", "snafu", @@ -3366,7 +4267,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "0.20.0" +version = "0.26.2" dependencies = [ "approx", "arrow-arith", @@ -3387,7 +4288,7 @@ dependencies = [ "num-traits", "pprof", "proptest", - "rand", + "rand 0.8.5", "rayon", "tokio", "tracing", @@ -3395,7 +4296,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -3422,10 +4323,11 @@ dependencies = [ "pprof", "pretty_assertions", "proptest", - "prost", - "prost-build", - "prost-types", - "rand", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", + "protobuf-src", + "rand 0.8.5", "rangemap", "roaring", "serde", @@ -3439,22 +4341,22 @@ dependencies = [ [[package]] name = "lance-test-macros" -version = "0.20.0" +version = "0.26.2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "lance-testing" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "arrow-schema", "lance-arrow", "num-traits", - "rand", + "rand 0.8.5", ] [[package]] @@ -3483,6 +4385,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "levenshtein_automata" version = "0.2.1" @@ -3491,9 +4399,9 @@ checksum = "0c2cdeb66e45e9f36bfad5bbdb4d2384e70936afbee843c6f6543f0c551ebb25" [[package]] name = "lexical-core" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -3504,9 +4412,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -3515,9 +4423,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.6" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" dependencies = [ "lexical-util", "static_assertions", @@ -3525,18 +4433,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" dependencies = [ "lexical-util", "lexical-write-integer", @@ -3544,48 +4452,150 @@ dependencies = [ ] [[package]] -name = "lexical-write-integer" -version = "0.8.5" +name = "lexical-write-integer" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "libc" +version = "0.2.171" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" + +[[package]] +name = "libflate" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45d9dfdc14ea4ef0900c1cddbc8dcd553fbaacd8a4a282cf4018ae9dd04fb21e" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77", +] + +[[package]] +name = "libflate_lz77" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" +dependencies = [ + "core2", + "hashbrown 0.14.5", + "rle-decode-fast", +] + +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets 0.48.5", +] + +[[package]] +name = "libm" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" + +[[package]] +name = "libredox" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "lexical-util", - "static_assertions", + "bitflags 2.9.0", + "libc", + "redox_syscall", ] [[package]] -name = "libc" -version = "0.2.155" +name = "lindera" +version = "0.38.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "fff887f4b98539fb5f879ede50e17eb7eaafa5622c252cffe8280f42cafc6b7d" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "csv", + "kanaria", + "lindera-dictionary", + "once_cell", + "regex", + "serde", + "serde_json", + "serde_yaml", + "strum", + "strum_macros", + "unicode-blocks", + "unicode-normalization", + "unicode-segmentation", + "yada", +] [[package]] -name = "libm" -version = "0.2.8" +name = "lindera-dictionary" +version = "0.38.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "ec716483ceb95aa84ac262cb766eef314b24257c343ca230daa71f856a278fe4" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "csv", + "derive_builder", + "encoding", + "encoding_rs", + "encoding_rs_io", + "flate2", + "glob", + "log", + "once_cell", + "reqwest", + "serde", + "tar", + "thiserror 2.0.12", + "yada", +] [[package]] -name = "libredox" -version = "0.1.3" +name = "lindera-tantivy" +version = "0.38.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +checksum = "261c87882a909fd17db4dd797e4dc2aac3992bdbbb4e2900d1362a1e0746266f" dependencies = [ - "bitflags 2.6.0", - "libc", + "lindera", + "tantivy", + "tantivy-tokenizer-api", ] [[package]] name = "linux-raw-sys" -version = "0.3.8" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" + +[[package]] +name = "litemap" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "lock_api" @@ -3599,20 +4609,52 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" dependencies = [ "value-bag", ] +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru" -version = "0.12.4" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" +dependencies = [ + "hashbrown 0.15.2", +] + +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" dependencies = [ - "hashbrown", + "cc", + "libc", ] [[package]] @@ -3678,9 +4720,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memmap2" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" dependencies = [ "libc", ] @@ -3699,22 +4741,21 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.4" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430" dependencies = [ - "adler", + "adler2", ] [[package]] name = "mio" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi 0.3.9", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -3750,30 +4791,28 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "moka" -version = "0.12.8" +version = "0.12.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f" +checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" dependencies = [ - "async-lock 3.4.0", - "async-trait", + "async-lock", "crossbeam-channel", "crossbeam-epoch", "crossbeam-utils", - "event-listener 5.3.1", + "event-listener 5.4.0", "futures-util", - "once_cell", + "loom", "parking_lot", - "quanta", + "portable-atomic", "rustc_version", "smallvec", "tagptr", - "thiserror", - "triomphe", + "thiserror 1.0.69", "uuid", ] @@ -3789,6 +4828,23 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nix" version = "0.26.4" @@ -3931,35 +4987,36 @@ dependencies = [ [[package]] name = "object" -version = "0.36.2" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.10.2" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" +checksum = "3cfccb68961a56facde1163f9319e0d15743352344e7808a11795fb99698dcaf" dependencies = [ "async-trait", "base64 0.22.1", "bytes", "chrono", "futures", + "httparse", "humantime", - "hyper 1.4.1", + "hyper 1.6.0", "itertools 0.13.0", "md-5", "parking_lot", "percent-encoding", - "quick-xml 0.36.1", - "rand", + "quick-xml 0.37.4", + "rand 0.8.5", "reqwest", "ring", - "rustls-pemfile 2.1.3", + "rustls-pemfile 2.2.0", "serde", "serde_json", "snafu", @@ -3971,27 +5028,65 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "oneshot" -version = "0.1.8" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" +checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "openssl" +version = "0.10.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" +dependencies = [ + "bitflags 2.9.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] [[package]] name = "option-ext" @@ -4010,9 +5105,9 @@ dependencies = [ [[package]] name = "outref" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] name = "overload" @@ -4029,11 +5124,22 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "p256" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" +dependencies = [ + "ecdsa", + "elliptic-curve", + "sha2", +] + [[package]] name = "parking" -version = "2.2.0" +version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" @@ -4053,16 +5159,16 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.3", + "redox_syscall", "smallvec", "windows-targets 0.52.6", ] [[package]] name = "parquet" -version = "52.2.0" +version = "54.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e977b9066b4d3b03555c22bdc442f3fadebd96a39111249113087d0edb2691cd" +checksum = "bfb15796ac6f56b429fd99e33ba133783ad75b27c36b4b5ce06f1f82cc97754e" dependencies = [ "ahash", "arrow-array", @@ -4073,25 +5179,25 @@ dependencies = [ "arrow-schema", "arrow-select", "base64 0.22.1", - "brotli 6.0.0", + "brotli 7.0.0", "bytes", "chrono", "flate2", "futures", "half", - "hashbrown", + "hashbrown 0.15.2", "lz4_flex", "num", "num-bigint", "object_store", "paste", "seq-macro", + "simdutf8", "snap", "thrift", "tokio", "twox-hash", "zstd", - "zstd-sys", ] [[package]] @@ -4123,9 +5229,9 @@ dependencies = [ [[package]] name = "pbjson" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1030c719b0ec2a2d25a5df729d6cff1acf3cc230bf766f4f97833591f7577b90" +checksum = "c7e6349fa080353f4a597daffd05cb81572a9c031a6d4fff7e504947496fcc68" dependencies = [ "base64 0.21.7", "serde", @@ -4133,28 +5239,28 @@ dependencies = [ [[package]] name = "pbjson-build" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2580e33f2292d34be285c5bc3dba5259542b083cfad6037b6d70345f24dcb735" +checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" dependencies = [ - "heck 0.4.1", - "itertools 0.11.0", - "prost", - "prost-types", + "heck", + "itertools 0.13.0", + "prost 0.13.5", + "prost-types 0.13.5", ] [[package]] name = "pbjson-types" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18f596653ba4ac51bdecbb4ef6773bc7f56042dc13927910de1684ad3d32aa12" +checksum = "e54e5e7bfb1652f95bc361d76f3c780d8e526b134b85417e774166ee941f0887" dependencies = [ "bytes", "chrono", "pbjson", "pbjson-build", - "prost", - "prost-build", + "prost 0.13.5", + "prost-build 0.13.5", "serde", ] @@ -4176,24 +5282,34 @@ version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ - "fixedbitset", + "fixedbitset 0.4.2", + "indexmap", +] + +[[package]] +name = "petgraph" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +dependencies = [ + "fixedbitset 0.5.7", "indexmap", ] [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", "phf_shared", @@ -4201,48 +5317,48 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ "siphasher", ] [[package]] name = "pin-project" -version = "1.1.5" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.5" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -4252,26 +5368,36 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "piper" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae1d5c74c9876f070d3e8fd503d748c7d974c3e48da8f41350fa5222ef9b4391" +checksum = "96c8c490f422ef9a4efd2cb5b42b76c8613d7e7dfc1caf667b8a3350a5acc066" dependencies = [ "atomic-waker", - "fastrand 2.1.0", + "fastrand", "futures-io", ] +[[package]] +name = "pkcs8" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plotters" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" dependencies = [ "num-traits", "plotters-backend", @@ -4282,48 +5408,47 @@ dependencies = [ [[package]] name = "plotters-backend" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" [[package]] name = "plotters-svg" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" dependencies = [ "plotters-backend", ] [[package]] name = "polling" -version = "2.8.0" +version = "3.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" +checksum = "a604568c3202727d1507653cb121dbd627a58684eb09a820fd746bee38b4442f" dependencies = [ - "autocfg", - "bitflags 1.3.2", "cfg-if", "concurrent-queue", - "libc", - "log", + "hermit-abi 0.4.0", "pin-project-lite", - "windows-sys 0.48.0", + "rustix 0.38.44", + "tracing", + "windows-sys 0.59.0", ] [[package]] -name = "polling" -version = "3.7.2" +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3ed00ed3fbf728b5816498ecd316d1716eecaced9c0c8d2c5a6740ca214985b" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" dependencies = [ - "cfg-if", - "concurrent-queue", - "hermit-abi 0.4.0", - "pin-project-lite", - "rustix 0.38.34", - "tracing", - "windows-sys 0.52.0", + "portable-atomic", ] [[package]] @@ -4352,23 +5477,23 @@ dependencies = [ "smallvec", "symbolic-demangle", "tempfile", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy", + "zerocopy 0.8.24", ] [[package]] name = "predicates" -version = "3.1.2" +version = "3.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" dependencies = [ "anstyle", "predicates-core", @@ -4376,15 +5501,15 @@ dependencies = [ [[package]] name = "predicates-core" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" [[package]] name = "predicates-tree" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" dependencies = [ "predicates-core", "termtree", @@ -4392,9 +5517,9 @@ dependencies = [ [[package]] name = "pretty_assertions" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af7cee1a6c8a5b9208b3cb1061f10c0cb689087b3d8ce85fb9d2dd7a29b6ba66" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" dependencies = [ "diff", "yansi", @@ -4402,47 +5527,47 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" dependencies = [ "proc-macro2", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" dependencies = [ "toml_edit", ] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" dependencies = [ "unicode-ident", ] [[package]] name = "proptest" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c2511913b88df1637da85cc8d96ec8e43a3f8bb8ccb71ee1ac240d6f3df58d" +checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" dependencies = [ "bit-set", "bit-vec", - "bitflags 2.6.0", + "bitflags 2.9.0", "lazy_static", "num-traits", - "rand", - "rand_chacha", + "rand 0.8.5", + "rand_chacha 0.3.1", "rand_xorshift", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", "rusty-fork", "tempfile", "unarray", @@ -4455,7 +5580,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.12.6", +] + +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive 0.13.5", ] [[package]] @@ -4465,17 +5600,37 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck 0.5.0", + "heck", "itertools 0.12.1", "log", "multimap", "once_cell", - "petgraph", + "petgraph 0.6.5", + "prettyplease", + "prost 0.12.6", + "prost-types 0.12.6", + "regex", + "syn 2.0.100", + "tempfile", +] + +[[package]] +name = "prost-build" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +dependencies = [ + "heck", + "itertools 0.14.0", + "log", + "multimap", + "once_cell", + "petgraph 0.7.1", "prettyplease", - "prost", - "prost-types", + "prost 0.13.5", + "prost-types 0.13.5", "regex", - "syn 2.0.87", + "syn 2.0.100", "tempfile", ] @@ -4489,7 +5644,20 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.100", ] [[package]] @@ -4498,22 +5666,34 @@ version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" dependencies = [ - "prost", + "prost 0.12.6", ] [[package]] -name = "quanta" -version = "0.12.3" +name = "prost-types" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" dependencies = [ - "crossbeam-utils", - "libc", - "once_cell", - "raw-cpuid", - "wasi", - "web-sys", - "winapi", + "prost 0.13.5", +] + +[[package]] +name = "protobuf-src" +version = "2.1.1+27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6217c3504da19b85a3a4b2e9a5183d635822d83507ba0986624b5c05b83bfc40" +dependencies = [ + "cmake", +] + +[[package]] +name = "psm" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f58e5423e24c18cc840e1c98370b3993c6649cd1678b4d24318bcf0a083cbe88" +dependencies = [ + "cc", ] [[package]] @@ -4533,9 +5713,9 @@ dependencies = [ [[package]] name = "quick-xml" -version = "0.36.1" +version = "0.37.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96a05e2e8efddfa51a84ca47cec303fac86c8541b686d37cac5efc0e094417bc" +checksum = "a4ce8c88de324ff838700f36fb6ab86c96df0e3c4ab6ef3a9b2044465cce1369" dependencies = [ "memchr", "serde", @@ -4543,61 +5723,73 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "c3bd15a6f2967aef83887dcb9fec0014580467e33720d073560cf015a5683012" dependencies = [ "bytes", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.0.0", - "rustls 0.23.12", - "socket2 0.5.7", - "thiserror", + "rustc-hash 2.1.1", + "rustls 0.23.25", + "socket2", + "thiserror 2.0.12", "tokio", "tracing", + "web-time", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "b820744eb4dc9b57a3398183639c511b5a26d2ed702cedd3febaa1393caa22cc" dependencies = [ "bytes", - "rand", + "getrandom 0.3.2", + "rand 0.9.0", "ring", - "rustc-hash 2.0.0", - "rustls 0.23.12", + "rustc-hash 2.1.1", + "rustls 0.23.25", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.12", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.5.4" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bffec3605b73c6f1754535084a85229fa8a30f86014e6c81aeec4abb68b0285" +checksum = "541d0f57c6ec747a90738a52741d3221f7960e8ac2f0ff4b1a63680e033b4ab5" dependencies = [ + "cfg_aliases", "libc", "once_cell", - "socket2 0.5.7", + "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "radium" version = "0.7.0" @@ -4611,18 +5803,39 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", + "zerocopy 0.8.24", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", ] [[package]] @@ -4631,7 +5844,16 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.2", ] [[package]] @@ -4641,7 +5863,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", ] [[package]] @@ -4650,7 +5872,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" dependencies = [ - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -4659,7 +5881,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" dependencies = [ - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -4672,7 +5894,7 @@ dependencies = [ "brotli 3.5.0", "once_cell", "paste", - "rand", + "rand 0.8.5", "unicase", ] @@ -4682,15 +5904,6 @@ version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f60fcc7d6849342eff22c4350c8b9a989ee8ceabc4b481253e8946b9fe83d684" -[[package]] -name = "raw-cpuid" -version = "11.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" -dependencies = [ - "bitflags 2.6.0", -] - [[package]] name = "rayon" version = "1.10.0" @@ -4712,44 +5925,55 @@ dependencies = [ ] [[package]] -name = "redox_syscall" -version = "0.4.1" +name = "recursive" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" dependencies = [ - "bitflags 1.3.2", + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn 2.0.100", ] [[package]] name = "redox_syscall" -version = "0.5.3" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", ] [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "regex" -version = "1.10.6" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -4763,13 +5987,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -4786,27 +6010,17 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" - -[[package]] -name = "regress" -version = "0.8.0" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f5f39ba4513916c1b2657b72af6ec671f091cd637992f58d0ede5cae4e5dea0" -dependencies = [ - "hashbrown", - "memchr", -] +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "regress" -version = "0.9.1" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eae2a1ebfecc58aff952ef8ccd364329abe627762f5bf09ff42eb9d98522479" +checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366" dependencies = [ - "hashbrown", + "hashbrown 0.15.2", "memchr", ] @@ -4818,40 +6032,46 @@ checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] name = "reqwest" -version = "0.12.7" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" +checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", - "h2 0.4.6", - "http 1.1.0", + "h2 0.4.8", + "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", - "hyper-rustls 0.27.3", + "hyper 1.6.0", + "hyper-rustls 0.27.5", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.12", - "rustls-native-certs 0.7.3", - "rustls-pemfile 2.1.3", + "rustls 0.23.25", + "rustls-native-certs 0.8.1", + "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", + "system-configuration", "tokio", - "tokio-rustls 0.26.0", + "tokio-native-tls", + "tokio-rustls 0.26.2", "tokio-util", + "tower", "tower-service", "url", "wasm-bindgen", @@ -4861,35 +6081,51 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "rfc6979" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7743f17af12fa0b03b803ba12cd6a8d9483a587e89c69445e3909655c0b9fabb" +dependencies = [ + "crypto-bigint 0.4.9", + "hmac", + "zeroize", +] + [[package]] name = "rgb" -version = "0.8.47" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e12bc8d2f72df26a5d3178022df33720fbede0d31d82c7291662eff89836994d" +checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" dependencies = [ "bytemuck", ] [[package]] name = "ring" -version = "0.17.8" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", - "spin", "untrusted", "windows-sys 0.52.0", ] +[[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + [[package]] name = "roaring" -version = "0.10.6" +version = "0.10.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +checksum = "19e8d2cfa184d94d0726d650a9f4a1be7f9b76ac9fdb954219878dc00c1c1e7b" dependencies = [ "bytemuck", "byteorder", @@ -4921,7 +6157,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.87", + "syn 2.0.100", "unicode-ident", ] @@ -4949,9 +6185,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc-hash" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustc_version" @@ -4964,29 +6200,28 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.27" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.9.0", "errno", - "io-lifetimes", "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", ] [[package]] name = "rustix" -version = "0.38.34" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", "errno", "libc", - "linux-raw-sys 0.4.14", - "windows-sys 0.52.0", + "linux-raw-sys 0.9.3", + "windows-sys 0.59.0", ] [[package]] @@ -5003,15 +6238,16 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.12" +version = "0.23.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" +checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.102.6", + "rustls-webpki 0.103.1", "subtle", "zeroize", ] @@ -5025,33 +6261,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile 1.0.4", "schannel", - "security-framework", -] - -[[package]] -name = "rustls-native-certs" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" -dependencies = [ - "openssl-probe", - "rustls-pemfile 2.1.3", - "rustls-pki-types", - "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] name = "rustls-native-certs" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.3", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.2.0", ] [[package]] @@ -5065,19 +6287,21 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.3" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" dependencies = [ - "base64 0.22.1", "rustls-pki-types", ] [[package]] name = "rustls-pki-types" -version = "1.8.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -5091,10 +6315,11 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.102.6" +version = "0.103.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" +checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -5102,9 +6327,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "rusty-fork" @@ -5120,9 +6345,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -5135,18 +6360,18 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "schemars" -version = "0.8.21" +version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09c024468a378b7e36765cd36702b7a90cc3cba11654f6685c8f233408e89e92" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ "dyn-clone", "schemars_derive", @@ -5156,16 +6381,22 @@ dependencies = [ [[package]] name = "schemars_derive" -version = "0.8.21" +version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1eee588578aff73f856ab961cd2f79e36bc45d7ded33a7562adba4667aecc0e" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.87", + "syn 2.0.100", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -5182,14 +6413,41 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sec1" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be24c1842290c45df0a7bf069e0c268a747ad05a192f2fd7dcfdbc1cba40928" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.6.0", - "core-foundation", + "bitflags 2.9.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags 2.9.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -5197,9 +6455,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.1" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -5207,37 +6465,37 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" dependencies = [ "serde", ] [[package]] name = "seq-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.204" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -5248,14 +6506,14 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "serde_json" -version = "1.0.122" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -5265,14 +6523,14 @@ dependencies = [ [[package]] name = "serde_tokenstream" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8790a7c3fe883e443eaa2af6f705952bc5d6e8671a220b9335c8cae92c037e74" +checksum = "64060d864397305347a78851c51588fd283767e7e7589829e8121d65512340f1" dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -5300,6 +6558,17 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -5329,6 +6598,12 @@ dependencies = [ "dirs", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -5338,11 +6613,27 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "1.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "siphasher" -version = "0.3.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "sketches-ddsketch" @@ -5364,30 +6655,29 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.2" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "snafu" -version = "0.7.5" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" dependencies = [ - "doc-comment", "snafu-derive", ] [[package]] name = "snafu-derive" -version = "0.7.5" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.4.1", + "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.100", ] [[package]] @@ -5398,49 +6688,44 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.4.10" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" +checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" dependencies = [ "libc", - "winapi", + "windows-sys 0.52.0", ] [[package]] -name = "socket2" -version = "0.5.7" +name = "spki" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b" dependencies = [ - "libc", - "windows-sys 0.52.0", + "base64ct", + "der", ] -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "sqlparser" -version = "0.49.0" +version = "0.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +checksum = "c66e3b7374ad4a6af849b08b3e7a6eda0edbd82f0fd59b57e22671bf16979899" dependencies = [ "log", + "recursive", "sqlparser_derive", ] [[package]] name = "sqlparser_derive" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" +checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -5449,6 +6734,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601f9201feb9b09c00266478bf459952b9ef9a6b94edb2f21eba14ab681a60a9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -5494,97 +6792,99 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] name = "substrait" -version = "0.29.4" +version = "0.50.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df6c402018947957c4c7f2af49304f5cd8a948858686bf958d519cf0aa644790" +checksum = "b1772d041c37cc7e6477733c76b2acf4ee36bd52b2ae4d9ea0ec9c87d003db32" dependencies = [ - "heck 0.5.0", + "heck", "prettyplease", - "prost", - "prost-build", - "prost-types", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", + "regress", "schemars", "semver", "serde", "serde_json", "serde_yaml", - "syn 2.0.87", - "typify 0.0.16", + "syn 2.0.100", + "typify 0.2.0", "walkdir", ] [[package]] name = "substrait" -version = "0.36.0" +version = "0.53.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1ee6e584c8bf37104b7eb51c25eae07a9321b0e01379bec3b7c462d2f42afbf" +checksum = "6fac3d70185423235f37b889764e184b81a5af4bb7c95833396ee9bd92577e1b" dependencies = [ - "heck 0.5.0", + "heck", "pbjson", "pbjson-build", "pbjson-types", "prettyplease", - "prost", - "prost-build", - "prost-types", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", + "regress", "schemars", "semver", "serde", "serde_json", "serde_yaml", - "syn 2.0.87", - "typify 0.1.0", + "syn 2.0.100", + "typify 0.3.0", "walkdir", ] [[package]] name = "substrait-expr" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9a8b8cc82442b391b67e7c195f0d3de35838bb78b115468d28076ec54dd4577" +checksum = "9d091cf06bc7808bd81eb01f5f5b77b2b14288bb022501a2dcad78633c65262f" dependencies = [ "once_cell", - "prost", - "substrait 0.29.4", + "prost 0.13.5", + "substrait 0.50.4", "substrait-expr-funcgen", "substrait-expr-macros", - "thiserror", + "thiserror 2.0.12", ] [[package]] name = "substrait-expr-funcgen" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96a5fb5bfa1ff743bdc1c259c46fde88d1ef8129c68ff7e7d876f907d67dbff7" +checksum = "bee762399b891e8c84b9777e67a4c3193bc499c176c18d22f39341df61166092" dependencies = [ "convert_case", "prettyplease", "proc-macro2", "quote", "serde_yaml", - "substrait 0.29.4", - "syn 2.0.87", - "thiserror", + "substrait 0.50.4", + "syn 2.0.100", + "thiserror 2.0.12", ] [[package]] name = "substrait-expr-macros" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919e5b5c5495d18dffb0b8369d74a143c893cbfb98b4337cecb31f3f9bcc112b" +checksum = "0e42af5525699cb9924c8fdd3aa233d2b067efde29f68c00090ca0c8eada8269" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -5595,9 +6895,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "symbolic-common" -version = "12.10.0" +version = "12.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16629323a4ec5268ad23a575110a724ad4544aae623451de600c747bf87b36cf" +checksum = "66135c8273581acaab470356f808a1c74a707fe7ec24728af019d7247e089e71" dependencies = [ "debugid", "memmap2", @@ -5607,9 +6907,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.10.0" +version = "12.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c043a45f08f41187414592b3ceb53fb0687da57209cc77401767fb69d5b596" +checksum = "42bcacd080282a72e795864660b148392af7babd75691d5ae9a3b77e29c98c77" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -5629,9 +6929,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.87" +version = "2.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" dependencies = [ "proc-macro2", "quote", @@ -5640,13 +6940,45 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.9.0", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tagptr" version = "0.2.0" @@ -5698,7 +7030,7 @@ dependencies = [ "tantivy-stacker", "tantivy-tokenizer-api", "tempfile", - "thiserror", + "thiserror 1.0.69", "time", "uuid", "winapi", @@ -5749,7 +7081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d60769b80ad7953d8a7b2c70cdfe722bbcdcac6bccc8ac934c40c034d866fc18" dependencies = [ "byteorder", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", "utf8-ranges", ] @@ -5802,9 +7134,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tar" -version = "0.4.41" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" dependencies = [ "filetime", "libc", @@ -5813,52 +7145,43 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.11.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fcd239983515c23a32fb82099f97d0b11b8c72f654ed659363a95c3dad7a53" +checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" dependencies = [ - "cfg-if", - "fastrand 2.1.0", + "fastrand", + "getrandom 0.3.2", "once_cell", - "rustix 0.38.34", - "windows-sys 0.52.0", -] - -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", + "rustix 1.0.5", + "windows-sys 0.59.0", ] [[package]] name = "termtree" -version = "0.4.1" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" [[package]] name = "test-log" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dffced63c2b5c7be278154d76b479f9f9920ed34e7574201407f0b14e2bbb93" +checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" dependencies = [ - "env_logger 0.11.5", + "env_logger", "test-log-macros", "tracing-subscriber", ] [[package]] name = "test-log-macros" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5" +checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -5883,31 +7206,51 @@ dependencies = [ "num-traits", "once_cell", "pin-project", - "prost", - "prost-build", + "prost 0.12.6", + "prost-build 0.12.6", "tar", - "thiserror", + "thiserror 1.0.69", "ureq", ] [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ - "thiserror-impl", + "proc-macro2", + "quote", + "syn 2.0.100", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -5933,9 +7276,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -5948,15 +7291,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", @@ -5971,6 +7314,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -5983,9 +7336,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" dependencies = [ "tinyvec_macros", ] @@ -5998,9 +7351,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" dependencies = [ "backtrace", "bytes", @@ -6008,20 +7361,30 @@ dependencies = [ "mio", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.7", + "socket2", "tokio-macros", "windows-sys 0.52.0", ] [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", ] [[package]] @@ -6036,20 +7399,19 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ - "rustls 0.23.12", - "rustls-pki-types", + "rustls 0.23.25", "tokio", ] [[package]] name = "tokio-stream" -version = "0.1.15" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", @@ -6058,9 +7420,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.11" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" dependencies = [ "bytes", "futures-core", @@ -6077,9 +7439,9 @@ checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" [[package]] name = "toml_edit" -version = "0.22.22" +version = "0.22.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" dependencies = [ "indexmap", "toml_datetime", @@ -6088,14 +7450,14 @@ dependencies = [ [[package]] name = "tower" -version = "0.4.13" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", - "pin-project", "pin-project-lite", + "sync_wrapper", "tokio", "tower-layer", "tower-service", @@ -6103,21 +7465,21 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -6126,13 +7488,13 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", ] [[package]] @@ -6148,9 +7510,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -6169,9 +7531,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term", @@ -6185,12 +7547,6 @@ dependencies = [ "tracing-log", ] -[[package]] -name = "triomphe" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" - [[package]] name = "try-lock" version = "0.2.5" @@ -6209,89 +7565,92 @@ dependencies = [ [[package]] name = "typenum" -version = "1.17.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "typify" -version = "0.0.16" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c61e9db210bbff218e6535c664b37ec47da449169b98e7866d0580d0db75529" +checksum = "b4c644dda9862f0fef3a570d8ddb3c2cfb1d5ac824a1f2ddfa7bc8f071a5ad8a" dependencies = [ - "typify-impl 0.0.16", - "typify-macro 0.0.16", + "typify-impl 0.2.0", + "typify-macro 0.2.0", ] [[package]] name = "typify" -version = "0.1.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb6beec125971dda80a086f90b4a70f60f222990ce4d63ad0fc140492f53444" +checksum = "e03ba3643450cfd95a1aca2e1938fef63c1c1994489337998aff4ad771f21ef8" dependencies = [ - "typify-impl 0.1.0", - "typify-macro 0.1.0", + "typify-impl 0.3.0", + "typify-macro 0.3.0", ] [[package]] name = "typify-impl" -version = "0.0.16" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95e32f38493804f88e2dc7a5412eccd872ea5452b4db9b0a77de4df180f2a87e" +checksum = "d59ab345b6c0d8ae9500b9ff334a4c7c0d316c1c628dc55726b95887eb8dbd11" dependencies = [ - "heck 0.4.1", + "heck", "log", "proc-macro2", "quote", - "regress 0.8.0", + "regress", "schemars", + "semver", + "serde", "serde_json", - "syn 2.0.87", - "thiserror", + "syn 2.0.100", + "thiserror 1.0.69", "unicode-ident", ] [[package]] name = "typify-impl" -version = "0.1.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93bbb24e990654aff858d80fee8114f4322f7d7a1b1ecb45129e2fcb0d0ad5ae" +checksum = "bce48219a2f3154aaa2c56cbf027728b24a3c8fe0a47ed6399781de2b3f3eeaf" dependencies = [ - "heck 0.5.0", + "heck", "log", "proc-macro2", "quote", - "regress 0.9.1", + "regress", "schemars", "semver", "serde", "serde_json", - "syn 2.0.87", - "thiserror", + "syn 2.0.100", + "thiserror 2.0.12", "unicode-ident", ] [[package]] name = "typify-macro" -version = "0.0.16" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc09508b72f63d521d68e42c7f172c7416d67986df44b3c7d1f7f9963948ed32" +checksum = "785e2cdcef0df8160fdd762ed548a637aaec1e83704fdbc14da0df66013ee8d0" dependencies = [ "proc-macro2", "quote", "schemars", + "semver", "serde", "serde_json", "serde_tokenstream", - "syn 2.0.87", - "typify-impl 0.0.16", + "syn 2.0.100", + "typify-impl 0.2.0", ] [[package]] name = "typify-macro" -version = "0.1.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8e6491896e955692d68361c68db2b263e3bec317ec0b684e0e2fa882fb6e31e" +checksum = "68b5780d745920ed73c5b7447496a9b5c42ed2681a9b70859377aec423ecf02b" dependencies = [ "proc-macro2", "quote", @@ -6300,8 +7659,8 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.87", - "typify-impl 0.1.0", + "syn 2.0.100", + "typify-impl 0.3.0", ] [[package]] @@ -6312,45 +7671,42 @@ checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" [[package]] name = "unicase" -version = "2.7.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" -dependencies = [ - "version_check", -] +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" [[package]] -name = "unicode-bidi" -version = "0.3.15" +name = "unicode-blocks" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" +checksum = "6b12e05d9e06373163a9bb6bb8c263c261b396643a99445fe6b9811fd376581b" [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-normalization" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" dependencies = [ "tinyvec", ] [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unsafe-libyaml" @@ -6366,15 +7722,15 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.10.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" dependencies = [ "base64 0.22.1", "flate2", "log", "once_cell", - "rustls 0.23.12", + "rustls 0.23.25", "rustls-pki-types", "url", "webpki-roots", @@ -6382,9 +7738,9 @@ dependencies = [ [[package]] name = "url" -version = "2.5.2" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", @@ -6397,12 +7753,24 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + [[package]] name = "utf8-ranges" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcfc827f90e53a02eaef5e535ee14266c1d569214c6aa70133a624d8a3164ba" +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -6411,25 +7779,33 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" dependencies = [ - "getrandom", + "getrandom 0.3.2", + "js-sys", "serde", + "wasm-bindgen", ] [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "value-bag" -version = "1.9.0" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "943ce29a8a743eb10d6082545d861b24f9d1b160b7d741e0f2cdf726bec909c5" + +[[package]] +name = "vcpkg" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a84c137d37ab0142f0f2ddfe332651fdbf252e7b7dbb4e67b6c1f1b2e925101" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "version_check" @@ -6445,19 +7821,13 @@ checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" [[package]] name = "wait-timeout" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" dependencies = [ "libc", ] -[[package]] -name = "waker-fn" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317211a0dc0ceedd78fb2ca9a44aed3d7b9b26f81870d485c07122b4350673b7" - [[package]] name = "walkdir" version = "2.5.0" @@ -6483,48 +7853,59 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", + "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6532,28 +7913,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-streams" -version = "0.4.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -6564,9 +7948,19 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", "wasm-bindgen", @@ -6574,13 +7968,25 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.3" +version = "0.26.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "winapi" version = "0.3.9" @@ -6612,24 +8018,101 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core 0.58.0", + "windows-targets 0.52.6", +] + [[package]] name = "windows-core" -version = "0.52.0" +version = "0.58.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" dependencies = [ + "windows-implement 0.58.0", + "windows-interface 0.58.0", + "windows-result 0.2.0", + "windows-strings 0.1.0", "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.61.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" +dependencies = [ + "windows-implement 0.60.0", + "windows-interface 0.59.1", + "windows-link", + "windows-result 0.3.2", + "windows-strings 0.4.0", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-link" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" + [[package]] name = "windows-registry" -version = "0.2.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ - "windows-result", - "windows-strings", - "windows-targets 0.52.6", + "windows-result 0.3.2", + "windows-strings 0.3.1", + "windows-targets 0.53.0", ] [[package]] @@ -6641,16 +8124,43 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-strings" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" dependencies = [ - "windows-result", + "windows-result 0.2.0", "windows-targets 0.52.6", ] +[[package]] +name = "windows-strings" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.45.0" @@ -6726,13 +8236,29 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-targets" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" @@ -6751,6 +8277,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" @@ -6769,6 +8301,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.42.2" @@ -6787,12 +8325,24 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.42.2" @@ -6811,6 +8361,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" @@ -6829,6 +8385,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" @@ -6847,6 +8409,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" @@ -6865,15 +8433,42 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" -version = "0.6.20" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "63d3fcd9bba44b03821e7d699eeee959f3126dcc4aa8e4ae18ec617c2a5cea10" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags 2.9.0", +] + +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "wyz" version = "0.5.1" @@ -6885,13 +8480,12 @@ dependencies = [ [[package]] name = "xattr" -version = "1.3.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e" dependencies = [ "libc", - "linux-raw-sys 0.4.14", - "rustix 0.38.34", + "rustix 1.0.5", ] [[package]] @@ -6901,19 +8495,40 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] -name = "xz2" -version = "0.1.7" +name = "yada" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aed111bd9e48a802518765906cbdadf0b45afb72b9c81ab049a3b86252adffdd" + +[[package]] +name = "yansi" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" dependencies = [ - "lzma-sys", + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", ] [[package]] -name = "yansi" -version = "0.5.1" +name = "yoke-derive" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", + "synstructure", +] [[package]] name = "zerocopy" @@ -6921,8 +8536,16 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +dependencies = [ + "zerocopy-derive 0.8.24", ] [[package]] @@ -6933,7 +8556,39 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.100", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", + "synstructure", ] [[package]] @@ -6942,11 +8597,33 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "zstd" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" dependencies = [ "zstd-safe", ] @@ -6962,9 +8639,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.12+zstd.1.5.6" +version = "2.0.13+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index e4f3174669e..d9a16d80976 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.20.0" +version = "0.26.2" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -41,43 +41,44 @@ categories = [ "development-tools", "science", ] -rust-version = "1.78" +rust-version = "1.82.0" [workspace.dependencies] -lance = { version = "=0.20.0", path = "./rust/lance" } -lance-arrow = { version = "=0.20.0", path = "./rust/lance-arrow" } -lance-core = { version = "=0.20.0", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.20.0", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.20.0", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.20.0", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.20.0", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.20.0", path = "./rust/lance-file" } -lance-index = { version = "=0.20.0", path = "./rust/lance-index" } -lance-io = { version = "=0.20.0", path = "./rust/lance-io" } -lance-jni = { version = "=0.20.0", path = "./java/core/lance-jni" } -lance-linalg = { version = "=0.20.0", path = "./rust/lance-linalg" } -lance-table = { version = "=0.20.0", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.20.0", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.20.0", path = "./rust/lance-testing" } +lance = { version = "=0.26.2", path = "./rust/lance" } +lance-arrow = { version = "=0.26.2", path = "./rust/lance-arrow" } +lance-core = { version = "=0.26.2", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.26.2", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.26.2", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.26.2", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.26.2", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.26.2", path = "./rust/lance-file" } +lance-index = { version = "=0.26.2", path = "./rust/lance-index" } +lance-io = { version = "=0.26.2", path = "./rust/lance-io" } +lance-jni = { version = "=0.26.2", path = "./java/core/lance-jni" } +lance-linalg = { version = "=0.26.2", path = "./rust/lance-linalg" } +lance-table = { version = "=0.26.2", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.26.2", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.26.2", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow -arrow = { version = "52.2", optional = false, features = ["prettyprint"] } -arrow-arith = "52.2" -arrow-array = "52.2" -arrow-buffer = "52.2" -arrow-cast = "52.2" -arrow-data = "52.2" -arrow-ipc = { version = "52.2", features = ["zstd"] } -arrow-ord = "52.2" -arrow-row = "52.2" -arrow-schema = "52.2" -arrow-select = "52.2" +arrow = { version = "54.1", optional = false, features = ["prettyprint"] } +arrow-arith = "54.1" +arrow-array = "54.1" +arrow-buffer = "54.1" +arrow-cast = "54.1" +arrow-data = "54.1" +arrow-ipc = { version = "54.1", features = ["zstd"] } +arrow-ord = "54.1" +arrow-row = "54.1" +arrow-schema = "54.1" +arrow-select = "54.1" async-recursion = "1.0" async-trait = "0.1" aws-config = "1.2.0" aws-credential-types = "1.2.0" aws-sdk-dynamodb = "1.38.0" -half = { "version" = "2.4.1", default-features = false, features = [ +aws-sdk-s3 = "1.38.0" +half = { "version" = "2.1", default-features = false, features = [ "num-traits", "std", ] } @@ -85,7 +86,7 @@ bitvec = "1" bytes = "1.4" byteorder = "1.5" clap = { version = "4", features = ["derive"] } -chrono = { version = "0.4.25", default-features = false, features = [ +chrono = { version = "0.4.40", default-features = false, features = [ "std", "now", ] } @@ -95,27 +96,33 @@ criterion = { version = "0.5", features = [ "html_reports", ] } crossbeam-queue = "0.3" -datafusion = { version = "41.0", default-features = false, features = [ +datafusion = { version = "46.0", default-features = false, features = [ "nested_expressions", "regex_expressions", "unicode_expressions", + "crypto_expressions", + "encoding_expressions", + "datetime_expressions", + "string_expressions", ] } -datafusion-common = "41.0" -datafusion-functions = { version = "41.0", features = ["regex_expressions"] } -datafusion-sql = "41.0" -datafusion-expr = "41.0" -datafusion-execution = "41.0" -datafusion-optimizer = "41.0" -datafusion-physical-expr = { version = "41.0", features = [ - "regex_expressions", -] } +datafusion-common = "46.0" +datafusion-functions = { version = "46.0", features = ["regex_expressions"] } +datafusion-sql = "46.0" +datafusion-expr = "46.0" +datafusion-execution = "46.0" +datafusion-optimizer = "46.0" +datafusion-physical-expr = { version = "46.0" } deepsize = "0.2.0" +dirs = "5.0.0" either = "1.0" -fsst = { version = "=0.20.0", path = "./rust/lance-encoding/src/compression_algo/fsst" } +fst = { version = "0.4.7", features = ["levenshtein"] } +fsst = { version = "=0.26.2", path = "./rust/lance-encoding/src/compression_algo/fsst" } futures = "0.3" http = "1.1.0" +humantime = "2.2.0" hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } itertools = "0.13" +jieba-rs = { version = "0.7", default-features = false } lazy_static = "1" log = "0.4" mockall = { version = "0.13.1" } @@ -123,15 +130,15 @@ mock_instant = { version = "0.3.1", features = ["sync"] } moka = { version = "0.12", features = ["future", "sync"] } num-traits = "0.2" # Set min to prevent use of versions with CVE-2024-41178 -object_store = { version = "0.10.2" } -parquet = "52.0" +object_store = { version = "0.11.0" } +parquet = "54.1" pin-project = "1.0" path_abs = "0.5" pprof = { version = "0.14.0", features = ["flamegraph", "criterion"] } proptest = "1.3.1" -prost = "0.12.2" -prost-build = "0.12.2" -prost-types = "0.12.2" +prost = "0.13.2" +prost-build = "0.13.2" +prost-types = "0.13.2" rand = { version = "0.8.3", features = ["small_rng"] } rangemap = { version = "1.0" } rayon = "1.10" @@ -141,8 +148,10 @@ rustc_version = "0.4" serde = { version = "^1" } serde_json = { version = "1" } shellexpand = "3.0" -snafu = "0.7.5" +snafu = "0.8" tantivy = { version = "0.22.0", features = ["stopwords"] } +lindera = { version = "0.38.1" } +lindera-tantivy = { version = "0.38.1" } tempfile = "3" test-log = { version = "0.2.15" } tokio = { version = "1.23", features = [ diff --git a/LICENSE b/LICENSE index 79de57d6670..1cb5bd54bdb 100644 --- a/LICENSE +++ b/LICENSE @@ -226,3 +226,212 @@ under the MIT license: SOFTWARE. https://github.com/pola-rs/polars/blob/main/LICENSE + +-------------------------------------------------------------------------------- + +This project includes code from apache spark project, which is licensed +under the Apache license: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +https://github.com/apache/spark/blob/master/LICENSE \ No newline at end of file diff --git a/README.md b/README.md index aed61800ac8..eea4d4180a0 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,13 @@ Lance Logo -**Modern columnar data format for ML. Convert from Parquet in 2-lines of code for 100x faster random access, a vector index, data versioning, and more.
** -**Compatible with pandas, DuckDB, Polars, and pyarrow with more integrations on the way.** +**Modern columnar data format for ML. Convert from Parquet in 2-lines of code for 100x faster random access, zero-cost schema evolution, rich secondary indices, versioning, and more.
** +**Compatible with Pandas, DuckDB, Polars, Pyarrow, and Ray with more integrations on the way.** Documentation • Blog • Discord • -Twitter +X [CI]: https://github.com/lancedb/lance/actions/workflows/rust.yml [CI Badge]: https://github.com/lancedb/lance/actions/workflows/rust.yml/badge.svg @@ -44,7 +44,7 @@ The key features of Lance include: * **Zero-copy, automatic versioning:** manage versions of your data without needing extra infrastructure. -* **Ecosystem integrations:** Apache Arrow, Pandas, Polars, DuckDB and more on the way. +* **Ecosystem integrations:** Apache Arrow, Pandas, Polars, DuckDB, Ray, Spark and more on the way. > [!TIP] > Lance is in active development and we welcome contributions. Please see our [contributing guide](docs/contributing.rst) for more information. @@ -66,7 +66,7 @@ pip install --pre --extra-index-url https://pypi.fury.io/lancedb/ pylance > [!TIP] > Preview releases are released more often than full releases and contain the > latest features and bug fixes. They receive the same level of testing as full releases. -> We guarantee they will remain published and available for download for at +> We guarantee they will remain published and available for download for at > least 6 months. When you want to pin to a specific version, prefer a stable release. **Converting to Lance** @@ -164,11 +164,12 @@ rs = [dataset.to_table(nearest={"column": "vector", "k": 10, "q": q}) ## Directory structure -| Directory | Description | -|--------------------|--------------------------| -| [rust](./rust) | Core Rust implementation | -| [python](./python) | Python bindings (pyo3) | -| [docs](./docs) | Documentation source | +| Directory | Description | +|--------------------|-------------------------------------------| +| [rust](./rust) | Core Rust implementation | +| [python](./python) | Python bindings (PyO3) | +| [java](./java) | Java bindings (JNI) and Spark integration | +| [docs](./docs) | Documentation source | ## What makes Lance different @@ -185,8 +186,8 @@ Support both CPUs (``x86_64`` and ``arm``) and GPU (``Nvidia (cuda)`` and ``Appl **Fast updates** (ROADMAP): Updates will be supported via write-ahead logs. -**Rich secondary indices** (ROADMAP): - - Inverted index for fuzzy search over many label / annotation fields. +**Rich secondary indices**: Support `BTree`, `Bitmap`, `Full text search`, `Label list`, +`NGrams`, and more. ## Benchmarks @@ -252,11 +253,16 @@ A comparison of different data formats in each stage of ML development cycle. Lance is currently used in production by: * [LanceDB](https://github.com/lancedb/lancedb), a serverless, low-latency vector database for ML applications +* [LanceDB Enterprise](https://docs.lancedb.com/enterprise/introduction), hyperscale LanceDB with enterprise SLA. +* Leading multimodal Gen AI companies for training over petabyte-scale multimodal data. * Self-driving car company for large-scale storage, retrieval and processing of multi-modal data. * E-commerce company for billion-scale+ vector personalized search. * and more. -## Presentations and Talks +## Presentations, Blogs and Talks +* [Designing a Table Format for ML Workloads](https://blog.lancedb.com/designing-a-table-format-for-ml-workloads/), Feb 2025. +* [Transforming Multimodal Data Management with LanceDB, Ray Summit](https://www.youtube.com/watch?v=xmTFEzAh8ho), Oct 2024. +* [Lance v2: A columnar container format for modern data](https://blog.lancedb.com/lance-v2/), Apr 2024. * [Lance Deep Dive](https://drive.google.com/file/d/1Orh9rK0Mpj9zN_gnQF1eJJFpAc6lStGm/view?usp=drive_link). July 2023. * [Lance: A New Columnar Data Format](https://docs.google.com/presentation/d/1a4nAiQAkPDBtOfXFpPg7lbeDAxcNDVKgoUkw3cUs2rE/edit#slide=id.p), [Scipy 2022, Austin, TX](https://www.scipy2022.scipy.org/posters). July, 2022. diff --git a/ci/check_versions.py b/ci/check_versions.py index d42062a2553..fe6023153f2 100644 --- a/ci/check_versions.py +++ b/ci/check_versions.py @@ -11,7 +11,12 @@ def get_versions(): """ Gets the current version in both python/Cargo.toml and Cargo.toml files. """ - import tomllib + try: + # Python 3.11+ + import tomllib + except ImportError: + # Python 3.6-3.10 use tomli + import tomli as tomllib with open("python/Cargo.toml", "rb") as file: pylance_version = tomllib.load(file)["package"]["version"] diff --git a/deny.toml b/deny.toml new file mode 100644 index 00000000000..27ae38cbc15 --- /dev/null +++ b/deny.toml @@ -0,0 +1,252 @@ +# This template contains all of the possible sections and their default values + +# Note that all fields that take a lint level have these possible values: +# * deny - An error will be produced and the check will fail +# * warn - A warning will be produced, but the check will not fail +# * allow - No warning or error will be produced, though in some cases a note +# will be + +# The values provided in this template are the default values that will be used +# when any section or field is not specified in your own configuration + +# Root options + +# The graph table configures how the dependency graph is constructed and thus +# which crates the checks are performed against +[graph] +# If 1 or more target triples (and optionally, target_features) are specified, +# only the specified targets will be checked when running `cargo deny check`. +# This means, if a particular package is only ever used as a target specific +# dependency, such as, for example, the `nix` crate only being used via the +# `target_family = "unix"` configuration, that only having windows targets in +# this list would mean the nix crate, as well as any of its exclusive +# dependencies not shared by any other crates, would be ignored, as the target +# list here is effectively saying which targets you are building for. +targets = [ + # The triple can be any string, but only the target triples built in to + # rustc (as of 1.40) can be checked against actual config expressions + #"x86_64-unknown-linux-musl", + # You can also specify which target_features you promise are enabled for a + # particular target. target_features are currently not validated against + # the actual valid features supported by the target architecture. + #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, + "x86_64-unknown-linux-gnu", + "aarch64-unknown-linux-gnu", + "x86_64-apple-darwin", + "aarch64-apple-darwin", + "x86_64-pc-windows-gnu", + "x86_64-pc-windows-msvc", +] +# When creating the dependency graph used as the source of truth when checks are +# executed, this field can be used to prune crates from the graph, removing them +# from the view of cargo-deny. This is an extremely heavy hammer, as if a crate +# is pruned from the graph, all of its dependencies will also be pruned unless +# they are connected to another crate in the graph that hasn't been pruned, +# so it should be used with care. The identifiers are [Package ID Specifications] +# (https://doc.rust-lang.org/cargo/reference/pkgid-spec.html) +#exclude = [] +# If true, metadata will be collected with `--all-features`. Note that this can't +# be toggled off if true, if you want to conditionally enable `--all-features` it +# is recommended to pass `--all-features` on the cmd line instead +all-features = true +# If true, metadata will be collected with `--no-default-features`. The same +# caveat with `all-features` applies +no-default-features = false +# If set, these feature will be enabled when collecting metadata. If `--features` +# is specified on the cmd line they will take precedence over this option. +#features = [] + +# The output table provides options for how/if diagnostics are outputted +[output] +# When outputting inclusion graphs in diagnostics that include features, this +# option can be used to specify the depth at which feature edges will be added. +# This option is included since the graphs can be quite large and the addition +# of features from the crate(s) to all of the graph roots can be far too verbose. +# This option can be overridden via `--feature-depth` on the cmd line +feature-depth = 1 + +# This section is considered when running `cargo deny check advisories` +# More documentation for the advisories section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html +[advisories] +# The path where the advisory databases are cloned/fetched into +#db-path = "$CARGO_HOME/advisory-dbs" +# The url(s) of the advisory databases to use +#db-urls = ["https://github.com/rustsec/advisory-db"] +# A list of advisory IDs to ignore. Note that ignored advisories will still +# output a note when they are encountered. +ignore = [ + #"RUSTSEC-0000-0000", + #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, + #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish + #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" }, + { id = "RUSTSEC-2021-0153", reason = "`encoding` is used by lindera" }, + { id = "RUSTSEC-2024-0384", reason = "`instant` is used by tantivy" }, + { id = "RUSTSEC-2024-0436", reason = "`paste` is used by datafusion" }, +] +# If this is true, then cargo deny will use the git executable to fetch advisory database. +# If this is false, then it uses a built-in git library. +# Setting this to true can be helpful if you have special authentication requirements that cargo-deny does not support. +# See Git Authentication for more information about setting up git authentication. +#git-fetch-with-cli = true + +# This section is considered when running `cargo deny check licenses` +# More documentation for the licenses section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html +[licenses] +# List of explicitly allowed licenses +# See https://spdx.org/licenses/ for list of possible licenses +# [possible values: any SPDX 3.11 short identifier (+ optional exception)]. +allow = [ + "MIT", + "Apache-2.0", + "Unicode-3.0", + "MPL-2.0", + "ISC", + "BSD-2-Clause", + "BSD-3-Clause", + "0BSD", + "OpenSSL", + "Zlib", + "CC0-1.0", +] +# The confidence threshold for detecting a license from license text. +# The higher the value, the more closely the license text must be to the +# canonical license text of a valid SPDX license file. +# [possible values: any between 0.0 and 1.0]. +confidence-threshold = 0.8 +# Allow 1 or more licenses on a per-crate basis, so that particular licenses +# aren't accepted for every possible crate as with the normal allow list +exceptions = [ + # Each entry is the crate and version constraint, and its specific allow + # list + #{ allow = ["Zlib"], crate = "adler32" }, +] + +# Some crates don't have (easily) machine readable licensing information, +# adding a clarification entry for it allows you to manually specify the +# licensing information +[[licenses.clarify]] +# The package spec the clarification applies to +crate = "ring" +# The SPDX expression for the license requirements of the crate +expression = "MIT AND ISC AND OpenSSL" +# One or more files in the crate's source used as the "source of truth" for +# the license expression. If the contents match, the clarification will be used +# when running the license check, otherwise the clarification will be ignored +# and the crate will be checked normally, which may produce warnings or errors +# depending on the rest of your configuration +license-files = [ + # Each entry is a crate relative path, and the (opaque) hash of its contents + { path = "LICENSE", hash = 0xbd0eed23 }, +] + +[licenses.private] +# If true, ignores workspace crates that aren't published, or are only +# published to private registries. +# To see how to mark a crate as unpublished (to the official registry), +# visit https://doc.rust-lang.org/cargo/reference/manifest.html#the-publish-field. +ignore = false +# One or more private registries that you might publish crates to, if a crate +# is only published to private registries, and ignore is true, the crate will +# not have its license(s) checked +registries = [ + #"https://sekretz.com/registry +] + +# This section is considered when running `cargo deny check bans`. +# More documentation about the 'bans' section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html +[bans] +# Lint level for when multiple versions of the same crate are detected +multiple-versions = "warn" +# Lint level for when a crate version requirement is `*` +wildcards = "allow" +# The graph highlighting used when creating dotgraphs for crates +# with multiple versions +# * lowest-version - The path to the lowest versioned duplicate is highlighted +# * simplest-path - The path to the version with the fewest edges is highlighted +# * all - Both lowest-version and simplest-path are used +highlight = "all" +# The default lint level for `default` features for crates that are members of +# the workspace that is being checked. This can be overridden by allowing/denying +# `default` on a crate-by-crate basis if desired. +workspace-default-features = "allow" +# The default lint level for `default` features for external crates that are not +# members of the workspace. This can be overridden by allowing/denying `default` +# on a crate-by-crate basis if desired. +external-default-features = "allow" +# List of crates that are allowed. Use with care! +allow = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" }, +] +# List of crates to deny +deny = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is banned" }, + # Wrapper crates can optionally be specified to allow the crate when it + # is a direct dependency of the otherwise banned crate + #{ crate = "ansi_term@0.11.0", wrappers = ["this-crate-directly-depends-on-ansi_term"] }, +] + +# List of features to allow/deny +# Each entry the name of a crate and a version range. If version is +# not specified, all versions will be matched. +#[[bans.features]] +#crate = "reqwest" +# Features to not allow +#deny = ["json"] +# Features to allow +#allow = [ +# "rustls", +# "__rustls", +# "__tls", +# "hyper-rustls", +# "rustls", +# "rustls-pemfile", +# "rustls-tls-webpki-roots", +# "tokio-rustls", +# "webpki-roots", +#] +# If true, the allowed features must exactly match the enabled feature set. If +# this is set there is no point setting `deny` +#exact = true + +# Certain crates/versions that will be skipped when doing duplicate detection. +skip = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason why it can't be updated/removed" }, +] +# Similarly to `skip` allows you to skip certain crates during duplicate +# detection. Unlike skip, it also includes the entire tree of transitive +# dependencies starting at the specified crate, up to a certain depth, which is +# by default infinite. +skip-tree = [ + #"ansi_term@0.11.0", # will be skipped along with _all_ of its direct and transitive dependencies + #{ crate = "ansi_term@0.11.0", depth = 20 }, +] + +# This section is considered when running `cargo deny check sources`. +# More documentation about the 'sources' section can be found here: +# https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html +[sources] +# Lint level for what to happen when a crate from a crate registry that is not +# in the allow list is encountered +unknown-registry = "warn" +# Lint level for what to happen when a crate from a git repository that is not +# in the allow list is encountered +unknown-git = "warn" +# List of URLs for allowed crate registries. Defaults to the crates.io index +# if not specified. If it is specified but empty, no registries are allowed. +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +# List of URLs for allowed Git repositories +allow-git = [] + +[sources.allow-org] +# github.com organizations to allow git sources for +github = [] +# gitlab.com organizations to allow git sources for +gitlab = [] +# bitbucket.org organizations to allow git sources for +bitbucket = [] diff --git a/docker-compose.yml b/docker-compose.yml index a55b31cd23c..6b87efc58dd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,21 +1,17 @@ version: "3.9" services: - minio: - image: lazybit/minio + localstack: + image: localstack/localstack:4.0 ports: - - 9000:9000 + - 4566:4566 environment: - - MINIO_ACCESS_KEY=ACCESSKEY - - MINIO_SECRET_KEY=SECRETKEY + - SERVICES=s3,dynamodb,kms + - DOCKER_HOST=unix:///var/run/docker.sock + # Note: localstack doesn't actually validate these. + - AWS_ACCESS_KEY_ID=ACCESS_KEY + - AWS_SECRET_ACCESS_KEY=SECRET_KEY healthcheck: - test: [ "CMD", "curl", "-f", "http://localhost:9000/minio/health/live" ] + test: [ "CMD", "curl", "-s", "http://localhost:4566/_localstack/health" ] interval: 5s retries: 3 start_period: 10s - dynamodb-local: - image: amazon/dynamodb-local - ports: - - 8000:8000 - environment: - - AWS_ACCESS_KEY_ID=ACCESSKEY - - AWS_SECRET_ACCESS_KEY=SECRETKEY diff --git a/docs/_static/blob.png b/docs/_static/blob.png new file mode 100644 index 00000000000..74d31b964a8 Binary files /dev/null and b/docs/_static/blob.png differ diff --git a/docs/_static/distributed_append.png b/docs/_static/distributed_append.png new file mode 100644 index 00000000000..af681c7a33a Binary files /dev/null and b/docs/_static/distributed_append.png differ diff --git a/docs/api/api.rst b/docs/api/api.rst index 4a6667f10a0..c657a2017df 100644 --- a/docs/api/api.rst +++ b/docs/api/api.rst @@ -2,6 +2,7 @@ APIs ---- .. toctree:: + :maxdepth: 1 - Rust - Python <./python/modules> + Rust + Python <./python.rst> \ No newline at end of file diff --git a/docs/api/py_modules.rst b/docs/api/py_modules.rst new file mode 100644 index 00000000000..8471f8c4600 --- /dev/null +++ b/docs/api/py_modules.rst @@ -0,0 +1,12 @@ + +.. automodule:: lance + :members: + :undoc-members: + +.. automodule:: lance.dataset + :members: + :undoc-members: + +.. automodule:: lance.fragment + :members: + :undoc-members: diff --git a/docs/api/python.rst b/docs/api/python.rst new file mode 100644 index 00000000000..b9a2d5edca2 --- /dev/null +++ b/docs/api/python.rst @@ -0,0 +1,68 @@ +Python APIs +=========== + +``Lance`` is a columnar format that is specifically designed for efficient +multi-modal data processing. + +Lance Dataset +------------- + +The core of Lance is the ``LanceDataset`` class. User can open a dataset by using +:py:meth:`lance.dataset`. + +.. autofunction:: lance.dataset + :noindex: + +Basic IOs +~~~~~~~~~ + +The following functions are used to read and write data in Lance format. + +.. automethod:: lance.dataset.LanceDataset.insert + :noindex: +.. automethod:: lance.dataset.LanceDataset.scanner + :noindex: +.. automethod:: lance.dataset.LanceDataset.to_batches + :noindex: +.. automethod:: lance.dataset.LanceDataset.to_table + :noindex: + +Random Access +~~~~~~~~~~~~~ + +Lance stands out with its super fast random access, unlike other columnar formats. + +.. automethod:: lance.dataset.LanceDataset.take + :noindex: +.. automethod:: lance.dataset.LanceDataset.take_blobs + :noindex: + + +Schema Evolution +~~~~~~~~~~~~~~~~ + +Lance supports schema evolution, which means that you can add new columns to the dataset +cheaply. + +.. automethod:: lance.dataset.LanceDataset.add_columns + :noindex: +.. automethod:: lance.dataset.LanceDataset.drop_columns + :noindex: + + +Indexing and Searching +~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: lance.dataset.LanceDataset.create_index + :noindex: +.. automethod:: lance.dataset.LanceDataset.create_scalar_index + :noindex: +.. automethod:: lance.dataset.LanceDataset.drop_index + :noindex: +.. automethod:: lance.dataset.LanceDataset.scanner + :noindex: + +API Reference +------------- + +More information can be found in the :doc:`API reference <./py_modules>`. diff --git a/docs/arrays.rst b/docs/arrays.rst index 184c7bccded..230d3c8f6b0 100644 --- a/docs/arrays.rst +++ b/docs/arrays.rst @@ -14,61 +14,58 @@ a 32-bit float: ~1e-38 to ~1e38. By comparison, a 16-bit float has a range of ~5.96e-8 to 65504. Lance provides an Arrow extension array (:class:`lance.arrow.BFloat16Array`) -and a Pandas extension array (:class:`lance.pandas.BFloat16Dtype`) for BFloat16. +and a Pandas extension array (:class:`~lance._arrow.PandasBFloat16Type`) for BFloat16. These are compatible with the `ml_dtypes `_ bfloat16 NumPy extension array. If you are using Pandas, you can use the `lance.bfloat16` dtype string to create the array: -.. testcode:: +.. doctest:: - import pandas as pd - import lance.arrow - - series = pd.Series([1.1, 2.1, 3.4], dtype="lance.bfloat16") - series - -.. testoutput:: + >>> import lance.arrow + >>> pd.Series([1.1, 2.1, 3.4], dtype="lance.bfloat16") 0 1.1015625 1 2.09375 2 3.40625 dtype: lance.bfloat16 -To create an an arrow array, use the :func:`lance.arrow.bfloat16_array` function: +To create an Arrow array, use the :func:`lance.arrow.bfloat16_array` function: -.. testcode:: +.. code-block:: python - from lance.arrow import bfloat16_array + >>> from lance.arrow import bfloat16_array - array = bfloat16_array([1.1, 2.1, 3.4]) - array - -.. testoutput:: + >>> bfloat16_array([1.1, 2.1, 3.4]) + + [ + 1.1015625, + 2.09375, + 3.40625 + ] - - [1.1015625, 2.09375, 3.40625] Finally, if you have a pre-existing NumPy array, you can convert it into either: -.. testcode:: - - import numpy as np - from ml_dtypes import bfloat16 - from lance.arrow import PandasBFloat16Array, BFloat16Array +.. doctest:: - np_array = np.array([1.1, 2.1, 3.4], dtype=bfloat16) - PandasBFloat16Array.from_numpy(np_array) - BFloat16Array.from_numpy(np_array) + >>> import numpy as np + >>> from ml_dtypes import bfloat16 + >>> from lance.arrow import PandasBFloat16Array, BFloat16Array -.. testoutput:: - + >>> np_array = np.array([1.1, 2.1, 3.4], dtype=bfloat16) + >>> PandasBFloat16Array.from_numpy(np_array) [1.1015625, 2.09375, 3.40625] Length: 3, dtype: lance.bfloat16 - - [1.1015625, 2.09375, 3.40625] + >>> BFloat16Array.from_numpy(np_array) + + [ + 1.1015625, + 2.09375, + 3.40625 + ] When reading, these can be converted back to to the NumPy bfloat16 dtype using each array class's ``to_numpy`` method. @@ -86,25 +83,23 @@ with a list of URIs represented by either :py:class:`pyarrow.StringArray` or an iterable that yields strings. Note that the URIs are not strongly validated and images are not read into memory automatically. -.. testcode:: - - from lance.arrow import ImageURIArray +.. doctest:: - ImageURIArray.from_uris([ - "/tmp/image1.jpg", - "file:///tmp/image2.jpg", - "s3://example/image3.jpg" - ]) + >>> from lance.arrow import ImageURIArray -.. testoutput:: + >>> ImageURIArray.from_uris([ + ... "/tmp/image1.jpg", + ... "file:///tmp/image2.jpg", + ... "s3://example/image3.jpg" + ... ]) + + ['/tmp/image1.jpg', 'file:///tmp/image2.jpg', 's3://example/image3.jpg'] - - ['/tmp/image1.jpg', 'file:///tmp/image2.jpg', 's3://example/image2.jpg'] :func:`lance.arrow.ImageURIArray.read_uris` will read images into memory and return them as a new :class:`lance.arrow.EncodedImageArray` object. -.. testcode:: +.. code-block:: python from lance.arrow import ImageURIArray @@ -139,7 +134,7 @@ function parameter. If decoder is not provided it will attempt to use `Pillow`_ and `tensorflow`_ in that order. If neither library or custom decoder is available an exception will be raised. -.. testcode:: +.. code-block:: python from lance.arrow import ImageURIArray @@ -185,30 +180,20 @@ If encoder is not provided it will attempt to use `tensorflow`_ and `Pillow`_ in that order. Default encoders will encode to PNG. If neither library is available it will raise an exception. -.. testcode:: - - from lance.arrow import ImageURIArray - - def jpeg_encoder(images): - import tensorflow as tf +.. testsetup:: - encoded_images = ( - tf.io.encode_jpeg(x).numpy() for x in tf.convert_to_tensor(images) - ) - return pa.array(encoded_images, type=pa.binary()) + image_uri = os.path.abspath(os.path.join(os.path.dirname(__name__), "_static", "icon.png")) - uris = [os.path.join(os.path.dirname(__file__), "images/1.png")] - tensor_images = ImageURIArray.from_uris(uris).read_uris().to_tensor() - print(tensor_images.to_encoded()) - print(tensor_images.to_encoded(jpeg_encoder)) +.. doctest:: -.. testoutput:: + >>> from lance.arrow import ImageURIArray + >>> uris = [image_uri] + >>> tensor_images = ImageURIArray.from_uris(uris).read_uris().to_tensor() + >>> tensor_images.to_encoded() - [b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00...'] - - [b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x01...'] - + [... + b'\x89PNG\r\n\x1a...' .. _tensorflow: https://www.tensorflow.org/api_docs/python/tf/io/encode_png .. _Pillow: https://pillow.readthedocs.io/en/stable/ \ No newline at end of file diff --git a/docs/blob.rst b/docs/blob.rst new file mode 100644 index 00000000000..c3587230e69 --- /dev/null +++ b/docs/blob.rst @@ -0,0 +1,41 @@ +Blob As Files +============= + +Unlike other data formats, large multimodal data is a first-class citizen in the Lance columnar format. +Lance provides a high-level API to store and retrieve large binary objects (blobs) in Lance datasets. + +.. image:: _static/blob.png + :scale: 50% + +Lance serves large binary data using :py:class:`lance.BlobFile`, which +is a file-like object that lazily reads large binary objects. + +To fetch blobs from a Lance dataset, you can use :py:meth:`lance.dataset.LanceDataset.take_blobs`. + +For example, it's easy to use `BlobFile` to extract frames from a video file without +loading the entire video into memory. + +.. code-block:: python + + # pip install av pylance + + import av + import lance + + ds = lance.dataset("./youtube.lance") + start_time, end_time = 500, 1000 + blobs = ds.take_blobs([5], "video") + with av.open(blobs[0]) as container: + stream = container.streams.video[0] + stream.codec_context.skip_frame = "NONKEY" + + start_time = start_time / stream.time_base + start_time = start_time.as_integer_ratio()[0] + end_time = end_time / stream.time_base + container.seek(start_time, stream=stream) + + for frame in container.decode(stream): + if frame.time > end_time: + break + display(frame.to_image()) + clear_output(wait=True) diff --git a/docs/conf.py b/docs/conf.py index b2cb7fb8ddd..5e69d0e4e03 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,25 +1,16 @@ # Configuration file for the Sphinx documentation builder. -import shutil - - -def run_apidoc(_): - from sphinx.ext.apidoc import main - - shutil.rmtree("api/python", ignore_errors=True) - main(["-f", "-o", "api/python", "../python/python/lance"]) - - -def setup(app): - app.connect("builder-inited", run_apidoc) - +import sys +import os # -- Project information ----------------------------------------------------- project = "Lance" -copyright = "2024, Lance Developer" +copyright = "%Y, Lance Developer" author = "Lance Developer" +sys.path.insert(0, os.path.abspath("../python")) + # -- General configuration --------------------------------------------------- @@ -27,12 +18,14 @@ def setup(app): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - "sphinx.ext.napoleon", "breathe", + "sphinx_immaterial", + "sphinx_immaterial.apidoc.python.apigen", "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.githubpages", "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", ] napoleon_google_docstring = False @@ -50,10 +43,30 @@ def setup(app): # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +intersphinx_mapping = { + "numpy": ("https://numpy.org/doc/stable/", None), + "pyarrow": ("https://arrow.apache.org/docs/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "ray": ("https://docs.ray.io/en/latest/", None), +} + +python_apigen_modules = { + "lance": "api/python/", +} +object_description_options = [ + ( + "py:.*", + dict( + include_object_type_in_xref_tooltip=False, + include_in_toc=False, + include_fields_in_toc=False, + ), + ), +] # -- Options for HTML output ------------------------------------------------- -html_theme = "piccolo_theme" +html_theme = "sphinx_immaterial" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -61,9 +74,52 @@ def setup(app): html_static_path = ["_static"] html_favicon = "_static/favicon_64x64.png" -# html_logo = "_static/high-res-icon.png" +html_logo = "_static/high-res-icon.png" html_theme_options = { - "source_url": "https://github.com/lancedb/lance", - "source_icon": "github", + "icon": { + "repo": "fontawesome/brands/github", + "edit": "material/file-edit-outline", + }, + "site_url": "https://github.com/lancedb/lance", + "repo_url": "https://github.com/lancedb/lance", + "repo_name": "Lance", + "features": [ + "navigation.expand", + # "navigation.tabs", + "content.tabs.link", + "content.code.copy", + ], + "navigation_depth": 4, + "social": [ + { + "icon": "fontawesome/brands/github", + "link": "https://github.com/lancedb/lance", + "name": "Source on github.com", + }, + { + "icon": "fontawesome/brands/python", + "link": "https://pypi.org/project/pylance/", + }, + { + "icon": "fontawesome/brands/discord", + "link": "https://discord.gg/zMM32dvNtd", + }, + ], } -html_css_files = ["custom.css"] + + +# -- doctest configuration --------------------------------------------------- + +doctest_global_setup = """ +import os +import shutil +from typing import Iterator + +import lance +import pyarrow as pa +import numpy as np +import pandas as pd +""" + +# Only test code examples in rst files +doctest_test_doctest_blocks = "" diff --git a/docs/contributing.rst b/docs/contributing.rst index a7e6d72206a..ec6114c169e 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -127,13 +127,6 @@ Example Notebooks Example notebooks are under `examples`. These are standalone notebooks you should be able to download and run. -DuckDB Extension -~~~~~~~~~~~~~~~~ - -In python, Lance integrates with DuckDB via Apache Arrow. Outside of python, the highly experimental duckdb extension for Lance -lives under `integration/duckdb_lance`. This uses the DuckDB `Rust extension framework `_. -The main code lives under `integration/duckdb_lance/src`. Follow the integration README for more details. - Benchmarks ~~~~~~~~~~ diff --git a/docs/distributed_write.rst b/docs/distributed_write.rst new file mode 100644 index 00000000000..1ac44b6b094 --- /dev/null +++ b/docs/distributed_write.rst @@ -0,0 +1,234 @@ +Distributed Write +================= + +.. warning:: + + Lance provides out-of-the-box :doc:`Ray <./integrations/ray>` and + `Spark `_ integrations. + + This page is intended for users who wish to perform distributed operations in a custom manner, + i.e. using `slurm` or `Kubernetes` without the Lance integration. + +Overview +-------- + +The :doc:`Lance format ` is designed to support parallel writing across multiple distributed workers. +A distributed write operation can be performed by two phases: + +#. **Parallel Writes**: Generate new :py:class:`~lance.LanceFragment` in parallel across multiple workers. +#. **Commit**: Collect all the :class:`~lance.FragmentMetadata` and commit into a single dataset in + a single :py:class:`~lance.LanceOperation`. + +.. image:: ./_static/distributed_append.png + +Write new data +--------------- + +Writing or appending new data is straightforward with :py:func:`~lance.fragment.write_fragments`. + +.. testsetup:: new_data + + shutil.rmtree("./dist_write", ignore_errors=True) + +.. testcode:: new_data + + import json + from lance.fragment import write_fragments + + # Run on each worker + data_uri = "./dist_write" + schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.string()), + ]) + + # Run on worker 1 + data1 = { + "a": [1, 2, 3], + "b": ["x", "y", "z"], + } + fragments_1 = write_fragments(data1, data_uri, schema=schema) + print("Worker 1: ", fragments_1) + + # Run on worker 2 + data2 = { + "a": [4, 5, 6], + "b": ["u", "v", "w"], + } + fragments_2 = write_fragments(data2, data_uri, schema=schema) + print("Worker 2: ", fragments_2) + +.. testoutput:: new_data + + Worker 1: [FragmentMetadata(id=0, files=...)] + Worker 2: [FragmentMetadata(id=0, files=...)] + + +Now, use :meth:`lance.fragment.FragmentMetadata.to_json` to serialize the fragment metadata, +and collect all serialized metadata on a single worker to execute the final commit operation. + +.. testsetup:: + + from lance.fragment import write_fragments + + shutil.rmtree("./dist_write", ignore_errors=True) + data_uri = "./dist_write" + schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.string()), + ]) + + data1 = { + "a": [1, 2, 3], + "b": ["x", "y", "z"], + } + fragments_1 = write_fragments(data1, data_uri, schema=schema) + data2 = { + "a": [4, 5, 6], + "b": ["u", "v", "w"], + } + fragments_2 = write_fragments(data2, data_uri, schema=schema) + +.. testcode:: new_data + + import json + from lance import FragmentMetadata, LanceOperation + + # Serialize Fragments into JSON data + fragments_json1 = [json.dumps(fragment.to_json()) for fragment in fragments_1] + fragments_json2 = [json.dumps(fragment.to_json()) for fragment in fragments_2] + + # On one worker, collect all fragments + all_fragments = [FragmentMetadata.from_json(f) for f in \ + fragments_json1 + fragments_json2] + + # Commit the fragments into a single dataset + # Use LanceOperation.Overwrite to overwrite the dataset or create new dataset. + op = lance.LanceOperation.Overwrite(schema, all_fragments) + read_version = 0 # Because it is empty at the time. + lance.LanceDataset.commit( + data_uri, + op, + read_version=read_version, + ) + + # We can read the dataset using the Lance API: + dataset = lance.dataset(data_uri) + assert len(dataset.get_fragments()) == 2 + assert dataset.version == 1 + print(dataset.to_table().to_pandas()) + +.. testoutput:: new_data + + a b + 0 1 x + 1 2 y + 2 3 z + 3 4 u + 4 5 v + 5 6 w + +Append data +------------ + +Appending additional data follows a similar process. Use :py:class:`lance.LanceOperation.Append` to commit the new fragments, +ensuring that the ``read_version`` is set to the current dataset's version. + +.. code-block:: python + :emphasize-lines: 2,4,5 + + ds = lance.dataset(data_uri) + read_version = ds.version + + op = lance.LanceOperation.Append(schema, all_fragments) + lance.LanceDataset.commit( + data_uri, + op, + read_version=read_version, + ) + +Add New Columns +--------------- + +`Lance Format excels at operations such as adding columns <./format.rst>`_. +Thanks to its two-dimensional layout +(`see this blog post `_), +adding new columns is highly efficient since it avoids copying the existing data files. +Instead, the process simply creates new data files and links them to the existing dataset +using metadata-only operations. + +.. testsetup:: add_columns + + import pyarrow as pa + import pyarrow.dataset as ds + import lance + + shutil.rmtree("./add_columns_example", ignore_errors=True) + + schema = pa.schema([ + ("name", pa.string()), + ("age", pa.int32()), + ]) + tbl = pa.Table.from_pydict({ + "name": ["alice", "bob", "charlie"], + "age": [25, 33, 44], + }, schema=schema) + dataset = lance.write_dataset(tbl, "./add_columns_example") + + tbl = pa.Table.from_pydict({ + "name": ["craig", "dave", "eve"], + "age": [55, 66, 77], + }, schema=schema) + dataset = dataset.insert(tbl) + +.. testcode:: add_columns + + from pyarrow import RecordBatch + import pyarrow.compute as pc + + from lance import LanceFragment, LanceOperation + + dataset = lance.dataset("./add_columns_example") + assert len(dataset.get_fragments()) == 2 + assert dataset.to_table().combine_chunks() == pa.Table.from_pydict({ + "name": ["alice", "bob", "charlie", "craig", "dave", "eve"], + "age": [25, 33, 44, 55, 66, 77], + }, schema=schema) + + + def name_len(names: RecordBatch) -> RecordBatch: + return RecordBatch.from_arrays( + [pc.utf8_length(names["name"])], + ["name_len"], + ) + + # On Worker 1 + frag1 = dataset.get_fragments()[0] + new_fragment1, new_schema = frag1.merge_columns(name_len, ["name"]) + + # On Worker 2 + frag2 = dataset.get_fragments()[1] + new_fragment2, _ = frag2.merge_columns(name_len, ["name"]) + + # On Worker 3 - Commit + all_fragments = [new_fragment1, new_fragment2] + op = lance.LanceOperation.Merge(all_fragments, schema=new_schema) + lance.LanceDataset.commit( + "./add_columns_example", + op, + read_version=dataset.version, + ) + + # Verify dataset + dataset = lance.dataset("./add_columns_example") + print(dataset.to_table().to_pandas()) + +.. testoutput:: add_columns + + name age name_len + 0 alice 25 5 + 1 bob 33 3 + 2 charlie 44 7 + 3 craig 55 5 + 4 dave 66 4 + 5 eve 77 3 diff --git a/docs/examples/examples.rst b/docs/examples/examples.rst index 35200523fe8..b502794bf4f 100644 --- a/docs/examples/examples.rst +++ b/docs/examples/examples.rst @@ -9,4 +9,5 @@ Examples Reading and writing a Lance dataset in Rust <./write_read_dataset.rst> Creating Multi-Modal datasets using Lance <./flickr8k_dataset_creation.rst> Training Multi-Modal models using a Lance dataset <./clip_training.rst> - Deep Learning Artefact Management using Lance <./artefact_management.rst> \ No newline at end of file + Deep Learning Artefact Management using Lance <./artefact_management.rst> + Reading and writing a Lance dataset via Spark DataSource <./spark_datasource_example.rst> \ No newline at end of file diff --git a/docs/examples/spark_datasource_example.rst b/docs/examples/spark_datasource_example.rst new file mode 100644 index 00000000000..14a295a32a5 --- /dev/null +++ b/docs/examples/spark_datasource_example.rst @@ -0,0 +1,111 @@ +Writing and Reading a Dataset Using Spark +========================================= + +.. attention:: + The Spark connector is currently an experimental feature undergoing rapid iteration. + +In this example, we will read a local ``iris.csv`` file and write it as a Lance dataset using Apache Spark, then demonstrate how to query the dataset. + +Preparing the Environment and Raw Dataset +----------------------------------------- + +Download the Spark binary package from the `official website `_. We recommend downloading Spark 3.5+ for Scala 2.12 (as the Spark connector currently only supports Scala 2.12). + +You can directly download Spark 3.5.1 using this `link `_. + +Prepare the dataset by downloading `iris.csv `_ to your local machine. + +Create a Scala file named ``iris_to_lance_via_spark_shell.scala`` and open it. + +Reading the Raw Dataset and Writing to a Lance Dataset +------------------------------------------------------- + +Add necessary imports and create a Spark session: + +.. code-block:: scala + + import org.apache.spark.sql.types.{StructType, StructField, DoubleType, StringType} + import org.apache.spark.sql.{SparkSession, DataFrame} + import com.lancedb.lance.spark.{LanceConfig, LanceDataSource} + + val spark = SparkSession.builder() + .appName("Iris CSV to Lance Converter") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .getOrCreate() + +Specifying your input and output path: + +.. code-block:: scala + + val irisPath = "/path/to/your/input/iris.csv" + val outputPath = "/path/to/your/output/iris.lance" + +Reading the ``iris.csv`` via the following snippet: + +.. code-block:: scala + + val rawDF = spark.read + .option("header", "true") + .option("inferSchema", "true") + .csv(irisPath) + + rawDF.printSchema() + +Preparing the lance schema and write a lance dataset: + +.. code-block:: scala + + val lanceSchema = new StructType() + .add(StructField("sepal_length", DoubleType)) + .add(StructField("sepal_width", DoubleType)) + .add(StructField("petal_length", DoubleType)) + .add(StructField("petal_width", DoubleType)) + .add(StructField("species", StringType)) + + val lanceDF = spark.createDataFrame(rawDF.rdd, lanceSchema) + + lanceDF.write + .format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, outputPath) + .save() + +Reading a Lance dataset +----------------------- + +After writing the dataset, we can read it back and examine its properties: + +.. code-block:: scala + + val lanceDF = spark.read + .format("lance") + .option(LanceConfig.CONFIG_DATASET_URI, outputPath) + .load() + + println(s"The total count: ${lanceDF.count()}") + lanceDF.printSchema() + println("\n The top 5 data:") + lanceDF.show(5, truncate = false) + + println("\n Species distribution statistics:") + lanceDF.groupBy("species").count().show() + +First, we open the dataset and count the total rows. Then we print the dataset schema. Finally, we analyze the species distribution statistics. + +Running the Spark Application +----------------------------- + +To execute the application, download these dependencies: + +* lance-core JAR: Core Rust Spark binding exposing Lance features to Java (available `here `_) +* lance-spark JAR: Spark connector for reading/writing Lance format (available `here `_) +* jar-jni JAR: Load JNI dependencies embedded within a JAR file (available `here `_) +* arrow-c-data JAR: Java implementation of C Data Interface (available `here `_) +* arrow-dataset JAR: Java implementation of Arrow Dataset API/Framework (available `here `_) + +Place these JARs in the ``${SPARK_HOME}/jars`` directory, then run: + +.. code-block:: bash + + ./bin/spark-shell --jars ./jars/lance-core-0.23.0.jar,./jars/lance-spark-0.23.0.jar,./jars/jar-jni-1.1.1.jar,./jars/arrow-c-data-12.0.1.jar,./jars/arrow-dataset-12.0.1.jar -i ./iris_to_lance_via_spark_shell.scala + +It should be work! Have fun! diff --git a/docs/format.rst b/docs/format.rst index 13cfbc27127..b2e9b5237e1 100644 --- a/docs/format.rst +++ b/docs/format.rst @@ -1,7 +1,7 @@ Lance Formats ============= -The Lance project includes both a table format and a file format. Lance typically refers +The Lance format is both a table format and a file format. Lance typically refers to tables as "datasets". A Lance dataset is designed to efficiently handle secondary indices, fast ingestion and modification of data, and a rich set of schema evolution features. @@ -31,7 +31,7 @@ Fragments ~~~~~~~~~ ``DataFragment`` represents a chunk of data in the dataset. Itself includes one or more ``DataFile``, -where each ``DataFile`` can contain several columns in the chunk of data. It also may include a +where each ``DataFile`` can contain several columns in the chunk of data. It also may include a ``DeletionFile``, which is explained in a later section. .. literalinclude:: ../protos/table.proto @@ -86,7 +86,7 @@ and/or performance. However, older software versions may not be able to read ne In addition, the latest version of the file format (next) is unstable and should not be used for production use cases. Breaking changes could be made to unstable encodings and -that would mean that files written with these encodings are no longer readable by any +that would mean that files written with these encodings are no longer readable by any newer versions of Lance. The ``next`` version should only be used for experimentation and benchmarking upcoming features. @@ -95,7 +95,7 @@ The following values are supported: .. list-table:: File Versions :widths: 20 20 20 40 :header-rows: 1 - + * - Version - Minimal Lance Version - Maximum Lance Version @@ -206,7 +206,7 @@ Feature Flags As the file format and dataset evolve, new feature flags are added to the format. There are two separate fields for checking for feature flags, depending on whether you are trying to read or write the table. Readers should check the -``reader_feature_flags`` to see if there are any flag it is not aware of. Writers +``reader_feature_flags`` to see if there are any flag it is not aware of. Writers should check ``writer_feature_flags``. If either sees a flag they don't know, they should return an "unsupported" error on any read or write operation. @@ -286,7 +286,7 @@ deleted for some fragment. For a given version of the dataset, each fragment can have up to one deletion file. Fragments that have no deleted rows have no deletion file. -Readers should filter out row ids contained in these deletion files during a +Readers should filter out row ids contained in these deletion files during a scan or ANN search. Deletion files come in two flavors: @@ -319,7 +319,7 @@ collisions. The suffix is determined by the file type (``.arrow`` for Arrow file :start-at: // Deletion File :end-at: } // DeletionFile -Deletes can be materialized by re-writing data files with the deleted rows +Deletes can be materialized by re-writing data files with the deleted rows removed. However, this invalidates row indices and thus the ANN indices, which can be expensive to recompute. @@ -388,7 +388,7 @@ The commit process is as follows: fails because another writer has already committed, go back to step 3. When checking whether two transactions conflict, be conservative. If the -transaction file is missing, assume it conflicts. If the transaction file +transaction file is missing, assume it conflicts. If the transaction file has an unknown operation, assume it conflicts. .. _external-manifest-store: @@ -555,7 +555,7 @@ The row id values for a fragment are stored in a ``RowIdSequence`` protobuf message. This is described in the `protos/rowids.proto`_ file. Row id sequences are just arrays of u64 values, which have representations optimized for the common case where they are sorted and possibly contiguous. For example, a new -fragment will have a row id sequence that is just a simple range, so it is +fragment will have a row id sequence that is just a simple range, so it is stored as a ``start`` and ``end`` value. These sequence messages are either stored inline in the fragment metadata, or diff --git a/docs/index.rst b/docs/index.rst index 28c96053ce0..0bd7ec8261f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -2,20 +2,20 @@ .. image:: _static/lance_logo.png :width: 400 -Lance: modern columnar data format for ML -====================================================================================== +Lance: modern columnar format for ML workloads +============================================== -`Lance` is a columnar data format that is easy and fast to version, query and train on. +`Lance` is a columnar format that is easy and fast to version, query and train on. It’s designed to be used with images, videos, 3D point clouds, audio and of course tabular data. It supports any POSIX file systems, and cloud storage like AWS S3 and Google Cloud Storage. The key features of Lance include: * **High-performance random access:** 100x faster than Parquet. -* **Vector search:** find nearest neighbors in under 1 millisecond and combine OLAP-queries with vector search. +* **Zero-copy schema evolution:** add and drop columns without copying the entire dataset. -* **Zero-copy, automatic versioning:** manage versions of your data automatically, and reduce redundancy with zero-copy logic built-in. +* **Vector search:** find nearest neighbors in under 1 millisecond and combine OLAP-queries with vector search. * **Ecosystem integrations:** Apache-Arrow, DuckDB and more on the way. @@ -39,14 +39,38 @@ Preview releases receive the same level of testing as regular releases. .. toctree:: - :maxdepth: 1 + :caption: Introduction + :maxdepth: 2 Quickstart <./notebooks/quickstart> - ./read_and_write - Lance Formats <./format> - Arrays <./arrays> - Integrations <./integrations/integrations> + ./introduction/read_and_write + ./introduction/schema_evolution + +.. toctree:: + :caption: Advanced Usage + :maxdepth: 1 + + Lance Format Spec <./format> + Blob API <./blob> + ./tags + Object Store Configuration <./object_store> + Distributed Write <./distributed_write> Performance Guide <./performance> + Tokenizer <./tokenizer> + Extension Arrays <./arrays> + +.. toctree:: + :caption: Integrations + + Huggingface <./integrations/huggingface> + Tensorflow <./integrations/tensorflow> + PyTorch <./integrations/pytorch> + Ray <./integrations/ray> + Spark <./integrations/spark> + +.. toctree:: + :maxdepth: 1 + API References <./api/api> Contributor Guide <./contributing> Examples <./examples/examples> diff --git a/docs/integrations/integrations.rst b/docs/integrations/integrations.rst deleted file mode 100644 index ecba04181f3..00000000000 --- a/docs/integrations/integrations.rst +++ /dev/null @@ -1,10 +0,0 @@ -Integrations ------------- - -.. toctree:: - :maxdepth: 2 - - Huggingface <./huggingface> - Tensorflow <./tensorflow> - PyTorch <./pytorch> - Ray <./ray> diff --git a/docs/integrations/ray.rst b/docs/integrations/ray.rst index e5c3adab4b5..724fe2473c8 100644 --- a/docs/integrations/ray.rst +++ b/docs/integrations/ray.rst @@ -1,27 +1,50 @@ Lance â¤ï¸ Ray -------------------- -Ray effortlessly scale up ML workload to large distributed compute environment. +`Ray `_ effortlessly scale up ML workload to large distributed +compute environment. -`Ray Data `_ can be directly written in Lance format by using the -:class:`lance.ray.sink.LanceDatasink` class. For example: +Lance format is one of the official `Ray data sources `_: -.. code-block:: bash +* Lance Data Source :py:meth:`ray.data.read_lance` +* Lance Data Sink :py:meth:`ray.data.Dataste.write_lance` - pip install pylance[ray] +.. testsetup:: + shutil.rmtree("./alice_bob_and_charlie.lance", ignore_errors=True) -``Ray Data Dataset`` can be written to Lance format using the following code: - -.. code-block:: python +.. testcode:: import ray - from lance.ray.sink import LanceDatasink + import pandas as pd ray.init() - sink = LanceDatasink("s3://bucket/to/data.lance") - ray.data.range(10).map( - lambda x: {"id": x["id"], "str": f"str-{x['id']}"} - ).write_datasink(sink) - + data = [ + {"id": 1, "name": "alice"}, + {"id": 2, "name": "bob"}, + {"id": 3, "name": "charlie"} + ] + ray.data.from_items(data).write_lance("./alice_bob_and_charlie.lance") + + # It can be read via lance directly + df = ( + lance. + dataset("./alice_bob_and_charlie.lance") + .to_table() + .to_pandas() + .sort_values(by=["id"]) + .reset_index(drop=True) + ) + assert df.equals(pd.DataFrame(data)), "{} != {}".format( + df, pd.DataFrame(data) + ) + + # Or via Ray.data.read_lance + ray_df = ( + ray.data.read_lance("./alice_bob_and_charlie.lance") + .to_pandas() + .sort_values(by=["id"]) + .reset_index(drop=True) + ) + assert df.equals(ray_df) diff --git a/docs/integrations/spark.rst b/docs/integrations/spark.rst new file mode 100644 index 00000000000..3a9ad123bf7 --- /dev/null +++ b/docs/integrations/spark.rst @@ -0,0 +1,122 @@ +Lance â¤ï¸ Spark +-------------------- + +Lance can be used as a third party datasource of ``_ + +.. warning:: + This feature is experimental and the APIs may change in the future. + +Build from source code +~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + git clone https://github.com/lancedb/lance.git + cd lance/java + mvn clean package -DskipTests -Drust.release.build=true + +After building the code, the spark related jars are under path :class:`lance/java/spark/target/jars/` + +.. code-block:: shell + + arrow-c-data-15.0.0.jar + arrow-dataset-15.0.0.jar + jar-jni-1.1.1.jar + lance-core-0.25.0-SNAPSHOT.jar + lance-spark-0.25.0-SNAPSHOT.jar + + + +Download the pre-build jars +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +If you did not want to get jars from source, you can download these five jars from maven repo. + +.. code-block:: bash + + wget https://repo1.maven.org/maven2/com/lancedb/lance-core/0.23.0/lance-core-0.23.0.jar + wget https://repo1.maven.org/maven2/com/lancedb/lance-spark/0.23.0/lance-spark-0.23.0.jar + wget https://repo1.maven.org/maven2/org/questdb/jar-jni/1.1.1/jar-jni-1.1.1.jar + wget https://repo1.maven.org/maven2/org/apache/arrow/arrow-c-data/12.0.1/arrow-c-data-12.0.1.jar + wget https://repo1.maven.org/maven2/org/apache/arrow/arrow-dataset/12.0.1/arrow-dataset-12.0.1.jar + +Configurations for Lance Spark Connector +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +There are some configurations you have to set in :class:`spark-defaults.conf` to enable lance datasource. + +.. code-block:: text + + spark.sql.catalog.lance com.lancedb.lance.spark.LanceCatalog + +This config define the `LanceCatalog` and then the spark will treat lance as a datasource. + +If dealing with lance dataset stored in object store, these configurations should be set: + +.. code-block:: text + + spark.sql.catalog.lance.access_key_id {your object store ak} + spark.sql.catalog.lance.secret_access_key {your object store sk} + spark.sql.catalog.lance.aws_region {your object store region(optional)} + spark.sql.catalog.lance.aws_endpoint {your object store aws_endpoint which should be in virtual host style} + spark.sql.catalog.lance.virtual_hosted_style_request true + + +Startup the Spark Shell +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: shell + + bin/spark-shell --master "local[56]" --jars "/path_of_code/lance/java/spark/target/jars/*.jar" + + +Use :class:`--jars` to involve the related jars we build or downloaded. + +.. note:: + Spark shell console use :class:`scala` language not :class:`python` + +Using Spark Shell to manipulate lance dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* Write a new dataset named :class:`test.lance` + +.. code-block:: scala + + val df = Seq( + ("Alice", 1), + ("Bob", 2) + ).toDF("name", "id") + df.write.format("lance").option("path","./test.lance").save() + +* Overwrite the :class:`test.lance` dataset + +.. code-block:: scala + + val df = Seq( + ("Alice", 3), + ("Bob", 4) + ).toDF("name", "id") + df.write.format("lance").option("path","./test.lance").mode("overwrite").save() + +* Append Data into the :class:`test.lance` dataset + +.. code-block:: scala + + val df = Seq( + ("Chris", 5), + ("Derek", 6) + ).toDF("name", "id") + df.write.format("lance").option("path","./test.lance").mode("append").save() + +* Use spark data frame to read the :class:`test.lance` dataset + +.. code-block:: scala + + val data = spark.read.format("lance").option("path", "./test.lance").load(); + data.show() + +* Register data frame as table and use sql to query :class:`test.lance` dataset + +.. code-block:: scala + + data.createOrReplaceTempView("lance_table") + spark.sql("select id, count(*) from lance_table group by id order by id").show() + diff --git a/docs/introduction/read_and_write.rst b/docs/introduction/read_and_write.rst new file mode 100644 index 00000000000..f0002854b4b --- /dev/null +++ b/docs/introduction/read_and_write.rst @@ -0,0 +1,554 @@ +Read and Write Data +=================== + +Writing Lance Dataset +--------------------- + +If you're familiar with `Apache PyArrow `_, +you'll find that creating a Lance dataset is straightforward. +Begin by writing a :py:class:`pyarrow.Table` using the :py:meth:`lance.write_dataset` function. + +.. testsetup:: + + shutil.rmtree("./alice_and_bob.lance", ignore_errors=True) + +.. doctest:: + + >>> import lance + >>> import pyarrow as pa + + >>> table = pa.Table.from_pylist([{"name": "Alice", "age": 20}, + ... {"name": "Bob", "age": 30}]) + >>> ds = lance.write_dataset(table, "./alice_and_bob.lance") + +If the dataset is too large to fully load into memory, you can stream data using :py:meth:`lance.write_dataset` +also supports :py:class:`~typing.Iterator` of :py:class:`pyarrow.RecordBatch` es. +You will need to provide a :py:class:`pyarrow.Schema` for the dataset in this case. + +.. testsetup:: rst_generator + + shutil.rmtree("./alice_and_bob.lance", ignore_errors=True) + +.. doctest:: rst_generator + + >>> def producer() -> Iterator[pa.RecordBatch]: + ... """An iterator of RecordBatches.""" + ... yield pa.RecordBatch.from_pylist([{"name": "Alice", "age": 20}]) + ... yield pa.RecordBatch.from_pylist([{"name": "Bob", "age": 30}]) + + >>> schema = pa.schema([ + ... ("name", pa.string()), + ... ("age", pa.int32()), + ... ]) + + >>> ds = lance.write_dataset(producer(), + ... "./alice_and_bob.lance", + ... schema=schema, mode="overwrite") + >>> ds.count_rows() + 2 + +:py:meth:`lance.write_dataset` supports writing :py:class:`pyarrow.Table`, :py:class:`pandas.DataFrame`, +:py:class:`pyarrow.dataset.Dataset`, and ``Iterator[pyarrow.RecordBatch]``. + +Adding Rows +----------- + +To insert data into your dataset, you can use either :py:meth:`LanceDataset.insert ` +or :py:meth:`~lance.write_dataset` with ``mode=append``. + +.. testsetup:: + + shutil.rmtree("./insert_example.lance", ignore_errors=True) + +.. doctest:: + + >>> import lance + >>> import pyarrow as pa + + >>> table = pa.Table.from_pylist([{"name": "Alice", "age": 20}, + ... {"name": "Bob", "age": 30}]) + >>> ds = lance.write_dataset(table, "./insert_example.lance") + + >>> new_table = pa.Table.from_pylist([{"name": "Carla", "age": 37}]) + >>> ds.insert(new_table) + >>> ds.to_table().to_pandas() + name age + 0 Alice 20 + 1 Bob 30 + 2 Carla 37 + + >>> new_table2 = pa.Table.from_pylist([{"name": "David", "age": 42}]) + >>> ds = lance.write_dataset(new_table2, ds, mode="append") + >>> ds.to_table().to_pandas() + name age + 0 Alice 20 + 1 Bob 30 + 2 Carla 37 + 3 David 42 + + +Deleting rows +------------- + +Lance supports deleting rows from a dataset using a SQL filter, as described in :ref:`filter-push-down`. +For example, to delete Bob's row from the dataset above, one could use: + +.. doctest:: + + >>> import lance + + >>> dataset = lance.dataset("./alice_and_bob.lance") + >>> dataset.delete("name = 'Bob'") + >>> dataset2 = lance.dataset("./alice_and_bob.lance") + >>> dataset2.to_table().to_pandas() + name age + 0 Alice 20 + + +.. note:: + + :doc:`Lance Format is immutable <../format>`. Each write operation creates a new version of the dataset, + so users must reopen the dataset to see the changes. Likewise, rows are removed by marking + them as deleted in a separate deletion index, rather than rewriting the files. This approach + is faster and avoids invalidating any indices that reference the files, ensuring that subsequent + queries do not return the deleted rows. + + +Updating rows +------------- + +Lance supports updating rows based on SQL expressions with the +:py:meth:`lance.LanceDataset.update` method. For example, if we notice +that Bob's name in our dataset has been sometimes written as ``Blob``, we can fix +that with: + +.. code-block:: python + + import lance + + dataset = lance.dataset("./alice_and_bob.lance") + dataset.update({"name": "'Bob'"}), where="name = 'Blob'") + +The update values are SQL expressions, which is why ``'Bob'`` is wrapped in single +quotes. This means we can use complex expressions that reference existing columns if +we wish. For example, if two years have passed and we wish to update the ages +of Alice and Bob in the same example, we could write: + +.. code-block:: python + + import lance + + dataset = lance.dataset("./alice_and_bob.lance") + dataset.update({"age": "age + 2"}) + +If you are trying to update a set of individual rows with new values then it is often +more efficient to use the merge insert operation described below. + +.. code-block:: python + + import lance + + # Change the ages of both Alice and Bob + new_table = pa.Table.from_pylist([{"name": "Alice", "age": 30}, + {"name": "Bob", "age": 20}]) + + # This works, but is inefficient, see below for a better approach + dataset = lance.dataset("./alice_and_bob.lance") + for idx in range(new_table.num_rows): + name = new_table[0][idx].as_py() + new_age = new_table[1][idx].as_py() + dataset.update({"age": new_age}, where=f"name='{name}'") + +Merge Insert +------------ + +Lance supports a merge insert operation. This can be used to add new data in bulk +while also (potentially) matching against existing data. This operation can be used +for a number of different use cases. + +Bulk Update +^^^^^^^^^^^ + +The :py:meth:`lance.LanceDataset.update` method is useful for updating rows based on +a filter. However, if we want to replace existing rows with new rows then a :py:meth:`lance.LanceDataset.merge_insert` +operation would be more efficient: + +.. testsetup:: bulk_update + + tbl = pa.Table.from_pylist([{"name": "Alice", "age": 20}, + {"name": "Bob", "age": 30}]) + lance.write_dataset(tbl, "./alice_and_bob.lance", mode="overwrite") + +.. doctest:: bulk_update + + >>> import lance + + >>> dataset = lance.dataset("./alice_and_bob.lance") + >>> dataset.to_table().to_pandas() + name age + 0 Alice 20 + 1 Bob 30 + >>> # Change the ages of both Alice and Bob + >>> new_table = pa.Table.from_pylist([{"name": "Alice", "age": 2}, + ... {"name": "Bob", "age": 3}]) + >>> # This will use `name` as the key for matching rows. Merge insert + >>> # uses a JOIN internally and so you typically want this column to + >>> # be a unique key or id of some kind. + >>> rst = dataset.merge_insert("name") \ + ... .when_matched_update_all() \ + ... .execute(new_table) + >>> dataset.to_table().to_pandas() + name age + 0 Alice 2 + 1 Bob 3 + +Note that, similar to the update operation, rows that are modified will +be removed and inserted back into the table, changing their position to +the end. Also, the relative order of these rows could change because we +are using a hash-join operation internally. + +Insert if not Exists +^^^^^^^^^^^^^^^^^^^^ + +Sometimes we only want to insert data if we haven't already inserted it +before. This can happen, for example, when we have a batch of data but +we don't know which rows we've added previously and we don't want to +create duplicate rows. We can use the merge insert operation to achieve +this: + +.. testsetup:: insert_if_not_exists + + import lance + import pyarrow as pa + + # Create a fresh dataset + tbl = pa.Table.from_pylist([{"name": "Alice", "age": 20}, + {"name": "Bob", "age": 30}]) + lance.write_dataset(tbl, "./alice_and_bob.lance", mode="overwrite") + +.. doctest:: insert_if_not_exists + + >>> # Bob is already in the table, but Carla is new + >>> new_table = pa.Table.from_pylist([{"name": "Bob", "age": 30}, + ... {"name": "Carla", "age": 37}]) + >>> + >>> dataset = lance.dataset("./alice_and_bob.lance") + >>> + >>> # This will insert Carla but leave Bob unchanged + >>> _ = dataset.merge_insert("name") \ + ... .when_not_matched_insert_all() \ + ... .execute(new_table) + >>> # Verify that Carla was added but Bob remains unchanged + >>> dataset.to_table().to_pandas() + name age + 0 Alice 20 + 1 Bob 30 + 2 Carla 37 + +Update or Insert (Upsert) +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Sometimes we want to combine both of the above behaviors. If a row +already exists we want to update it. If the row does not exist we want +to add it. This operation is sometimes called "upsert". We can use +the merge insert operation to do this as well: + +.. testsetup:: upsert + + # Create a fresh dataset + tbl = pa.Table.from_pylist([{"name": "Alice", "age": 20}, + {"name": "Bob", "age": 30}, + {"name": "Carla", "age": 37}]) + lance.write_dataset(tbl, "./alice_and_bob.lance", mode="overwrite") + +.. doctest:: upsert + + >>> import lance + >>> import pyarrow as pa + >>> + >>> # Change Carla's age and insert David + >>> new_table = pa.Table.from_pylist([{"name": "Carla", "age": 27}, + ... {"name": "David", "age": 42}]) + >>> + >>> dataset = lance.dataset("./alice_and_bob.lance") + >>> + >>> # This will update Carla and insert David + >>> _ = dataset.merge_insert("name") \ + ... .when_matched_update_all() \ + ... .when_not_matched_insert_all() \ + ... .execute(new_table) + >>> # Verify the results + >>> dataset.to_table().to_pandas() + name age + 0 Alice 20 + 1 Bob 30 + 2 Carla 27 + 3 David 42 + +Replace a Portion of Data +^^^^^^^^^^^^^^^^^^^^^^^^^ + +A less common, but still useful, behavior can be to replace some region +of existing rows (defined by a filter) with new data. This is similar +to performing both a delete and an insert in a single transaction. For +example: + +.. testsetup:: replace_portion + + import lance + import pyarrow as pa + + # Create a dataset with a mix of ages including some over 40 + tbl = pa.Table.from_pylist([{"name": "Alice", "age": 20}, + {"name": "Bob", "age": 30}, + {"name": "Charlie", "age": 45}, + {"name": "Donna", "age": 50}]) + lance.write_dataset(tbl, "./alice_and_bob.lance", mode="overwrite") + +.. doctest:: replace_portion + + >>> import lance + >>> import pyarrow as pa + >>> + >>> new_table = pa.Table.from_pylist([{"name": "Edgar", "age": 46}, + ... {"name": "Francene", "age": 44}]) + >>> + >>> dataset = lance.dataset("./alice_and_bob.lance") + >>> dataset.to_table().to_pandas() + name age + 0 Alice 20 + 1 Bob 30 + 2 Charlie 45 + 3 Donna 50 + >>> + >>> # This will remove anyone above 40 and insert our new data + >>> _ = dataset.merge_insert("name") \ + ... .when_not_matched_insert_all() \ + ... .when_not_matched_by_source_delete("age >= 40") \ + ... .execute(new_table) + >>> # Verify the results - people over 40 replaced with new data + >>> dataset.to_table().to_pandas() + name age + 0 Alice 20 + 1 Bob 30 + 2 Edgar 46 + 3 Francene 44 + +Reading Lance Dataset +--------------------- + +To open a Lance dataset, use the :py:meth:`lance.dataset` function: + +.. code-block:: python + + import lance + ds = lance.dataset("s3://bucket/path/imagenet.lance") + # Or local path + ds = lance.dataset("./imagenet.lance") + +.. note:: + + Lance supports local file system, AWS ``s3`` and Google Cloud Storage(``gs``) as storage backends + at the moment. Read more in `Object Store Configuration`_. + +The most straightforward approach for reading a Lance dataset is to utilize the :py:meth:`lance.LanceDataset.to_table` +method in order to load the entire dataset into memory. + +.. code-block:: python + + table = ds.to_table() + +Due to Lance being a high-performance columnar format, it enables efficient reading of subsets of the dataset by utilizing +**Column (projection)** push-down and **filter (predicates)** push-downs. + +.. code-block:: python + + table = ds.to_table( + columns=["image", "label"], + filter="label = 2 AND text IS NOT NULL", + limit=1000, + offset=3000) + +Lance understands the cost of reading heavy columns such as ``image``. +Consequently, it employs an optimized query plan to execute the operation efficiently. + +Iterative Read +~~~~~~~~~~~~~~ + +If the dataset is too large to fit in memory, you can read it in batches +using the :py:meth:`lance.LanceDataset.to_batches` method: + +.. code-block:: python + + for batch in ds.to_batches(columns=["image"], filter="label = 10"): + # do something with batch + compute_on_batch(batch) + +Unsurprisingly, :py:meth:`~lance.LanceDataset.to_batches` takes the same parameters +as :py:meth:`~lance.LanceDataset.to_table` function. + + +.. _filter-push-down: + +Filter push-down +~~~~~~~~~~~~~~~~ + +Lance embraces the utilization of standard SQL expressions as predicates for dataset filtering. +By pushing down the SQL predicates directly to the storage system, +the overall I/O load during a scan is significantly reduced. + +Currently, Lance supports a growing list of expressions. + +* ``>``, ``>=``, ``<``, ``<=``, ``=`` +* ``AND``, ``OR``, ``NOT`` +* ``IS NULL``, ``IS NOT NULL`` +* ``IS TRUE``, ``IS NOT TRUE``, ``IS FALSE``, ``IS NOT FALSE`` +* ``IN`` +* ``LIKE``, ``NOT LIKE`` +* ``regexp_match(column, pattern)`` +* ``CAST`` + +For example, the following filter string is acceptable: + +.. code-block:: SQL + + ((label IN [10, 20]) AND (note['email'] IS NOT NULL)) + OR NOT note['created'] + +Nested fields can be accessed using the subscripts. Struct fields can be +subscripted using field names, while list fields can be subscripted using +indices. + +If your column name contains special characters or is a `SQL Keyword `_, +you can use backtick (`````) to escape it. For nested fields, each segment of the +path must be wrapped in backticks. + +.. code-block:: SQL + + `CUBE` = 10 AND `column name with space` IS NOT NULL + AND `nested with space`.`inner with space` < 2 + +.. warning:: + + Field names containing periods (``.``) are not supported. + +Literals for dates, timestamps, and decimals can be written by writing the string +value after the type name. For example + +.. code-block:: SQL + + date_col = date '2021-01-01' + and timestamp_col = timestamp '2021-01-01 00:00:00' + and decimal_col = decimal(8,3) '1.000' + +For timestamp columns, the precision can be specified as a number in the type +parameter. Microsecond precision (6) is the default. + +.. list-table:: + :widths: 30 40 + :header-rows: 1 + + * - SQL + - Time unit + * - ``timestamp(0)`` + - Seconds + * - ``timestamp(3)`` + - Milliseconds + * - ``timestamp(6)`` + - Microseconds + * - ``timestamp(9)`` + - Nanoseconds + +Lance internally stores data in Arrow format. The mapping from SQL types to Arrow +is: + +.. list-table:: + :widths: 30 40 + :header-rows: 1 + + * - SQL type + - Arrow type + * - ``boolean`` + - ``Boolean`` + * - ``tinyint`` / ``tinyint unsigned`` + - ``Int8`` / ``UInt8`` + * - ``smallint`` / ``smallint unsigned`` + - ``Int16`` / ``UInt16`` + * - ``int`` or ``integer`` / ``int unsigned`` or ``integer unsigned`` + - ``Int32`` / ``UInt32`` + * - ``bigint`` / ``bigint unsigned`` + - ``Int64`` / ``UInt64`` + * - ``float`` + - ``Float32`` + * - ``double`` + - ``Float64`` + * - ``decimal(precision, scale)`` + - ``Decimal128`` + * - ``date`` + - ``Date32`` + * - ``timestamp`` + - ``Timestamp`` (1) + * - ``string`` + - ``Utf8`` + * - ``binary`` + - ``Binary`` + +(1) See precision mapping in previous table. + + +Random read +~~~~~~~~~~~ + +One district feature of Lance, as columnar format, is that it allows you to read random samples quickly. + +.. code-block:: python + + # Access the 2nd, 101th and 501th rows + data = ds.take([1, 100, 500], columns=["image", "label"]) + +The ability to achieve fast random access to individual rows plays a crucial role in facilitating various workflows +such as random sampling and shuffling in ML training. +Additionally, it empowers users to construct secondary indices, +enabling swift execution of queries for enhanced performance. + + +Table Maintenance +----------------- + +Some operations over time will cause a Lance dataset to have a poor layout. For +example, many small appends will lead to a large number of small fragments. Or +deleting many rows will lead to slower queries due to the need to filter out +deleted rows. + +To address this, Lance provides methods for optimizing dataset layout. + +Compact data files +~~~~~~~~~~~~~~~~~~ + +Data files can be rewritten so there are fewer files. When passing a +``target_rows_per_fragment`` to :py:meth:`lance.dataset.DatasetOptimizer.compact_files`, +Lance will skip any fragments that are already above that row count, and rewrite +others. Fragments will be merged according to their fragment ids, so the inherent +ordering of the data will be preserved. + +.. note:: + + Compaction creates a new version of the table. It does not delete the old + version of the table and the files referenced by it. + +.. code-block:: python + + import lance + + dataset = lance.dataset("./alice_and_bob.lance") + dataset.optimize.compact_files(target_rows_per_fragment=1024 * 1024) + +During compaction, Lance can also remove deleted rows. Rewritten fragments will +not have deletion files. This can improve scan performance since the soft deleted +rows don't have to be skipped during the scan. + +When files are rewritten, the original row addresses are invalidated. This means the +affected files are no longer part of any ANN index if they were before. Because +of this, it's recommended to rewrite files before re-building indices. + +.. TODO: remove this last comment once move-stable row ids are default. diff --git a/docs/introduction/schema_evolution.rst b/docs/introduction/schema_evolution.rst new file mode 100644 index 00000000000..0fa70cf93e8 --- /dev/null +++ b/docs/introduction/schema_evolution.rst @@ -0,0 +1,268 @@ +Schema Evolution +================ + +Lance supports schema evolution: adding, removing, and altering columns in a +dataset. Most of these operations can be performed *without* rewriting the +data files in the dataset, making them very efficient operations. + +In general, schema changes will conflict with most other concurrent write +operations. For example, if you change the schema of the dataset while someone +else is appending data to it, either your schema change or the append will fail, +depending on the order of the operations. Thus, it's recommended to perform +schema changes when no other writes are happening. + +Renaming columns +~~~~~~~~~~~~~~~~ + +Columns can be renamed using the :py:meth:`lance.LanceDataset.alter_columns` +method. + +.. testsetup:: + + shutil.rmtree("ids", ignore_errors=True) + +.. testcode:: + + table = pa.table({"id": pa.array([1, 2, 3])}) + dataset = lance.write_dataset(table, "ids") + dataset.alter_columns({"path": "id", "name": "new_id"}) + print(dataset.to_table().to_pandas()) + +.. testoutput:: + + new_id + 0 1 + 1 2 + 2 3 + +This works for nested columns as well. To address a nested column, use a dot +(``.``) to separate the levels of nesting. For example: + +.. testsetup:: + + shutil.rmtree("nested_rename", ignore_errors=True) + +.. testcode:: + + data = [ + {"meta": {"id": 1, "name": "Alice"}}, + {"meta": {"id": 2, "name": "Bob"}}, + ] + schema = pa.schema([ + ("meta", pa.struct([ + ("id", pa.int32()), + ("name", pa.string()), + ])) + ]) + dataset = lance.write_dataset(data, "nested_rename") + dataset.alter_columns({"path": "meta.id", "name": "new_id"}) + print(dataset.to_table().to_pandas()) + +.. testoutput:: + + meta + 0 {'new_id': 1, 'name': 'Alice'} + 1 {'new_id': 2, 'name': 'Bob'} + + +Casting column data types +~~~~~~~~~~~~~~~~~~~~~~~~~ + +In addition to changing column names, you can also change the data type of a +column using the :py:meth:`lance.LanceDataset.alter_columns` method. This +requires rewriting that column to new data files, but does not require rewriting +the other columns. + +.. note:: + + If the column has an index, the index will be dropped if the column type is + changed. + +This method can be used to change the vector type of a column. For example, we +can change a float32 embedding column into a float16 column to save disk space +at the cost of lower precision: + +.. testcode:: + + table = pa.table({ + "id": pa.array([1, 2, 3]), + "embedding": pa.FixedShapeTensorArray.from_numpy_ndarray( + np.random.rand(3, 128).astype("float32")) + }) + dataset = lance.write_dataset(table, "embeddings") + dataset.alter_columns({"path": "embedding", + "data_type": pa.list_(pa.float16(), 128)}) + print(dataset.schema) + +.. testoutput:: + + id: int64 + embedding: fixed_size_list[128] + child 0, item: halffloat + + +Adding new columns +~~~~~~~~~~~~~~~~~~~ + +New columns can be added and populated within a single operation using the +:py:meth:`lance.LanceDataset.add_columns` method. There are two ways to specify +how to populate the new columns: first, by providing a SQL expression for each +new column, or second, by providing a function to generate the new column data. + +SQL expressions can either be independent expressions or reference existing +columns. SQL literal values can be used to set a single value for all +existing rows. + +.. testsetup:: + + shutil.rmtree("./names", ignore_errors=True) + +.. testcode:: + + table = pa.table({"name": pa.array(["Alice", "Bob", "Carla"])}) + dataset = lance.write_dataset(table, "names") + dataset.add_columns({ + "hash": "sha256(name)", + "status": "'active'", + }) + print(dataset.to_table().to_pandas()) + +.. testoutput:: + + name hash status + 0 Alice b';\xc5\x10b\x97>> table = pa.table({"id": pa.array([1, 2, 3]), + ... "name": pa.array(["Alice", "Bob", "Carla"])}) + >>> dataset = lance.write_dataset(table, "names", mode="overwrite") + >>> dataset.drop_columns(["name"]) + >>> dataset.schema + id: int64 + + +To actually remove the data from disk, the files must be rewritten to remove the +columns and then the old files must be deleted. This can be done using +:py:meth:`lance.dataset.DatasetOptimizer.compact_files()` followed by +:py:meth:`lance.LanceDataset.cleanup_old_versions()`. \ No newline at end of file diff --git a/docs/object_store.rst b/docs/object_store.rst new file mode 100644 index 00000000000..175cc849d58 --- /dev/null +++ b/docs/object_store.rst @@ -0,0 +1,364 @@ +Object Store Configuration +========================== + +Lance supports object stores such as AWS S3 (and compatible stores), Azure Blob Store, +and Google Cloud Storage. Which object store to use is determined by the URI scheme of +the dataset path. For example, ``s3://bucket/path`` will use S3, ``az://bucket/path`` +will use Azure, and ``gs://bucket/path`` will use GCS. + +.. versionadded:: 0.10.7 + + Passing options directly to storage options. + +These object stores take additional configuration objects. There are two ways to +specify these configurations: by setting environment variables or by passing them +to the ``storage_options`` parameter of :py:meth:`lance.dataset` and +:py:func:`lance.write_dataset`. So for example, to globally set a higher timeout, +you would run in your shell: + +.. code-block:: bash + + export TIMEOUT=60s + +If you only want to set the timeout for a single dataset, you can pass it as a +storage option: + +.. code-block:: python + + import lance + ds = lance.dataset("s3://path", storage_options={"timeout": "60s"}) + + +General Configuration +~~~~~~~~~~~~~~~~~~~~~ + +These options apply to all object stores. + +.. from https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Key + - Description + * - ``allow_http`` + - Allow non-TLS, i.e. non-HTTPS connections. Default, ``False``. + * - ``download_retry_count`` + - Number of times to retry a download. Default, ``3``. This limit is applied when + the HTTP request succeeds but the response is not fully downloaded, typically due + to a violation of ``request_timeout``. + * - ``allow_invalid_certificates`` + - Skip certificate validation on https connections. Default, ``False``. + Warning: This is insecure and should only be used for testing. + * - ``connect_timeout`` + - Timeout for only the connect phase of a Client. Default, ``5s``. + * - ``request_timeout`` + - Timeout for the entire request, from connection until the response body + has finished. Default, ``30s``. + * - ``user_agent`` + - User agent string to use in requests. + * - ``proxy_url`` + - URL of a proxy server to use for requests. Default, ``None``. + * - ``proxy_ca_certificate`` + - PEM-formatted CA certificate for proxy connections + * - ``proxy_excludes`` + - List of hosts that bypass proxy. This is a comma separated list of domains + and IP masks. Any subdomain of the provided domain will be bypassed. For + example, ``example.com, 192.168.1.0/24`` would bypass ``https://api.example.com``, + ``https://www.example.com``, and any IP in the range ``192.168.1.0/24``. + * - ``client_max_retries`` + - Number of times for a s3 client to retry the request. Default, ``10``. + * - ``client_retry_timeout`` + - Timeout for a s3 client to retry the request in seconds. Default, ``180``. + +S3 Configuration +~~~~~~~~~~~~~~~~ + +S3 (and S3-compatible stores) have additional configuration options that configure +authorization and S3-specific features (such as server-side encryption). + +AWS credentials can be set in the environment variables ``AWS_ACCESS_KEY_ID``, +``AWS_SECRET_ACCESS_KEY``, and ``AWS_SESSION_TOKEN``. Alternatively, they can be +passed as parameters to the ``storage_options`` parameter: + +.. code-block:: python + + import lance + ds = lance.dataset( + "s3://bucket/path", + storage_options={ + "access_key_id": "my-access-key", + "secret_access_key": "my-secret-key", + "session_token": "my-session-token", + } + ) + +If you are using AWS SSO, you can specify the ``AWS_PROFILE`` environment variable. +It cannot be specified in the ``storage_options`` parameter. + +The following keys can be used as both environment variables or keys in the +``storage_options`` parameter: + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Key + - Description + * - ``aws_region`` / ``region`` + - The AWS region the bucket is in. This can be automatically detected when + using AWS S3, but must be specified for S3-compatible stores. + * - ``aws_access_key_id`` / ``access_key_id`` + - The AWS access key ID to use. + * - ``aws_secret_access_key`` / ``secret_access_key`` + - The AWS secret access key to use. + * - ``aws_session_token`` / ``session_token`` + - The AWS session token to use. + * - ``aws_endpoint`` / ``endpoint`` + - The endpoint to use for S3-compatible stores. + * - ``aws_virtual_hosted_style_request`` / ``virtual_hosted_style_request`` + - Whether to use virtual hosted-style requests, where bucket name is part + of the endpoint. Meant to be used with ``aws_endpoint``. Default, ``False``. + * - ``aws_s3_express`` / ``s3_express`` + - Whether to use S3 Express One Zone endpoints. Default, ``False``. See more + details below. + * - ``aws_server_side_encryption`` + - The server-side encryption algorithm to use. Must be one of ``"AES256"``, + ``"aws:kms"``, or ``"aws:kms:dsse"``. Default, ``None``. + * - ``aws_sse_kms_key_id`` + - The KMS key ID to use for server-side encryption. If set, + ``aws_server_side_encryption`` must be ``"aws:kms"`` or ``"aws:kms:dsse"``. + * - ``aws_sse_bucket_key_enabled`` + - Whether to use bucket keys for server-side encryption. + + +S3-compatible stores +^^^^^^^^^^^^^^^^^^^^ + +Lance can also connect to S3-compatible stores, such as MinIO. To do so, you must +specify both region and endpoint: + +.. code-block:: python + + import lance + ds = lance.dataset( + "s3://bucket/path", + storage_options={ + "region": "us-east-1", + "endpoint": "http://minio:9000", + } + ) + +This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables. + +S3 Express +^^^^^^^^^^ + +.. versionadded:: 0.9.7 + +Lance supports `S3 Express One Zone`_ endpoints, but requires additional configuration. Also, +S3 Express endpoints only support connecting from an EC2 instance within the same +region + +.. _S3 Express One Zone: https://aws.amazon.com/s3/storage-classes/express-one-zone/ + +To configure Lance to use an S3 Express endpoint, you must set the storage option +``s3_express``. The bucket name in your table URI should **include the suffix**. + +.. code-block:: python + + import lance + ds = lance.dataset( + "s3://my-bucket--use1-az4--x-s3/path/imagenet.lance", + storage_options={ + "region": "us-east-1", + "s3_express": "true", + } + ) + + +Committing mechanisms for S3 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. deprecated:: + + S3 now supports atomic put-if-not-exists, so this feature is no longer necessary. + It will be removed in a future version. You should migrate tables to use the + new feature by removing the commit locks from all writers at the same time. Note + that it is unsafe to mix writers with and without commit locks on the same dataset. + +Most supported storage systems (e.g. local file system, Google Cloud Storage, +Azure Blob Store) natively support atomic commits, which prevent concurrent +writers from corrupting the dataset. However, S3 does not support this natively. +To work around this, you may provide a locking mechanism that Lance can use to +lock the table while providing a write. To do so, you should implement a +context manager that acquires and releases a lock and then pass that to the +``commit_lock`` parameter of :py:meth:`lance.write_dataset`. + +.. note:: + + In order for the locking mechanism to work, all writers must use the same exact + mechanism. Otherwise, Lance will not be able to detect conflicts. + +On entering, the context manager should acquire the lock on the table. The table +version being committed is passed in as an argument, which may be used if the +locking service wishes to keep track of the current version of the table, but +this is not required. If the table is already locked by another transaction, +it should wait until it is unlocked, since the other transaction may fail. Once +unlocked, it should either lock the table or, if the lock keeps track of the +current version of the table, return a :class:`CommitConflictError` if the +requested version has already been committed. + +To prevent poisoned locks, it's recommended to set a timeout on the locks. That +way, if a process crashes while holding the lock, the lock will be released +eventually. The timeout should be no less than 30 seconds. + +.. code-block:: python + + from contextlib import contextmanager + + @contextmanager + def commit_lock(version: int); + # Acquire the lock + my_lock.acquire() + try: + yield + except: + failed = True + finally: + my_lock.release() + + lance.write_dataset(data, "s3://bucket/path/", commit_lock=commit_lock) + +When the context manager is exited, it will raise an exception if the commit +failed. This might be because of a network error or if the version has already +been written. Either way, the context manager should release the lock. Use a +try/finally block to ensure that the lock is released. + +Concurrent Writer on S3 using DynamoDB +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. warning:: + + This feature is experimental at the moment + +Lance has native support for concurrent writers on S3 using DynamoDB instead of locking. +User may pass in a DynamoDB table name alone with the S3 URI to their dataset to enable this feature. + +.. code-block:: python + + import lance + # s3+ddb:// URL scheme let's lance know that you want to + # use DynamoDB for writing to S3 concurrently + ds = lance.dataset("s3+ddb://my-bucket/mydataset?ddbTableName=mytable") + +The DynamoDB table is expected to have a primary hash key of ``base_uri`` and a range key ``version``. +The key ``base_uri`` should be string type, and the key ``version`` should be number type. + +For details on how this feature works, please see :ref:`external-manifest-store`. + + +Google Cloud Storage Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +GCS credentials are configured by setting the ``GOOGLE_SERVICE_ACCOUNT`` environment +variable to the path of a JSON file containing the service account credentials. +Alternatively, you can pass the path to the JSON file in the ``storage_options`` + +.. code-block:: python + + import lance + ds = lance.dataset( + "gs://my-bucket/my-dataset", + storage_options={ + "service_account": "path/to/service-account.json", + } + ) + +.. note:: + + By default, GCS uses HTTP/1 for communication, as opposed to HTTP/2. This improves + maximum throughput significantly. However, if you wish to use HTTP/2 for some reason, + you can set the environment variable ``HTTP1_ONLY`` to ``false``. + + +The following keys can be used as both environment variables or keys in the +``storage_options`` parameter: + +.. source: https://docs.rs/object_store/latest/object_store/gcp/enum.GoogleConfigKey.html + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Key + - Description + * - ``google_service_account`` / ``service_account`` + - Path to the service account JSON file. + * - ``google_service_account_key`` / ``service_account_key`` + - The serialized service account key. + * - ``google_application_credentials`` / ``application_credentials`` + - Path to the application credentials. + + +Azure Blob Storage Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Azure Blob Storage credentials can be configured by setting the ``AZURE_STORAGE_ACCOUNT_NAME`` +and ``AZURE_STORAGE_ACCOUNT_KEY`` environment variables. Alternatively, you can pass +the account name and key in the ``storage_options`` parameter: + +.. code-block:: python + + import lance + ds = lance.dataset( + "az://my-container/my-dataset", + storage_options={ + "account_name": "some-account", + "account_key": "some-key", + } + ) + +These keys can be used as both environment variables or keys in the ``storage_options`` parameter: + +.. source: https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Key + - Description + * - ``azure_storage_account_name`` / ``account_name`` + - The name of the azure storage account. + * - ``azure_storage_account_key`` / ``account_key`` + - The serialized service account key. + * - ``azure_client_id`` / ``client_id`` + - Service principal client id for authorizing requests. + * - ``azure_client_secret`` / ``client_secret`` + - Service principal client secret for authorizing requests. + * - ``azure_tenant_id`` / ``tenant_id`` + - Tenant id used in oauth flows. + * - ``azure_storage_sas_key`` / ``azure_storage_sas_token`` / ``sas_key`` / ``sas_token`` + - Shared access signature. The signature is expected to be percent-encoded, much like they are provided in the azure storage explorer or azure portal. + * - ``azure_storage_token`` / ``bearer_token`` / ``token`` + - Bearer token. + * - ``azure_storage_use_emulator`` / ``object_store_use_emulator`` / ``use_emulator`` + - Use object store with azurite storage emulator. + * - ``azure_endpoint`` / ``endpoint`` + - Override the endpoint used to communicate with blob storage. + * - ``azure_use_fabric_endpoint`` / ``use_fabric_endpoint`` + - Use object store with url scheme account.dfs.fabric.microsoft.com. + * - ``azure_msi_endpoint`` / ``azure_identity_endpoint`` / ``identity_endpoint`` / ``msi_endpoint`` + - Endpoint to request a imds managed identity token. + * - ``azure_object_id`` / ``object_id`` + - Object id for use with managed identity authentication. + * - ``azure_msi_resource_id`` / ``msi_resource_id`` + - Msi resource id for use with managed identity authentication. + * - ``azure_federated_token_file`` / ``federated_token_file`` + - File containing token for Azure AD workload identity federation. + * - ``azure_use_azure_cli`` / ``use_azure_cli`` + - Use azure cli for acquiring access token. + * - ``azure_disable_tagging`` / ``disable_tagging`` + - Disables tagging objects. This can be desirable if not supported by the backing store. \ No newline at end of file diff --git a/docs/performance.rst b/docs/performance.rst index 2684a3d234f..712155e1850 100644 --- a/docs/performance.rst +++ b/docs/performance.rst @@ -3,6 +3,89 @@ Lance Performance Guide This guide provides tips and tricks for optimizing the performance of your Lance applications. +Trace Events +------------ + +Lance uses tracing to log events. If you are running ``pylance`` then these events will be emitted to +as log messages. For Rust connections you can use the ``tracing`` crate to capture these events. + +File Audit +~~~~~~~~~~ + +File audit events are emitted when significant files are created or deleted. + +.. list-table:: + :widths: 20 20 60 + :header-rows: 1 + + * - Event + - Parameter + - Description + + * - ``lance::file_audit`` + - ``mode`` + - The mode of I/O operation (create, delete, delete_unverified) + * - ``lance::file_audit`` + - ``type`` + - The type of file affected (manifest, data file, index file, deletion file) + +I/O Events +~~~~~~~~~~ + +I/O events are emitted when significant I/O operations are performed, particularly +those related to indices. These events are NOT emitted when the index is loaded from +the in-memory cache. Correct cache utilization is important for performance and these +events are intended to help you debug cache usage. + +.. list-table:: + :widths: 20 20 60 + :header-rows: 1 + + * - Event + - Parameter + - Description + + * - ``lance::io_events`` + - ``type`` + - The type of I/O operation (open_scalar_index, open_vector_index, load_vector_part, load_scalar_part) + +Execution Events +~~~~~~~~~~~~~~~~ + +Execution events are emitted when an execution plan is run. These events are useful for +debugging query performance. + +.. list-table:: + :widths: 20 20 60 + :header-rows: 1 + + * - Event + - Parameter + - Description + + * - ``lance::execution`` + - ``type`` + - The type of execution event (plan_run is the only type today) + * - ``lance::execution`` + - ``output_rows`` + - The number of rows in the output of the plan + * - ``lance::execution`` + - ``iops`` + - The number of I/O operations performed by the plan + * - ``lance::execution`` + - ``bytes_read`` + - The number of bytes read by the plan + * - ``lance::execution`` + - ``indices_loaded`` + - The number of indices loaded by the plan + * - ``lance::execution`` + - ``parts_loaded`` + - The number of index partitions loaded by the plan + * - ``lance::execution`` + - ``index_comparisons`` + - The number of comparisons performed inside the various indices + + Threading Model --------------- diff --git a/docs/read_and_write.rst b/docs/read_and_write.rst deleted file mode 100644 index 49d490833ea..00000000000 --- a/docs/read_and_write.rst +++ /dev/null @@ -1,1014 +0,0 @@ -Read and Write Lance Dataset -============================ - -Lance dataset APIs follows the `PyArrow API `_ -conventions. - -Writing Lance Dataset ---------------------- - -Similar to Apache Pyarrow, the simplest approach to create a Lance dataset is -writing a :py:class:`pyarrow.Table` via :py:meth:`lance.write_dataset`. - -.. code-block:: python - - import lance - import pyarrow as pa - - table = pa.Table.from_pylist([{"name": "Alice", "age": 20}, - {"name": "Bob", "age": 30}]) - lance.write_dataset(table, "./alice_and_bob.lance") - -If the memory footprint of the dataset is too large to fit in memory, :py:meth:`lance.write_dataset` -also supports writing a dataset in iterator of :py:class:`pyarrow.RecordBatch` es. - -.. code-block:: python - - import lance - import pyarrow as pa - - def producer(): - yield pa.RecordBatch.from_pylist([{"name": "Alice", "age": 20}]) - yield pa.RecordBatch.from_pylist([{"name": "Blob", "age": 30}]) - - schema = pa.schema([ - pa.field("name", pa.string()), - pa.field("age", pa.int64()), - ]) - - lance.write_dataset(reader, "./alice_and_bob.lance", schema) - -:py:meth:`lance.write_dataset` supports writing :py:class:`pyarrow.Table`, :py:class:`pandas.DataFrame`, -:py:class:`pyarrow.Dataset`, and ``Iterator[pyarrow.RecordBatch]``. Check its doc for more details. - -Deleting rows -~~~~~~~~~~~~~ - -Lance supports deleting rows from a dataset using a SQL filter. For example, to -delete Bob's row from the dataset above, one could use: - -.. code-block:: python - - import lance - - dataset = lance.dataset("./alice_and_bob.lance") - dataset.delete("name = 'Bob'") - -:py:meth:`lance.LanceDataset.delete` supports the same filters as described in -:ref:`filter-push-down`. - -Rows are deleted by marking them as deleted in a separate deletion index. This is -faster than rewriting the files and also avoids invaliding any indices that point -to those files. Any subsequent queries will not return the deleted rows. - -.. warning:: - - Do not read datasets with deleted rows using Lance versions prior to 0.5.0, - as they will return the deleted rows. This is fixed in 0.5.0 and later. - -Updating rows -~~~~~~~~~~~~~ - -Lance supports updating rows based on SQL expressions with the -:py:meth:`lance.LanceDataset.update` method. For example, if we notice -that Bob's name in our dataset has been sometimes written as ``Blob``, we can fix -that with: - -.. code-block:: python - - import lance - - dataset = lance.dataset("./alice_and_bob.lance") - dataset.update({"name": "'Bob'"}), where="name = 'Blob'") - -The update values are SQL expressions, which is why ``'Bob'`` is wrapped in single -quotes. This means we can use complex expressions that reference existing columns if -we wish. For example, if two years have passed and we wish to update the ages -of Alice and Bob in the same example, we could write: - -.. code-block:: python - - import lance - - dataset = lance.dataset("./alice_and_bob.lance") - dataset.update({"age": "age + 2"}) - -If you are trying to update a set of individual rows with new values then it is often -more efficient to use the merge insert operation described below. - -.. code-block:: python - - import lance - - # Change the ages of both Alice and Bob - new_table = pa.Table.from_pylist([{"name": "Alice", "age": 30}, - {"name": "Bob", "age": 20}]) - - # This works, but is inefficient, see below for a better approach - dataset = lance.dataset("./alice_and_bob.lance") - for idx in range(new_table.num_rows): - name = new_table[0][idx].as_py() - new_age = new_table[1][idx].as_py() - dataset.update({"age": new_age}, where=f"name='{name}'") - -Merge Insert -~~~~~~~~~~~~ - -Lance supports a merge insert operation. This can be used to add new data in bulk -while also (potentially) matching against existing data. This operation can be used -for a number of different use cases. - -Bulk Update -^^^^^^^^^^^ - -The :py:meth:`lance.LanceDataset.update` method is useful for updating rows based on -a filter. However, if we want to replace existing rows with new rows then a merge -insert operation would be more efficient: - -.. code-block:: python - - import lance - - # Change the ages of both Alice and Bob - new_table = pa.Table.from_pylist([{"name": "Alice", "age": 30}, - {"name": "Bob", "age": 20}]) - dataset = lance.dataset("./alice_and_bob.lance") - # This will use `name` as the key for matching rows. Merge insert - # uses a JOIN internally and so you typically want this column to - # be a unique key or id of some kind. - dataset.merge_insert("name") \ - .when_matched_update_all() \ - .execute(new_table) - -Note that, similar to the update operation, rows that are modified will -be removed and inserted back into the table, changing their position to -the end. Also, the relative order of these rows could change because we -are using a hash-join operation internally. - -Insert if not Exists -^^^^^^^^^^^^^^^^^^^^ - -Sometimes we only want to insert data if we haven't already inserted it -before. This can happen, for example, when we have a batch of data but -we don't know which rows we've added previously and we don't want to -create duplicate rows. We can use the merge insert operation to achieve -this: - -.. code-block:: python - - import lance - - # Bob is already in the table, but Carla is new - new_table = pa.Table.from_pylist([{"name": "Bob", "age": 30}, - {"name": "Carla", "age": 37}]) - - dataset = lance.dataset("./alice_and_bob.lance") - - # This will insert Carla but leave Bob unchanged - dataset.merge_insert("name") \ - .when_not_matched_insert_all() \ - .execute(new_table) - -Update or Insert (Upsert) -^^^^^^^^^^^^^^^^^^^^^^^^^ - -Sometimes we want to combine both of the above behaviors. If a row -already exists we want to update it. If the row does not exist we want -to add it. This operation is sometimes called "upsert". We can use -the merge insert operation to do this as well: - -.. code-block:: python - - import lance - - # Change Carla's age and insert David - new_table = pa.Table.from_pylist([{"name": "Carla", "age": 27}, - {"name": "David", "age": 42}]) - - dataset = lance.dataset("./alice_and_bob.lance") - - # This will update Carla and insert David - dataset.merge_insert("name") \ - .when_matched_update_all() \ - .when_not_matched_insert_all() \ - .execute(new_table) - -Replace a Portion of Data -^^^^^^^^^^^^^^^^^^^^^^^^^ - -A less common, but still useful, behavior can be to replace some region -of existing rows (defined by a filter) with new data. This is similar -to performing both a delete and an insert in a single transaction. For -example: - -.. code-block:: python - - import lance - - new_table = pa.Table.from_pylist([{"name": "Edgar", "age": 46}, - {"name": "Francene", "age": 44}]) - - dataset = lance.dataset("./alice_and_bob.lance") - - # This will remove anyone above 40 and insert our new data - dataset.merge_insert("name") \ - .when_not_matched_insert_all() \ - .when_not_matched_by_source_delete("age >= 40") \ - .execute(new_table) - - -Evolving the schema -------------------- - -Lance supports schema evolution: adding, removing, and altering columns in a -dataset. Most of these operations can be performed *without* rewriting the -data files in the dataset, making them very efficient operations. - -In general, schema changes will conflict with most other concurrent write -operations. For example, if you change the schema of the dataset while someone -else is appending data to it, either your schema change or the append will fail, -depending on the order of the operations. Thus, it's recommended to perform -schema changes when no other writes are happening. - -Renaming columns -~~~~~~~~~~~~~~~~ - -Columns can be renamed using the :py:meth:`lance.LanceDataset.alter_columns` -method. - -.. testcode:: - - import lance - import pyarrow as pa - table = pa.table({"id": pa.array([1, 2, 3])}) - dataset = lance.write_dataset(table, "ids") - dataset.alter_columns({"path": "id", "name": "new_id"}) - dataset.to_table().to_pandas() - -.. testoutput:: - - new_id - 0 1 - 1 2 - 2 3 - -This works for nested columns as well. To address a nested column, use a dot -(``.``) to separate the levels of nesting. For example: - -.. testcode:: - - data = [ - {"meta": {"id": 1, "name": "Alice"}}, - {"meta": {"id": 2, "name": "Bob"}}, - ] - dataset = lance.write_dataset(data, "nested_rename") - dataset.alter_columns({"path": "meta.id", "name": "new_id"}) - -.. testoutput:: - - meta - 0 {"new_id": 1, "name": "Alice"} - 1 {"new_id": 2, "name": "Bob"} - - -Casting column data types -~~~~~~~~~~~~~~~~~~~~~~~~~ - -In addition to changing column names, you can also change the data type of a -column using the :py:meth:`lance.LanceDataset.alter_columns` method. This -requires rewriting that column to new data files, but does not require rewriting -the other columns. - -.. note:: - - If the column has an index, the index will be dropped if the column type is - changed. - -This method can be used to change the vector type of a column. For example, we -can change a float32 embedding column into a float16 column to save disk space -at the cost of lower precision: - -.. testcode:: - - import lance - import pyarrow as pa - import numpy as np - table = pa.table({ - "id": pa.array([1, 2, 3]), - "embedding": pa.FixedShapeTensorArray.from_numpy_ndarray( - np.random.rand(3, 128).astype("float32")) - }) - dataset = lance.write_dataset(table, "embeddings") - dataset.alter_columns({"path": "embedding", - "type": pa.list_(pa.float16(), 128)}) - dataset.schema() - -.. testoutput:: - - id: int64 - embedding: fixed_size_list - - -Adding new columns -~~~~~~~~~~~~~~~~~~~ - -New columns can be added and populated within a single operation using the -:py:meth:`lance.LanceDataset.add_columns` method. There are two ways to specify -how to populate the new columns: first, by providing a SQL expression for each -new column, or second, by providing a function to generate the new column data. - -SQL expressions can either be independent expressions or reference existing -columns. SQL literal values can be used to set a single value for all -existing rows. - -.. testcode:: - - import lance - import pyarrow as pa - table = pa.table({"name": pa.array(["Alice", "Bob", "Carla"])}) - dataset = lance.write_dataset(table, "names") - dataset.add_columns({ - "hash": "sha256(name)", - "status": "'active'", - }) - dataset.to_table().to_pandas() - -.. testoutput:: - - name hash... status - 0 Alice 3bc51062973c... active - 1 Bob cd9fb1e148cc... active - 2 Carla ad8d83ffd82b... active - -You can also provide a Python function to generate the new column data. This can -be used, for example, to compute a new embedding column. This function should -take a PyArrow RecordBatch and return either a PyArrow RecordBatch or a Pandas -DataFrame. The function will be called once for each batch in the dataset. - -If the function is expensive to compute and can fail, it is recommended to set -a checkpoint file in the UDF. This checkpoint file saves the state of the UDF -after each invocation, so that if the UDF fails, it can be restarted from the -last checkpoint. Note that this file can get quite large, since it needs to store -unsaved results for up to an entire data file. - -.. code-block:: - - import lance - import pyarrow as pa - import numpy as np - - table = pa.table({"id": pa.array([1, 2, 3])}) - dataset = lance.write_dataset(table, "ids") - - @lance.batch_udf(checkpoint_file="embedding_checkpoint.sqlite") - def add_random_vector(batch): - embeddings = np.random.rand(batch.num_rows, 128).astype("float32") - return pd.DataFrame({"embedding": embeddings}) - dataset.add_columns(add_random_vector) - - -Adding new columns using merge -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you have pre-computed one or more new columns, you can add them to an existing -dataset using the :py:meth:`lance.LanceDataset.merge` method. This allows filling in -additional columns without having to rewrite the whole dataset. - - -To use the ``merge`` method, provide a new dataset that includes the columns you -want to add, and a column name to use for joining the new data to the existing -dataset. - -For example, imagine we have a dataset of embeddings and ids: - -.. testcode:: - - import lance - import pyarrow as pa - import numpy as np - table = pa.table({ - "id": pa.array([1, 2, 3]), - "embedding": pa.array([np.array([1, 2, 3]), np.array([4, 5, 6]), - np.array([7, 8, 9])]) - }) - dataset = lance.write_dataset(table, "embeddings") - -Now if we want to add a column of labels we have generated, we can do so by merging a new table: - -.. testcode:: - - new_data = pa.table({ - "id": pa.array([1, 2, 3]), - "label": pa.array(["horse", "rabbit", "cat"]) - }) - dataset.merge(new_data, "id") - dataset.to_table().to_pandas() - -.. testoutput:: - - id embedding label - 0 1 [1, 2, 3] horse - 1 2 [4, 5, 6] rabbit - 2 3 [7, 8, 9] cat - - -Dropping columns -~~~~~~~~~~~~~~~~ - -Finally, you can drop columns from a dataset using the :py:meth:`lance.LanceDataset.drop_columns` -method. This is a metadata-only operation and does not delete the data on disk. This makes -it very quick. - -.. testcode:: - - import lance - import pyarrow as pa - table = pa.table({"id": pa.array([1, 2, 3]), - "name": pa.array(["Alice", "Bob", "Carla"])}) - dataset = lance.write_dataset(table, "names") - dataset.drop_columns(["name"]) - dataset.schema() - -.. testoutput:: - - id: int64 - -To actually remove the data from disk, the files must be rewritten to remove the -columns and then the old files must be deleted. This can be done using -:py:meth:`lance.dataset.DatasetOptimizer.compact_files()` followed by -:py:meth:`lance.LanceDataset.cleanup_old_versions()`. - - -Reading Lance Dataset ---------------------- - -To open a Lance dataset, use the :py:meth:`lance.dataset` function: - -.. code-block:: python - - import lance - ds = lance.dataset("s3://bucket/path/imagenet.lance") - # Or local path - ds = lance.dataset("./imagenet.lance") - -.. note:: - - Lance supports local file system, AWS ``s3`` and Google Cloud Storage(``gs``) as storage backends - at the moment. Read more in `Object Store Configuration`_. - -The most straightforward approach for reading a Lance dataset is to utilize the :py:meth:`lance.LanceDataset.to_table` -method in order to load the entire dataset into memory. - -.. code-block:: python - - table = ds.to_table() - -Due to Lance being a high-performance columnar format, it enables efficient reading of subsets of the dataset by utilizing -**Column (projection)** push-down and **filter (predicates)** push-downs. - -.. code-block:: python - - table = ds.to_table( - columns=["image", "label"], - filter="label = 2 AND text IS NOT NULL", - limit=1000, - offset=3000) - -Lance understands the cost of reading heavy columns such as ``image``. -Consequently, it employs an optimized query plan to execute the operation efficiently. - -Iterative Read -~~~~~~~~~~~~~~ - -If the dataset is too large to fit in memory, you can read it in batches -using the :py:meth:`lance.LanceDataset.to_batches` method: - -.. code-block:: python - - for batch in ds.to_batches(columns=["image"], filter="label = 10"): - # do something with batch - compute_on_batch(batch) - -Unsurprisingly, :py:meth:`~lance.LanceDataset.to_batches` takes the same parameters -as :py:meth:`~lance.LanceDataset.to_table` function. - - -.. _filter-push-down: - -Filter push-down -~~~~~~~~~~~~~~~~ - -Lance embraces the utilization of standard SQL expressions as predicates for dataset filtering. -By pushing down the SQL predicates directly to the storage system, -the overall I/O load during a scan is significantly reduced. - -Currently, Lance supports a growing list of expressions. - -* ``>``, ``>=``, ``<``, ``<=``, ``=`` -* ``AND``, ``OR``, ``NOT`` -* ``IS NULL``, ``IS NOT NULL`` -* ``IS TRUE``, ``IS NOT TRUE``, ``IS FALSE``, ``IS NOT FALSE`` -* ``IN`` -* ``LIKE``, ``NOT LIKE`` -* ``regexp_match(column, pattern)`` -* ``CAST`` - -For example, the following filter string is acceptable: - -.. code-block:: SQL - - ((label IN [10, 20]) AND (note['email'] IS NOT NULL)) - OR NOT note['created'] - -Nested fields can be accessed using the subscripts. Struct fields can be -subscripted using field names, while list fields can be subscripted using -indices. - -If your column name contains special characters or is a `SQL Keyword `_, -you can use backtick (`````) to escape it. For nested fields, each segment of the -path must be wrapped in backticks. - -.. code-block:: SQL - - `CUBE` = 10 AND `column name with space` IS NOT NULL - AND `nested with space`.`inner with space` < 2 - -.. warning:: - - Field names containing periods (``.``) are not supported. - -Literals for dates, timestamps, and decimals can be written by writing the string -value after the type name. For example - -.. code-block:: SQL - - date_col = date '2021-01-01' - and timestamp_col = timestamp '2021-01-01 00:00:00' - and decimal_col = decimal(8,3) '1.000' - -For timestamp columns, the precision can be specified as a number in the type -parameter. Microsecond precision (6) is the default. - -.. list-table:: - :widths: 30 40 - :header-rows: 1 - - * - SQL - - Time unit - * - ``timestamp(0)`` - - Seconds - * - ``timestamp(3)`` - - Milliseconds - * - ``timestamp(6)`` - - Microseconds - * - ``timestamp(9)`` - - Nanoseconds - -Lance internally stores data in Arrow format. The mapping from SQL types to Arrow -is: - -.. list-table:: - :widths: 30 40 - :header-rows: 1 - - * - SQL type - - Arrow type - * - ``boolean`` - - ``Boolean`` - * - ``tinyint`` / ``tinyint unsigned`` - - ``Int8`` / ``UInt8`` - * - ``smallint`` / ``smallint unsigned`` - - ``Int16`` / ``UInt16`` - * - ``int`` or ``integer`` / ``int unsigned`` or ``integer unsigned`` - - ``Int32`` / ``UInt32`` - * - ``bigint`` / ``bigint unsigned`` - - ``Int64`` / ``UInt64`` - * - ``float`` - - ``Float32`` - * - ``double`` - - ``Float64`` - * - ``decimal(precision, scale)`` - - ``Decimal128`` - * - ``date`` - - ``Date32`` - * - ``timestamp`` - - ``Timestamp`` (1) - * - ``string`` - - ``Utf8`` - * - ``binary`` - - ``Binary`` - -(1) See precision mapping in previous table. - - -Random read -~~~~~~~~~~~ - -One district feature of Lance, as columnar format, is that it allows you to read random samples quickly. - -.. code-block:: python - - # Access the 2nd, 101th and 501th rows - data = ds.take([1, 100, 500], columns=["image", "label"]) - -The ability to achieve fast random access to individual rows plays a crucial role in facilitating various workflows -such as random sampling and shuffling in ML training. -Additionally, it empowers users to construct secondary indices, -enabling swift execution of queries for enhanced performance. - - -Table Maintenance ------------------ - -Some operations over time will cause a Lance dataset to have a poor layout. For -example, many small appends will lead to a large number of small fragments. Or -deleting many rows will lead to slower queries due to the need to filter out -deleted rows. - -To address this, Lance provides methods for optimizing dataset layout. - -Compact data files -~~~~~~~~~~~~~~~~~~ - -Data files can be rewritten so there are fewer files. When passing a -``target_rows_per_fragment`` to :py:meth:`lance.dataset.DatasetOptimizer.compact_files`, -Lance will skip any fragments that are already above that row count, and rewrite -others. Fragments will be merged according to their fragment ids, so the inherent -ordering of the data will be preserved. - -.. note:: - - Compaction creates a new version of the table. It does not delete the old - version of the table and the files referenced by it. - -.. code-block:: python - - import lance - - dataset = lance.dataset("./alice_and_bob.lance") - dataset.optimize.compact_files(target_rows_per_fragment=1024 * 1024) - -During compaction, Lance can also remove deleted rows. Rewritten fragments will -not have deletion files. This can improve scan performance since the soft deleted -rows don't have to be skipped during the scan. - -When files are rewritten, the original row addresses are invalidated. This means the -affected files are no longer part of any ANN index if they were before. Because -of this, it's recommended to rewrite files before re-building indices. - -.. TODO: remove this last comment once move-stable row ids are default. - -Object Store Configuration --------------------------- - -Lance supports object stores such as AWS S3 (and compatible stores), Azure Blob Store, -and Google Cloud Storage. Which object store to use is determined by the URI scheme of -the dataset path. For example, ``s3://bucket/path`` will use S3, ``az://bucket/path`` -will use Azure, and ``gs://bucket/path`` will use GCS. - -.. versionadded:: 0.10.7 - - Passing options directly to storage options. - -These object stores take additional configuration objects. There are two ways to -specify these configurations: by setting environment variables or by passing them -to the ``storage_options`` parameter of :py:meth:`lance.dataset` and -:py:func:`lance.write_dataset`. So for example, to globally set a higher timeout, -you would run in your shell: - -.. code-block:: bash - - export TIMEOUT=60s - -If you only want to set the timeout for a single dataset, you can pass it as a -storage option: - -.. code-block:: python - - import lance - ds = lance.dataset("s3://path", storage_options={"timeout": "60s"}) - - -General Configuration -~~~~~~~~~~~~~~~~~~~~~ - -These options apply to all object stores. - -.. from https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html - -.. list-table:: - :widths: 30 70 - :header-rows: 1 - - * - Key - - Description - * - ``allow_http`` - - Allow non-TLS, i.e. non-HTTPS connections. Default, ``False``. - * - ``download_retry_count`` - - Number of times to retry a download. Default, ``3``. This limit is applied when - the HTTP request succeeds but the response is not fully downloaded, typically due - to a violation of ``request_timeout``. - * - ``allow_invalid_certificates`` - - Skip certificate validation on https connections. Default, ``False``. - Warning: This is insecure and should only be used for testing. - * - ``connect_timeout`` - - Timeout for only the connect phase of a Client. Default, ``5s``. - * - ``request_timeout`` - - Timeout for the entire request, from connection until the response body - has finished. Default, ``30s``. - * - ``user_agent`` - - User agent string to use in requests. - * - ``proxy_url`` - - URL of a proxy server to use for requests. Default, ``None``. - * - ``proxy_ca_certificate`` - - PEM-formatted CA certificate for proxy connections - * - ``proxy_excludes`` - - List of hosts that bypass proxy. This is a comma separated list of domains - and IP masks. Any subdomain of the provided domain will be bypassed. For - example, ``example.com, 192.168.1.0/24`` would bypass ``https://api.example.com``, - ``https://www.example.com``, and any IP in the range ``192.168.1.0/24``. - - -S3 Configuration -~~~~~~~~~~~~~~~~ - -S3 (and S3-compatible stores) have additional configuration options that configure -authorization and S3-specific features (such as server-side encryption). - -AWS credentials can be set in the environment variables ``AWS_ACCESS_KEY_ID``, -``AWS_SECRET_ACCESS_KEY``, and ``AWS_SESSION_TOKEN``. Alternatively, they can be -passed as parameters to the ``storage_options`` parameter: - -.. code-block:: python - - import lance - ds = lance.dataset( - "s3://bucket/path", - storage_options={ - "access_key_id": "my-access-key", - "secret_access_key": "my-secret-key", - "session_token": "my-session-token", - } - ) - -If you are using AWS SSO, you can specify the ``AWS_PROFILE`` environment variable. -It cannot be specified in the ``storage_options`` parameter. - -The following keys can be used as both environment variables or keys in the -``storage_options`` parameter: - -.. list-table:: - :widths: 30 70 - :header-rows: 1 - - * - Key - - Description - * - ``aws_region`` / ``region`` - - The AWS region the bucket is in. This can be automatically detected when - using AWS S3, but must be specified for S3-compatible stores. - * - ``aws_access_key_id`` / ``access_key_id`` - - The AWS access key ID to use. - * - ``aws_secret_access_key`` / ``secret_access_key`` - - The AWS secret access key to use. - * - ``aws_session_token`` / ``session_token`` - - The AWS session token to use. - * - ``aws_endpoint`` / ``endpoint`` - - The endpoint to use for S3-compatible stores. - * - ``aws_virtual_hosted_style_request`` / ``virtual_hosted_style_request`` - - Whether to use virtual hosted-style requests, where bucket name is part - of the endpoint. Meant to be used with ``aws_endpoint``. Default, ``False``. - * - ``aws_s3_express`` / ``s3_express`` - - Whether to use S3 Express One Zone endpoints. Default, ``False``. See more - details below. - * - ``aws_server_side_encryption`` - - The server-side encryption algorithm to use. Must be one of ``"AES256"``, - ``"aws:kms"``, or ``"aws:kms:dsse"``. Default, ``None``. - * - ``aws_sse_kms_key_id`` - - The KMS key ID to use for server-side encryption. If set, - ``aws_server_side_encryption`` must be ``"aws:kms"`` or ``"aws:kms:dsse"``. - * - ``aws_sse_bucket_key_enabled`` - - Whether to use bucket keys for server-side encryption. - - -S3-compatible stores -^^^^^^^^^^^^^^^^^^^^ - -Lance can also connect to S3-compatible stores, such as MinIO. To do so, you must -specify both region and endpoint: - -.. code-block:: python - - import lance - ds = lance.dataset( - "s3://bucket/path", - storage_options={ - "region": "us-east-1", - "endpoint": "http://minio:9000", - } - ) - -This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables. - -S3 Express -^^^^^^^^^^ - -.. versionadded:: 0.9.7 - -Lance supports `S3 Express One Zone`_ endpoints, but requires additional configuration. Also, -S3 Express endpoints only support connecting from an EC2 instance within the same -region. - -.. _S3 Express One Zone: https://aws.amazon.com/s3/storage-classes/express-one-zone/ - -To configure Lance to use an S3 Express endpoint, you must set the storage option -``s3_express``. The bucket name in your table URI should **include the suffix**. - -.. code-block:: python - - import lance - ds = lance.dataset( - "s3://my-bucket--use1-az4--x-s3/path/imagenet.lance", - storage_options={ - "region": "us-east-1", - "s3_express": "true", - } - ) - - -Committing mechanisms for S3 -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Most supported storage systems (e.g. local file system, Google Cloud Storage, -Azure Blob Store) natively support atomic commits, which prevent concurrent -writers from corrupting the dataset. However, S3 does not support this natively. -To work around this, you may provide a locking mechanism that Lance can use to -lock the table while providing a write. To do so, you should implement a -context manager that acquires and releases a lock and then pass that to the -``commit_lock`` parameter of :py:meth:`lance.write_dataset`. - -.. note:: - - In order for the locking mechanism to work, all writers must use the same exact - mechanism. Otherwise, Lance will not be able to detect conflicts. - -On entering, the context manager should acquire the lock on the table. The table -version being committed is passed in as an argument, which may be used if the -locking service wishes to keep track of the current version of the table, but -this is not required. If the table is already locked by another transaction, -it should wait until it is unlocked, since the other transaction may fail. Once -unlocked, it should either lock the table or, if the lock keeps track of the -current version of the table, return a :class:`CommitConflictError` if the -requested version has already been committed. - -To prevent poisoned locks, it's recommended to set a timeout on the locks. That -way, if a process crashes while holding the lock, the lock will be released -eventually. The timeout should be no less than 30 seconds. - -.. code-block:: python - - from contextlib import contextmanager - - @contextmanager - def commit_lock(version: int); - # Acquire the lock - my_lock.acquire() - try: - yield - except: - failed = True - finally: - my_lock.release() - - lance.write_dataset(data, "s3://bucket/path/", commit_lock=commit_lock) - -When the context manager is exited, it will raise an exception if the commit -failed. This might be because of a network error or if the version has already -been written. Either way, the context manager should release the lock. Use a -try/finally block to ensure that the lock is released. - -Concurrent Writer on S3 using DynamoDB -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. warning:: - - This feature is experimental at the moment - -Lance has native support for concurrent writers on S3 using DynamoDB instead of locking. -User may pass in a DynamoDB table name alone with the S3 URI to their dataset to enable this feature. - -.. code-block:: python - - import lance - # s3+ddb:// URL scheme let's lance know that you want to - # use DynamoDB for writing to S3 concurrently - ds = lance.dataset("s3+ddb://my-bucket/mydataset?ddbTableName=mytable") - -The DynamoDB table is expected to have a primary hash key of ``base_uri`` and a range key ``version``. -The key ``base_uri`` should be string type, and the key ``version`` should be number type. - -For details on how this feature works, please see :ref:`external-manifest-store`. - - -Google Cloud Storage Configuration -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -GCS credentials are configured by setting the ``GOOGLE_SERVICE_ACCOUNT`` environment -variable to the path of a JSON file containing the service account credentials. -Alternatively, you can pass the path to the JSON file in the ``storage_options`` - -.. code-block:: python - - import lance - ds = lance.dataset( - "gs://my-bucket/my-dataset", - storage_options={ - "service_account": "path/to/service-account.json", - } - ) - -.. note:: - - By default, GCS uses HTTP/1 for communication, as opposed to HTTP/2. This improves - maximum throughput significantly. However, if you wish to use HTTP/2 for some reason, - you can set the environment variable ``HTTP1_ONLY`` to ``false``. - - -The following keys can be used as both environment variables or keys in the -``storage_options`` parameter: - -.. source: https://docs.rs/object_store/latest/object_store/gcp/enum.GoogleConfigKey.html - -.. list-table:: - :widths: 30 70 - :header-rows: 1 - - * - Key - - Description - * - ``google_service_account`` / ``service_account`` - - Path to the service account JSON file. - * - ``google_service_account_key`` / ``service_account_key`` - - The serialized service account key. - * - ``google_application_credentials`` / ``application_credentials`` - - Path to the application credentials. - - -Azure Blob Storage Configuration -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Azure Blob Storage credentials can be configured by setting the ``AZURE_STORAGE_ACCOUNT_NAME`` -and ``AZURE_STORAGE_ACCOUNT_KEY`` environment variables. Alternatively, you can pass -the account name and key in the ``storage_options`` parameter: - -.. code-block:: python - - import lance - ds = lance.dataset( - "az://my-container/my-dataset", - storage_options={ - "account_name": "some-account", - "account_key": "some-key", - } - ) - -These keys can be used as both environment variables or keys in the ``storage_options`` parameter: - -.. source: https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html - -.. list-table:: - :widths: 30 70 - :header-rows: 1 - - * - Key - - Description - * - ``azure_storage_account_name`` / ``account_name`` - - The name of the azure storage account. - * - ``azure_storage_account_key`` / ``account_key`` - - The serialized service account key. - * - ``azure_client_id`` / ``client_id`` - - Service principal client id for authorizing requests. - * - ``azure_client_secret`` / ``client_secret`` - - Service principal client secret for authorizing requests. - * - ``azure_tenant_id`` / ``tenant_id`` - - Tenant id used in oauth flows. - * - ``azure_storage_sas_key`` / ``azure_storage_sas_token`` / ``sas_key`` / ``sas_token`` - - Shared access signature. The signature is expected to be percent-encoded, much like they are provided in the azure storage explorer or azure portal. - * - ``azure_storage_token`` / ``bearer_token`` / ``token`` - - Bearer token. - * - ``azure_storage_use_emulator`` / ``object_store_use_emulator`` / ``use_emulator`` - - Use object store with azurite storage emulator. - * - ``azure_endpoint`` / ``endpoint`` - - Override the endpoint used to communicate with blob storage. - * - ``azure_use_fabric_endpoint`` / ``use_fabric_endpoint`` - - Use object store with url scheme account.dfs.fabric.microsoft.com. - * - ``azure_msi_endpoint`` / ``azure_identity_endpoint`` / ``identity_endpoint`` / ``msi_endpoint`` - - Endpoint to request a imds managed identity token. - * - ``azure_object_id`` / ``object_id`` - - Object id for use with managed identity authentication. - * - ``azure_msi_resource_id`` / ``msi_resource_id`` - - Msi resource id for use with managed identity authentication. - * - ``azure_federated_token_file`` / ``federated_token_file`` - - File containing token for Azure AD workload identity federation. - * - ``azure_use_azure_cli`` / ``use_azure_cli`` - - Use azure cli for acquiring access token. - * - ``azure_disable_tagging`` / ``disable_tagging`` - - Disables tagging objects. This can be desirable if not supported by the backing store. diff --git a/docs/requirements.txt b/docs/requirements.txt index 7955db6cd01..22fff75761e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,14 @@ pyarrow -# pin until breathe updates (https://github.com/sphinx-doc/sphinx/issues/11605, https://github.com/breathe-doc/breathe/issues/943) -sphinx==7.1.2 +sphinx>=8 +sphinx-copybutton +sphinx-immaterial breathe cython pandas piccolo-theme -duckdb>=0.8 +duckdb>=1 jupyterlab fastai xmltodict tensorflow +ray[data] diff --git a/docs/tags.rst b/docs/tags.rst new file mode 100644 index 00000000000..a131f1b220f --- /dev/null +++ b/docs/tags.rst @@ -0,0 +1,51 @@ +Manage Tags +=========== + +Lance, much like Git, employs the :py:attr:`LanceDataset.tags ` +property to label specific versions within a dataset's history. + +:py:class:`Tags ` are particularly useful for tracking the evolution of datasets, +especially in machine learning workflows where datasets are frequently updated. +For example, you can :py:meth:`create `, :meth:`update `, +and :meth:`delete ` or :py:meth:`list ` tags. + +.. note:: + + Creating or deleting tags does not generate new dataset versions. + Tags exist as auxiliary metadata stored in a separate directory. + +.. testsetup:: + + shutil.rmtree("./tags.lance", ignore_errors=True) + data = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + lance.write_dataset(data, "./tags.lance") + data = [{"a": 5, "b": 6}, {"a": 7, "b": 8}] + lance.write_dataset(data, "./tags.lance", mode="append") + +.. doctest:: + + >>> import lance + >>> ds = lance.dataset("./tags.lance") + >>> len(ds.versions()) + 2 + >>> ds.tags.list() + {} + >>> ds.tags.create("v1-prod", 1) + >>> ds.tags.list() + {'v1-prod': {'version': 1, ...}} + >>> ds.tags.update("v1-prod", 2) + >>> ds.tags.list() + {'v1-prod': {'version': 2, ...}} + >>> ds.tags.delete("v1-prod") + >>> ds.tags.list() + {} + + + +.. note:: + + Tagged versions are exempted from the :py:meth:`LanceDataset.cleanup_old_versions() ` + process. + + To remove a version that has been tagged, you must first :py:meth:`LanceDataset.tags.delete() ` + the associated tag. \ No newline at end of file diff --git a/docs/tokenizer.rst b/docs/tokenizer.rst new file mode 100644 index 00000000000..f961557cffa --- /dev/null +++ b/docs/tokenizer.rst @@ -0,0 +1,92 @@ +Tokenizers +============================ + +Currently, Lance has built-in support for Jieba and Lindera. However, it doesn't come with its own language models. +If tokenization is needed, you can download language models by yourself. +You can specify the location where the language models are stored by setting the environment variable LANCE_LANGUAGE_MODEL_HOME. +If it's not set, the default value is + +.. code-block:: bash + + ${system data directory}/lance/language_models + +It also supports configuring user dictionaries, +which makes it convenient for users to expand their own dictionaries without retraining the language models. + +Language Models of Jieba +------------------------ + +Downloading the Model +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + python -m lance.download jieba + +The language model is stored by default in `${LANCE_LANGUAGE_MODEL_HOME}/jieba/default`. + +Using the Model +~~~~~~~~~~~~~~~ + +.. code-block:: python + ds.create_scalar_index("text", "INVERTED", base_tokenizer="jieba/default") + +User Dictionaries +~~~~~~~~~~~~~~~~~ +Create a file named config.json in the root directory of the current model. + +.. code-block:: json + + { + "main": "dict.txt", + "users": ["path/to/user/dict.txt"] + } + +- The "main" field is optional. If not filled, the default is "dict.txt". +- "users" is the path of the user dictionary. For the format of the user dictionary, please refer to https://github.com/messense/jieba-rs/blob/main/src/data/dict.txt. + + +Language Models of Lindera +-------------------------- + +Downloading the Model +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + python -m lance.download lindera -l [ipadic|ko-dic|unidic] + +Note that the language models of Lindera need to be compiled. Please install lindera-cli first. For detailed steps, please refer to https://github.com/lindera/lindera/tree/main/lindera-cli. + +The language model is stored by default in ${LANCE_LANGUAGE_MODEL_HOME}/lindera/[ipadic|ko-dic|unidic] + +Using the Model +~~~~~~~~~~~~~~~ + +.. code-block:: python + + ds.create_scalar_index("text", "INVERTED", base_tokenizer="lindera/ipadic") + +User Dictionaries +~~~~~~~~~~~~~~~~~ + +Create a file named config.json in the root directory of the current model. + +.. code-block::json + { + "main": "main", + "users": "path/to/user/dict.bin", + "user_kind": "ipadic|ko-dic|unidic" + } + +- The "main" field is optional. If not filled, the default is the "main" directory. +- "user" is the path of the user dictionary. The user dictionary can be passed as a CSV file or as a binary file compiled by lindera-cli. +- The "user_kind" field can be left blank if the user dictionary is in binary format. If it's in CSV format, you need to specify the type of the language model. + + +Create your own language model +------------------------------ + +Put your language model into `LANCE_LANGUAGE_MODEL_HOME`. + + diff --git a/integration/duckdb_lance/CMakeLists.txt b/integration/duckdb_lance/CMakeLists.txt deleted file mode 100644 index b3a1e7976fb..00000000000 --- a/integration/duckdb_lance/CMakeLists.txt +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2023 Lance Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Still need to use cmake to link to duckdb via `build_loadable_extension` macro. -# - -cmake_minimum_required(VERSION 3.22) - -if (POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) -endif () - -project(lance_duckdb VERSION 0.3) -set(EXTENSION_NAME lance) - -if (APPLE) - # POLICY CMP0042 - set(CMAKE_MACOSX_RPATH 1) -endif() - -include(FetchContent) - -if(UNIX AND NOT APPLE) - find_package(OpenSSL REQUIRED) -endif() - -FetchContent_Declare( - Corrosion - GIT_REPOSITORY https://github.com/corrosion-rs/corrosion.git - GIT_TAG v0.3.2 # Optionally specify a commit hash, version tag or branch here -) -set(BUILD_UNITTESTS FALSE) # Disable unit test build in duckdb - -FetchContent_MakeAvailable(Corrosion) - -#set(EXTERNAL_EXTENSION_DIRECTORIES ${CMAKE_CURRENT_SOURCE_DIR}) - -corrosion_import_crate(MANIFEST_PATH ${CMAKE_CURRENT_SOURCE_DIR}/Cargo.toml) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/duckdb/src/include) - -set(ALL_SOURCES src/extension.c src/extension.h) - -SET(EXTENSION_STATIC_BUILD 1) -set(PARAMETERS "-warnings") -build_loadable_extension(${EXTENSION_NAME} ${PARAMETERS} ${ALL_SOURCES}) - -set(LIB_NAME ${EXTENSION_NAME}_loadable_extension) - -set_target_properties(${LIB_NAME} PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(${LIB_NAME} - "${CMAKE_CURRENT_BINARY_DIR}/libduckdb_lance.a" - duckdb_static - ${OPENSSL_LIBRARIES} -) - -if (APPLE) - target_link_libraries(${LIB_NAME} - "-framework CoreFoundation" - "-framework Security" - "-framework Accelerate") -endif() diff --git a/integration/duckdb_lance/Cargo.toml b/integration/duckdb_lance/Cargo.toml deleted file mode 100644 index e163d5c0dcd..00000000000 --- a/integration/duckdb_lance/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "duckdb-lance" -version = "0.1.0" -edition = "2021" - -[dependencies] -lance = { path = "../../rust/lance" } -duckdb-ext = { path = "./duckdb-ext" } -lazy_static = "1.4.0" -tokio = { version = "1.23", features = ["rt-multi-thread"] } -arrow-schema = "49.0.0" -arrow-array = "49.0.0" -futures = "0.3" -num-traits = "0.2" - -[dev-dependencies] -libduckdb-sys = { version = "0.8.1", features = ["bundled"] } - -[lib] -name = "duckdb_lance" -crate-type = ["staticlib"] diff --git a/integration/duckdb_lance/Makefile b/integration/duckdb_lance/Makefile deleted file mode 100644 index 7c15c9d0d4f..00000000000 --- a/integration/duckdb_lance/Makefile +++ /dev/null @@ -1,21 +0,0 @@ -# - -BUILD_FLAGS=-DEXTENSION_STATIC_BUILD=1 -DCLANG_TIDY=False - -# Debug build -build: - mkdir -p build/debug && \ - cd build/debug && \ - cmake $(GENERATOR) $(FORCE_COLOR) -DCMAKE_BUILD_TYPE=Debug ${BUILD_FLAGS} ../../duckdb/CMakeLists.txt -DEXTERNAL_EXTENSION_DIRECTORIES=../../duckdb_lance -B. && \ - cmake --build . --config Debug -.PHONY: build - - -release: - mkdir -p build/release && \ - cd build/release && \ - cmake $(GENERATOR) $(FORCE_COLOR) -DCMAKE_BUILD_TYPE=Release ${BUILD_FLAGS} \ - ../../duckdb/CMakeLists.txt -DEXTERNAL_EXTENSION_DIRECTORIES=../../duckdb_lance -B. && \ - cmake --build . --config Release -.PHONY: release - diff --git a/integration/duckdb_lance/README.md b/integration/duckdb_lance/README.md deleted file mode 100644 index 9fef7377e61..00000000000 --- a/integration/duckdb_lance/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# DuckDB Extension - - -## How to build - -```sh - -git submodule update -make build -``` diff --git a/integration/duckdb_lance/duckdb b/integration/duckdb_lance/duckdb deleted file mode 160000 index f7827396d70..00000000000 --- a/integration/duckdb_lance/duckdb +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f7827396d70232a0434c91142809deef6e0b6092 diff --git a/integration/duckdb_lance/duckdb-ext/Cargo.toml b/integration/duckdb_lance/duckdb-ext/Cargo.toml deleted file mode 100644 index c14b8499642..00000000000 --- a/integration/duckdb_lance/duckdb-ext/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "duckdb-ext" -version = "0.1.0" -edition = "2021" - -[dependencies] - -[build-dependencies] -bindgen = "0.64.0" -build_script = "0.2.0" -cc = "1.0.78" diff --git a/integration/duckdb_lance/duckdb-ext/README.md b/integration/duckdb_lance/duckdb-ext/README.md deleted file mode 100644 index 9f15206cb21..00000000000 --- a/integration/duckdb_lance/duckdb-ext/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# DuckDB Rust Extension Toolkit - - -## Credits - -This library was inspired by [DuckDB Extension Framework](https://github.com/Mause/duckdb-extension-framework). diff --git a/integration/duckdb_lance/duckdb-ext/build.rs b/integration/duckdb_lance/duckdb-ext/build.rs deleted file mode 100644 index 6696365900a..00000000000 --- a/integration/duckdb_lance/duckdb-ext/build.rs +++ /dev/null @@ -1,40 +0,0 @@ -use build_script::cargo_rerun_if_changed; -use std::path::PathBuf; -use std::{env, path::Path}; - -fn main() { - let duckdb_root = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap()) - .join("duckdb") - .canonicalize() - .expect("duckdb source root"); - - let header = "src/duckdb_ext.h"; - - cargo_rerun_if_changed(header); - - let duckdb_include = duckdb_root.join("src/include"); - let bindings = bindgen::Builder::default() - .header(header) - .clang_arg("-xc++") - .clang_arg("-I") - .clang_arg(duckdb_include.to_string_lossy()) - .derive_debug(true) - .derive_default(true) - .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - .generate() - .expect("Unable to generate bindings"); - - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); - - cc::Build::new() - .include(duckdb_include) - .flag_if_supported("-Wno-unused-parameter") - .flag_if_supported("-Wno-redundant-move") - .flag_if_supported("-std=c++17") - .cpp(true) - .file("src/duckdb_ext.cc") - .compile("duckdb_ext"); -} diff --git a/integration/duckdb_lance/duckdb-ext/duckdb b/integration/duckdb_lance/duckdb-ext/duckdb deleted file mode 160000 index f7827396d70..00000000000 --- a/integration/duckdb_lance/duckdb-ext/duckdb +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f7827396d70232a0434c91142809deef6e0b6092 diff --git a/integration/duckdb_lance/duckdb-ext/src/connection.rs b/integration/duckdb_lance/duckdb-ext/src/connection.rs deleted file mode 100644 index ae125990fe6..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/connection.rs +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::ffi::{duckdb_connection, duckdb_register_table_function}; -use crate::table_function::TableFunction; - -/// A connection to a database. This represents a (client) connection that can -/// be used to query the database. -#[derive(Debug)] -pub struct Connection { - ptr: duckdb_connection, -} - -impl From for Connection { - fn from(ptr: duckdb_connection) -> Self { - Self { ptr } - } -} - -impl Connection { - pub fn register_table_function( - &self, - table_function: TableFunction, - ) -> Result<(), Box> { - unsafe { - duckdb_register_table_function(self.ptr, table_function.ptr); - } - Ok(()) - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/data_chunk.rs b/integration/duckdb_lance/duckdb-ext/src/data_chunk.rs deleted file mode 100644 index 32194bc50cc..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/data_chunk.rs +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use super::vector::{FlatVector, ListVector, StructVector}; -use crate::{ - ffi::{ - duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_size, - duckdb_data_chunk_get_vector, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk, - duckdb_data_chunk_get_column_count, - }, - LogicalType, -}; - -/// DataChunk in DuckDB. -pub struct DataChunk { - /// Pointer to the DataChunk in duckdb C API. - ptr: duckdb_data_chunk, - - /// Whether this [DataChunk] own the [DataChunk::ptr]. - owned: bool, -} - -impl DataChunk { - pub fn new(logical_types: &[LogicalType]) -> Self { - let num_columns = logical_types.len(); - let mut c_types = logical_types.iter().map(|t| t.ptr).collect::>(); - let ptr = unsafe { duckdb_create_data_chunk(c_types.as_mut_ptr(), num_columns as u64) }; - DataChunk { ptr, owned: true } - } - - /// Get the vector at the specific column index: `idx`. - /// - pub fn flat_vector(&self, idx: usize) -> FlatVector { - FlatVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) - } - - /// Get a list vector from the column index. - pub fn list_vector(&self, idx: usize) -> ListVector { - ListVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) - } - - /// Get struct vector at the column index: `idx`. - pub fn struct_vector(&self, idx: usize) -> StructVector { - StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) - } - - /// Set the size of the data chunk - pub fn set_len(&self, new_len: usize) { - unsafe { duckdb_data_chunk_set_size(self.ptr, new_len as u64) }; - } - - /// Get the length / the number of rows in this [DataChunk]. - pub fn len(&self) -> usize { - unsafe { duckdb_data_chunk_get_size(self.ptr) as usize } - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn num_columns(&self) -> usize { - unsafe { duckdb_data_chunk_get_column_count(self.ptr) as usize } - } -} - -impl From for DataChunk { - fn from(ptr: duckdb_data_chunk) -> Self { - Self { ptr, owned: false } - } -} - -impl Drop for DataChunk { - fn drop(&mut self) { - if self.owned && !self.ptr.is_null() { - unsafe { duckdb_destroy_data_chunk(&mut self.ptr) } - self.ptr = std::ptr::null_mut(); - } - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/database.rs b/integration/duckdb_lance/duckdb-ext/src/database.rs deleted file mode 100644 index 41a181aa351..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/database.rs +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::ffi::{duckdb_connect, duckdb_connection, duckdb_database, duckdb_state_DuckDBError}; -use crate::{Connection, Error, Result}; - -pub struct Database { - ptr: duckdb_database, -} - -impl From for Database { - fn from(ptr: duckdb_database) -> Self { - Self { ptr } - } -} - -impl Database { - pub fn connect(&self) -> Result { - let mut connection: duckdb_connection = std::ptr::null_mut(); - - let state = unsafe { duckdb_connect(self.ptr, &mut connection) }; - if state == duckdb_state_DuckDBError { - return Err(Error::DuckDB("Connection error".to_string())); - } - - Ok(Connection::from(connection)) - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/duckdb_ext.cc b/integration/duckdb_lance/duckdb-ext/src/duckdb_ext.cc deleted file mode 100644 index c2efa5c66c2..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/duckdb_ext.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2023 Lance Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "duckdb_ext.h" - -#include - -#include "duckdb.hpp" - -namespace { - -auto build_child_list(idx_t n_pairs, const char *const *names, duckdb_logical_type const *types) { - duckdb::child_list_t members; - for (idx_t i = 0; i < n_pairs; i++) { - members.emplace_back(std::string(names[i]), *(duckdb::LogicalType *)types[i]); - } - return members; -} - -} // namespace - -extern "C" { - -duckdb_logical_type duckdb_create_struct_type(idx_t n_pairs, - const char **names, - const duckdb_logical_type *types) { - auto *stype = new duckdb::LogicalType; - *stype = duckdb::LogicalType::STRUCT(build_child_list(n_pairs, names, types)); - return reinterpret_cast(stype); -} - -} \ No newline at end of file diff --git a/integration/duckdb_lance/duckdb-ext/src/duckdb_ext.h b/integration/duckdb_lance/duckdb-ext/src/duckdb_ext.h deleted file mode 100644 index d246e483c8f..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/duckdb_ext.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2023 Lance Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#define DUCKDB_BUILD_LOADABLE_EXTENSION -#include "duckdb.h" - -extern "C" { - -DUCKDB_EXTENSION_API duckdb_logical_type duckdb_create_struct_type( - idx_t n_pairs, const char** names, const duckdb_logical_type* types); - -}; diff --git a/integration/duckdb_lance/duckdb-ext/src/error.rs b/integration/duckdb_lance/duckdb-ext/src/error.rs deleted file mode 100644 index d5e8f9de2ae..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/error.rs +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::ffi::CString; - -pub enum Error { - IO(String), - DuckDB(String), -} - -pub type Result = std::result::Result; - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::IO(s) => write!(f, "I/O: {s}"), - Self::DuckDB(s) => write!(f, "I/O: {s}"), - } - } -} - -impl Error { - pub fn c_str(&self) -> CString { - CString::new(self.to_string()).unwrap() - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/function_info.rs b/integration/duckdb_lance/duckdb-ext/src/function_info.rs deleted file mode 100644 index a81c967a833..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/function_info.rs +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::ffi::{duckdb_function_get_init_data, duckdb_function_info, duckdb_function_set_error}; -use crate::Error; - -/// UDF -pub struct FunctionInfo { - ptr: duckdb_function_info, -} - -impl From for FunctionInfo { - fn from(ptr: duckdb_function_info) -> Self { - Self { ptr } - } -} - -impl FunctionInfo { - pub fn init_data(&self) -> *mut T { - unsafe { duckdb_function_get_init_data(self.ptr).cast() } - } - - pub fn set_error(&self, error: Error) { - unsafe { - duckdb_function_set_error(self.ptr, error.c_str().as_ptr()); - } - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/lib.rs b/integration/duckdb_lance/duckdb-ext/src/lib.rs deleted file mode 100644 index 8cc7597a69b..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/lib.rs +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -mod connection; -mod data_chunk; -mod database; -mod error; -mod function_info; -mod logical_type; -pub mod table_function; -mod value; -mod vector; - -pub use connection::Connection; -pub use data_chunk::DataChunk; -pub use database::Database; -pub use error::{Error, Result}; -pub use function_info::FunctionInfo; -pub use logical_type::{LogicalType, LogicalTypeId}; -pub use value::Value; -pub use vector::{FlatVector, Inserter, ListVector, StructVector, Vector}; - -#[allow(clippy::all)] -pub mod ffi { - #![allow(non_upper_case_globals)] - #![allow(non_camel_case_types)] - #![allow(non_snake_case)] - #![allow(unused)] - #![allow(improper_ctypes)] - #![allow(clippy::upper_case_acronyms)] - include!(concat!(env!("OUT_DIR"), "/bindings.rs")); -} diff --git a/integration/duckdb_lance/duckdb-ext/src/logical_type.rs b/integration/duckdb_lance/duckdb-ext/src/logical_type.rs deleted file mode 100644 index 921f273b16f..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/logical_type.rs +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::ffi::{c_char, CString}; -use std::fmt::Debug; - -use crate::ffi::*; - -#[repr(u32)] -#[derive(Debug, PartialEq, Eq)] -pub enum LogicalTypeId { - Boolean = DUCKDB_TYPE_DUCKDB_TYPE_BOOLEAN, - Tinyint = DUCKDB_TYPE_DUCKDB_TYPE_TINYINT, - Smallint = DUCKDB_TYPE_DUCKDB_TYPE_SMALLINT, - Integer = DUCKDB_TYPE_DUCKDB_TYPE_INTEGER, - Bigint = DUCKDB_TYPE_DUCKDB_TYPE_BIGINT, - UTinyint = DUCKDB_TYPE_DUCKDB_TYPE_UTINYINT, - USmallint = DUCKDB_TYPE_DUCKDB_TYPE_USMALLINT, - UInteger = DUCKDB_TYPE_DUCKDB_TYPE_UINTEGER, - UBigint = DUCKDB_TYPE_DUCKDB_TYPE_UBIGINT, - Float = DUCKDB_TYPE_DUCKDB_TYPE_FLOAT, - Double = DUCKDB_TYPE_DUCKDB_TYPE_DOUBLE, - Timestamp = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP, - Date = DUCKDB_TYPE_DUCKDB_TYPE_DATE, - Time = DUCKDB_TYPE_DUCKDB_TYPE_TIME, - Interval = DUCKDB_TYPE_DUCKDB_TYPE_INTERVAL, - Hugeint = DUCKDB_TYPE_DUCKDB_TYPE_HUGEINT, - Varchar = DUCKDB_TYPE_DUCKDB_TYPE_VARCHAR, - Blob = DUCKDB_TYPE_DUCKDB_TYPE_BLOB, - Decimal = DUCKDB_TYPE_DUCKDB_TYPE_DECIMAL, - TimestampS = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_S, - TimestampMs = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_MS, - TimestampNs = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_NS, - Enum = DUCKDB_TYPE_DUCKDB_TYPE_ENUM, - List = DUCKDB_TYPE_DUCKDB_TYPE_LIST, - Struct = DUCKDB_TYPE_DUCKDB_TYPE_STRUCT, - Map = DUCKDB_TYPE_DUCKDB_TYPE_MAP, - Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID, - Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION, -} - -impl From for LogicalTypeId { - fn from(value: u32) -> Self { - match value { - DUCKDB_TYPE_DUCKDB_TYPE_BOOLEAN => Self::Boolean, - DUCKDB_TYPE_DUCKDB_TYPE_TINYINT => Self::Tinyint, - DUCKDB_TYPE_DUCKDB_TYPE_SMALLINT => Self::Smallint, - DUCKDB_TYPE_DUCKDB_TYPE_INTEGER => Self::Integer, - DUCKDB_TYPE_DUCKDB_TYPE_BIGINT => Self::Bigint, - DUCKDB_TYPE_DUCKDB_TYPE_UTINYINT => Self::UTinyint, - DUCKDB_TYPE_DUCKDB_TYPE_USMALLINT => Self::USmallint, - DUCKDB_TYPE_DUCKDB_TYPE_UINTEGER => Self::UInteger, - DUCKDB_TYPE_DUCKDB_TYPE_UBIGINT => Self::UBigint, - DUCKDB_TYPE_DUCKDB_TYPE_FLOAT => Self::Float, - DUCKDB_TYPE_DUCKDB_TYPE_DOUBLE => Self::Double, - DUCKDB_TYPE_DUCKDB_TYPE_VARCHAR => Self::Varchar, - DUCKDB_TYPE_DUCKDB_TYPE_BLOB => Self::Blob, - DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP => Self::Timestamp, - DUCKDB_TYPE_DUCKDB_TYPE_DATE => Self::Date, - DUCKDB_TYPE_DUCKDB_TYPE_TIME => Self::Time, - DUCKDB_TYPE_DUCKDB_TYPE_INTERVAL => Self::Interval, - DUCKDB_TYPE_DUCKDB_TYPE_HUGEINT => Self::Hugeint, - DUCKDB_TYPE_DUCKDB_TYPE_DECIMAL => Self::Decimal, - DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_S => Self::TimestampS, - DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_MS => Self::TimestampMs, - DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_NS => Self::TimestampNs, - DUCKDB_TYPE_DUCKDB_TYPE_ENUM => Self::Enum, - DUCKDB_TYPE_DUCKDB_TYPE_LIST => Self::List, - DUCKDB_TYPE_DUCKDB_TYPE_STRUCT => Self::Struct, - DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map, - DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid, - DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union, - _ => panic!(), - } - } -} - -/// DuckDB Logical Type. -/// -/// https://duckdb.org/docs/sql/data_types/overview -pub struct LogicalType { - pub(crate) ptr: duckdb_logical_type, -} - -impl Debug for LogicalType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - let id = self.id(); - match id { - LogicalTypeId::Struct => { - write!(f, "struct<")?; - for i in 0..self.num_children() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}: {:?}", self.child_name(i), self.child(i))?; - } - write!(f, ">") - } - _ => write!(f, "{:?}", self.id()), - } - } -} - -impl Drop for LogicalType { - fn drop(&mut self) { - if !self.ptr.is_null() { - unsafe { - duckdb_destroy_logical_type(&mut self.ptr); - } - } - - self.ptr = std::ptr::null_mut(); - } -} - -/// Wrap a DuckDB logical type from C API -impl From for LogicalType { - fn from(ptr: duckdb_logical_type) -> Self { - Self { ptr } - } -} - -impl LogicalType { - /// Create a new [LogicalType] from [LogicalTypeId] - pub fn new(id: LogicalTypeId) -> Self { - unsafe { - Self { - ptr: duckdb_create_logical_type(id as u32), - } - } - } - - /// Creates a list type from its child type. - /// - pub fn list_type(child_type: &LogicalType) -> Self { - unsafe { - Self { - ptr: duckdb_create_list_type(child_type.ptr), - } - } - } - - /// Make a `LogicalType` for `struct` - /// - pub fn struct_type(fields: &[(&str, LogicalType)]) -> Self { - let keys: Vec = fields.iter().map(|f| CString::new(f.0).unwrap()).collect(); - let values: Vec = fields.iter().map(|it| it.1.ptr).collect(); - let name_ptrs = keys - .iter() - .map(|it| it.as_ptr()) - .collect::>(); - - unsafe { - Self { - ptr: duckdb_create_struct_type( - fields.len() as idx_t, - name_ptrs.as_slice().as_ptr().cast_mut(), - values.as_slice().as_ptr(), - ), - } - } - } - - /// Logical type ID - pub fn id(&self) -> LogicalTypeId { - let duckdb_type_id = unsafe { duckdb_get_type_id(self.ptr) }; - duckdb_type_id.into() - } - - pub fn num_children(&self) -> usize { - match self.id() { - LogicalTypeId::Struct => unsafe { duckdb_struct_type_child_count(self.ptr) as usize }, - LogicalTypeId::List => 1, - _ => 0, - } - } - - pub fn child_name(&self, idx: usize) -> String { - assert_eq!(self.id(), LogicalTypeId::Struct); - unsafe { - let child_name_ptr = duckdb_struct_type_child_name(self.ptr, idx as u64); - let c_str = CString::from_raw(child_name_ptr); - let name = c_str.to_str().unwrap(); - name.to_string() - } - } - - pub fn child(&self, idx: usize) -> Self { - let c_logical_type = unsafe { duckdb_struct_type_child_type(self.ptr, idx as u64) }; - Self::from(c_logical_type) - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/table_function.rs b/integration/duckdb_lance/duckdb-ext/src/table_function.rs deleted file mode 100644 index cf2b8a0b57c..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/table_function.rs +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::ffi::{c_void, CString}; - -use crate::ffi::{ - duckdb_bind_add_result_column, duckdb_bind_get_parameter, duckdb_bind_get_parameter_count, - duckdb_bind_info, duckdb_bind_set_bind_data, duckdb_bind_set_cardinality, - duckdb_bind_set_error, duckdb_create_table_function, duckdb_delete_callback_t, - duckdb_destroy_table_function, duckdb_init_get_bind_data, duckdb_init_info, - duckdb_init_set_error, duckdb_init_set_init_data, duckdb_table_function, - duckdb_table_function_add_parameter, duckdb_table_function_bind_t, - duckdb_table_function_init_t, duckdb_table_function_set_bind, - duckdb_table_function_set_function, duckdb_table_function_set_init, - duckdb_table_function_set_name, duckdb_table_function_supports_projection_pushdown, - duckdb_table_function_t, duckdb_init_get_column_count, duckdb_init_get_column_index, -}; -use crate::{Error, LogicalType, Value}; - -/// DuckDB BindInfo. -pub struct BindInfo { - ptr: duckdb_bind_info, -} - -impl From for BindInfo { - fn from(ptr: duckdb_bind_info) -> Self { - Self { ptr } - } -} - -impl BindInfo { - /// Add a result column to the output of the table function. - /// - /// - `name`: The name of the column - /// - `logical_type`: The [LogicalType] of the new column. - /// - /// # Safety - pub fn add_result_column(&self, name: &str, logical_type: LogicalType) { - let c_string = CString::new(name).unwrap(); - unsafe { - duckdb_bind_add_result_column(self.ptr, c_string.as_ptr(), logical_type.ptr); - } - } - - /// Sets the user-provided bind data in the bind object. This object can be retrieved again during execution. - /// - /// # Arguments - /// * `extra_data`: The bind data object. - /// * `destroy`: The callback that will be called to destroy the bind data (if any) - /// - /// # Safety - /// - pub fn set_bind_data( - &self, - data: *mut c_void, - free_function: Option, - ) { - unsafe { - duckdb_bind_set_bind_data(self.ptr, data, free_function); - } - } - - /// Get the number of regular (non-named) parameters to the function. - pub fn num_parameters(&self) -> u64 { - unsafe { duckdb_bind_get_parameter_count(self.ptr) } - } - - /// Get the parameter at the given index. - /// - /// # Arguments - /// * `index`: The index of the parameter to get - /// - /// returns: The value of the parameter - pub fn parameter(&self, index: usize) -> Value { - unsafe { Value::from(duckdb_bind_get_parameter(self.ptr, index as u64)) } - } - - /// Sets the cardinality estimate for the table function, used for optimization. - /// - /// * `cardinality`: The cardinality estimate - /// * `is_exact`: Whether or not the cardinality estimate is exact, or an approximation - pub fn set_cardinality(&self, cardinality: usize, is_exact: bool) { - unsafe { duckdb_bind_set_cardinality(self.ptr, cardinality as u64, is_exact) } - } - - pub fn set_error(&self, error: Error) { - unsafe { - duckdb_bind_set_error(self.ptr, error.c_str().as_ptr()); - } - } -} - -#[derive(Debug)] -pub struct InitInfo { - ptr: duckdb_init_info, -} - -impl From for InitInfo { - fn from(ptr: duckdb_init_info) -> Self { - Self { ptr } - } -} - -impl InitInfo { - /// # Safety - pub fn set_init_data(&self, data: *mut c_void, freeer: duckdb_delete_callback_t) { - unsafe { - duckdb_init_set_init_data(self.ptr, data, freeer); - } - } - - pub fn bind_data(&self) -> *mut T { - unsafe { duckdb_init_get_bind_data(self.ptr).cast() } - } - - /// Report that an error has occurred while calling init. - /// - /// # Arguments - /// * `error`: The error message - pub fn set_error(&self, error: Error) { - unsafe { duckdb_init_set_error(self.ptr, error.c_str().as_ptr()) } - } - - /// Get the total number of columns to be projected. - pub fn projected_column_ids(&self) -> Vec { - let num_columns = unsafe { duckdb_init_get_column_count(self.ptr) as usize }; - (0..num_columns).map(|col_id| { - unsafe { duckdb_init_get_column_index(self.ptr, col_id as u64) as usize} - }).collect() - } -} - -/// A function that returns a queryable table -#[derive(Debug)] -pub struct TableFunction { - pub(crate) ptr: duckdb_table_function, -} - -impl Drop for TableFunction { - fn drop(&mut self) { - if !self.ptr.is_null() { - unsafe { - duckdb_destroy_table_function(&mut self.ptr); - } - } - self.ptr = std::ptr::null_mut(); - } -} - -impl TableFunction { - /// Creates a new empty table function. - pub fn new(name: &str) -> Self { - let this = Self { - ptr: unsafe { duckdb_create_table_function() }, - }; - this.set_name(name); - this - } - - pub fn set_name(&self, name: &str) -> &Self { - unsafe { - let string = CString::new(name).unwrap(); - duckdb_table_function_set_name(self.ptr, string.as_ptr()); - } - self - } - - /// Adds a parameter to the table function. - /// - pub fn add_parameter(&self, logical_type: &LogicalType) -> &Self { - unsafe { - duckdb_table_function_add_parameter(self.ptr, logical_type.ptr); - } - self - } - - /// Enable project pushdown. - pub fn pushdown(&self, supports: bool) -> &Self { - unsafe { - duckdb_table_function_supports_projection_pushdown(self.ptr, supports); - } - self - } - - /// Sets the main function of the table function - /// - pub fn set_function(&self, func: duckdb_table_function_t) -> &Self { - unsafe { - duckdb_table_function_set_function(self.ptr, func); - } - self - } - - /// Sets the init function of the table function - /// - /// # Arguments - /// * `function`: The init function - pub fn set_init(&self, init_func: duckdb_table_function_init_t) -> &Self { - unsafe { - duckdb_table_function_set_init(self.ptr, init_func); - } - self - } - - /// Sets the bind function of the table function - /// - /// # Arguments - /// * `bind_func`: The bind function - pub fn set_bind(&self, bind_func: duckdb_table_function_bind_t) -> &Self { - unsafe { - duckdb_table_function_set_bind(self.ptr, bind_func); - } - self - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/value.rs b/integration/duckdb_lance/duckdb-ext/src/value.rs deleted file mode 100644 index 04728f869a7..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/value.rs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::ffi::{duckdb_destroy_value, duckdb_get_varchar, duckdb_value}; -use std::ffi::CString; - -/// The Value object holds a single arbitrary value of any type that can be -/// stored in the database. -#[derive(Debug)] -pub struct Value { - pub(crate) ptr: duckdb_value, -} - -impl From for Value { - fn from(ptr: duckdb_value) -> Self { - Self { ptr } - } -} - -impl Drop for Value { - fn drop(&mut self) { - if !self.ptr.is_null() { - unsafe { - duckdb_destroy_value(&mut self.ptr); - } - } - self.ptr = std::ptr::null_mut(); - } -} - -impl Value { - pub fn to_string(&self) -> String { - let c_string = unsafe { CString::from_raw(duckdb_get_varchar(self.ptr)) }; - c_string.into_string().unwrap() - } -} diff --git a/integration/duckdb_lance/duckdb-ext/src/vector.rs b/integration/duckdb_lance/duckdb-ext/src/vector.rs deleted file mode 100644 index 1f40300e5cb..00000000000 --- a/integration/duckdb_lance/duckdb-ext/src/vector.rs +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::any::Any; -use std::ffi::CString; -use std::slice; - -use crate::ffi::{ - duckdb_list_entry, duckdb_list_vector_get_child, duckdb_list_vector_get_size, - duckdb_list_vector_reserve, duckdb_list_vector_set_size, duckdb_struct_type_child_count, - duckdb_struct_type_child_name, duckdb_struct_vector_get_child, duckdb_vector, - duckdb_vector_assign_string_element, duckdb_vector_get_column_type, duckdb_vector_get_data, - duckdb_vector_size, -}; -use crate::LogicalType; - -/// Vector trait. -pub trait Vector { - fn as_any(&self) -> &dyn Any; - - fn as_mut_any(&mut self) -> &mut dyn Any; -} - -pub struct FlatVector { - ptr: duckdb_vector, - capacity: usize, -} - -impl From for FlatVector { - fn from(ptr: duckdb_vector) -> Self { - Self { - ptr, - capacity: unsafe { duckdb_vector_size() as usize }, - } - } -} - -impl Vector for FlatVector { - fn as_any(&self) -> &dyn Any { - self - } - - fn as_mut_any(&mut self) -> &mut dyn Any { - self - } -} - -impl FlatVector { - fn with_capacity(ptr: duckdb_vector, capacity: usize) -> Self { - Self { ptr, capacity } - } - - pub fn capacity(&self) -> usize { - self.capacity - } - - /// Returns an unsafe mutable pointer to the vector’s - pub fn as_mut_ptr(&self) -> *mut T { - unsafe { duckdb_vector_get_data(self.ptr).cast() } - } - - pub fn as_slice(&self) -> &[T] { - unsafe { slice::from_raw_parts(self.as_mut_ptr(), self.capacity()) } - } - - pub fn as_mut_slice(&mut self) -> &mut [T] { - unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.capacity()) } - } - - pub fn logical_type(&self) -> LogicalType { - LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) }) - } - - pub fn copy(&mut self, data: &[T]) { - assert!(data.len() <= self.capacity()); - self.as_mut_slice::()[0..data.len()].copy_from_slice(data); - } -} - -pub trait Inserter { - fn insert(&self, index: usize, value: T); -} - -impl Inserter<&str> for FlatVector { - fn insert(&self, index: usize, value: &str) { - let cstr = CString::new(value.as_bytes()).unwrap(); - unsafe { - duckdb_vector_assign_string_element(self.ptr, index as u64, cstr.as_ptr()); - } - } -} - -pub struct ListVector { - /// ListVector does not own the vector pointer. - entries: FlatVector, -} - -impl From for ListVector { - fn from(ptr: duckdb_vector) -> Self { - Self { - entries: FlatVector::from(ptr), - } - } -} - -impl ListVector { - pub fn len(&self) -> usize { - unsafe { duckdb_list_vector_get_size(self.entries.ptr) as usize } - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - // TODO: not ideal interface. Where should we keep capacity. - pub fn child(&self, capacity: usize) -> FlatVector { - self.reserve(capacity); - FlatVector::with_capacity( - unsafe { duckdb_list_vector_get_child(self.entries.ptr) }, - capacity, - ) - } - - /// Set primitive data to the child node. - pub fn set_child(&self, data: &[T]) { - self.child(data.len()).copy(data); - self.set_len(data.len()); - } - - pub fn set_entry(&mut self, idx: usize, offset: usize, length: usize) { - self.entries.as_mut_slice::()[idx].offset = offset as u64; - self.entries.as_mut_slice::()[idx].length = length as u64; - } - - /// Reserve the capacity for its child node. - fn reserve(&self, capacity: usize) { - unsafe { duckdb_list_vector_reserve(self.entries.ptr, capacity as u64); } - } - - pub fn set_len(&self, new_len: usize) { - unsafe { duckdb_list_vector_set_size(self.entries.ptr, new_len as u64); } - } -} - -pub struct StructVector { - /// ListVector does not own the vector pointer. - ptr: duckdb_vector, -} - -impl From for StructVector { - fn from(ptr: duckdb_vector) -> Self { - Self { ptr } - } -} - -impl StructVector { - pub fn child(&self, idx: usize) -> FlatVector { - FlatVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) }) - } - - /// Take the child as [StructVector]. - pub fn struct_vector_child(&self, idx: usize) -> StructVector { - Self::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) }) - } - - pub fn list_vector_child(&self, idx: usize) -> ListVector { - ListVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) }) - } - - /// Get the logical type of this struct vector. - pub fn logical_type(&self) -> LogicalType { - LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) }) - } - - pub fn child_name(&self, idx: usize) -> String { - let logical_type = self.logical_type(); - unsafe { - let child_name_ptr = duckdb_struct_type_child_name(logical_type.ptr, idx as u64); - let c_str = CString::from_raw(child_name_ptr); - let name = c_str.to_str().unwrap(); - // duckdb_free(child_name_ptr.cast()); - name.to_string() - } - } - - pub fn num_children(&self) -> usize { - let logical_type = self.logical_type(); - unsafe { duckdb_struct_type_child_count(logical_type.ptr) as usize } - } -} diff --git a/integration/duckdb_lance/src/arrow.rs b/integration/duckdb_lance/src/arrow.rs deleted file mode 100644 index 0c014627465..00000000000 --- a/integration/duckdb_lance/src/arrow.rs +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Arrow / DuckDB conversion. - -use arrow_array::{ - cast::{ - as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, - as_struct_array, - }, - types::*, - Array, ArrowPrimitiveType, BooleanArray, FixedSizeListArray, GenericListArray, OffsetSizeTrait, - PrimitiveArray, RecordBatch, StringArray, StructArray, -}; -use arrow_schema::DataType; -use duckdb_ext::{DataChunk, FlatVector, Inserter, ListVector, StructVector, Vector}; -use duckdb_ext::{LogicalType, LogicalTypeId}; -use lance::arrow::as_fixed_size_list_array; -use num_traits::AsPrimitive; - -use crate::{Error, Result}; - -pub fn to_duckdb_type_id(data_type: &DataType) -> Result { - use LogicalTypeId::*; - - let type_id = match data_type { - DataType::Boolean => Boolean, - DataType::Int8 => Tinyint, - DataType::Int16 => Smallint, - DataType::Int32 => Integer, - DataType::Int64 => Bigint, - DataType::UInt8 => UTinyint, - DataType::UInt16 => USmallint, - DataType::UInt32 => UInteger, - DataType::UInt64 => UBigint, - DataType::Float32 => Float, - DataType::Float64 => Double, - DataType::Timestamp(_, _) => Timestamp, - DataType::Date32 => Time, - DataType::Date64 => Time, - DataType::Time32(_) => Time, - DataType::Time64(_) => Time, - DataType::Duration(_) => Interval, - DataType::Interval(_) => Interval, - DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => Blob, - DataType::Utf8 | DataType::LargeUtf8 => Varchar, - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => List, - DataType::Struct(_) => Struct, - DataType::Union(_, _) => Union, - DataType::Dictionary(_, _) => todo!(), - DataType::Decimal128(_, _) => Decimal, - DataType::Decimal256(_, _) => Decimal, - DataType::Map(_, _) => Map, - _ => { - return Err(Error::DuckDB(format!( - "Unsupported arrow type: {data_type}" - ))); - } - }; - Ok(type_id) -} - -pub fn to_duckdb_logical_type(data_type: &DataType) -> Result { - if data_type.is_primitive() - || matches!( - data_type, - DataType::Boolean - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary - ) - { - Ok(LogicalType::new(to_duckdb_type_id(data_type)?)) - } else if let DataType::Dictionary(_, value_type) = data_type { - to_duckdb_logical_type(value_type) - } else if let DataType::Struct(fields) = data_type { - let mut shape = vec![]; - for field in fields.iter() { - shape.push(( - field.name().as_str(), - to_duckdb_logical_type(field.data_type())?, - )); - } - Ok(LogicalType::struct_type(shape.as_slice())) - } else if let DataType::List(child) = data_type { - Ok(LogicalType::list_type(&to_duckdb_logical_type( - child.data_type(), - )?)) - } else if let DataType::LargeList(child) = data_type { - Ok(LogicalType::list_type(&to_duckdb_logical_type( - child.data_type(), - )?)) - } else if let DataType::FixedSizeList(child, _) = data_type { - Ok(LogicalType::list_type(&to_duckdb_logical_type( - child.data_type(), - )?)) - } else { - todo!("Unsupported data type: {data_type}, please file an issue at https://github.com/lancedb/lance"); - } -} - -pub fn record_batch_to_duckdb_data_chunk(batch: &RecordBatch, chunk: &mut DataChunk) -> Result<()> { - // Fill the row - assert_eq!(batch.num_columns(), chunk.num_columns()); - for i in 0..batch.num_columns() { - let col = batch.column(i); - match col.data_type() { - dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { - primitive_array_to_vector(col, &mut chunk.flat_vector(i)); - } - DataType::Utf8 => { - string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i)); - } - DataType::List(_) => { - list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i)); - } - DataType::LargeList(_) => { - list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i)); - } - DataType::FixedSizeList(_, _) => { - fixed_size_list_array_to_vector( - as_fixed_size_list_array(col.as_ref()), - &mut chunk.list_vector(i), - ); - } - DataType::Struct(_) => { - let struct_array = as_struct_array(col.as_ref()); - let mut struct_vector = chunk.struct_vector(i); - struct_array_to_vector(struct_array, &mut struct_vector); - } - _ => { - todo!("column {} is not supported yet, please file an issue at https://github.com/lancedb/lance", batch.schema().field(i)); - } - } - } - chunk.set_len(batch.num_rows()); - Ok(()) -} - -fn primitive_array_to_flat_vector( - array: &PrimitiveArray, - out_vector: &mut FlatVector, -) { - // assert!(array.len() <= out_vector.capacity()); - out_vector.copy::(array.values()); -} - -fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { - match array.data_type() { - DataType::Boolean => { - boolean_array_to_vector( - as_boolean_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::UInt8 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::UInt16 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::UInt32 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::UInt64 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::Int8 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::Int16 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::Int32 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::Int64 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::Float32 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - DataType::Float64 => { - primitive_array_to_flat_vector::( - as_primitive_array(array), - out.as_mut_any().downcast_mut().unwrap(), - ); - } - _ => { - todo!() - } - } -} - -/// Convert Arrow [BooleanArray] to a duckdb vector. -fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) { - assert!(array.len() <= out.capacity()); - - for i in 0..array.len() { - out.as_mut_slice()[i] = array.value(i); - } -} - -fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) { - assert!(array.len() <= out.capacity()); - - // TODO: zero copy assignment - for i in 0..array.len() { - let s = array.value(i); - out.insert(i, s); - } -} - -fn list_array_to_vector>( - array: &GenericListArray, - out: &mut ListVector, -) { - let value_array = array.values(); - let mut child = out.child(value_array.len()); - match value_array.data_type() { - dt if dt.is_primitive() => { - primitive_array_to_vector(value_array.as_ref(), &mut child); - for i in 0..array.len() { - let offset = array.value_offsets()[i]; - let length = array.value_length(i); - out.set_entry(i, offset.as_(), length.as_()); - } - } - _ => { - todo!("Nested list is not supported yet."); - } - } -} - -fn fixed_size_list_array_to_vector(array: &FixedSizeListArray, out: &mut ListVector) { - let value_array = array.values(); - let mut child = out.child(value_array.len()); - match value_array.data_type() { - dt if dt.is_primitive() => { - primitive_array_to_vector(value_array.as_ref(), &mut child); - for i in 0..array.len() { - let offset = array.value_offset(i); - let length = array.value_length(); - out.set_entry(i, offset as usize, length as usize); - } - out.set_len(value_array.len()); - } - _ => { - todo!("Nested list is not supported yet."); - } - } -} - -fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) { - for i in 0..array.num_columns() { - let column = array.column(i); - match column.data_type() { - dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { - primitive_array_to_vector(column, &mut out.child(i)); - } - DataType::Utf8 => { - string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i)); - } - DataType::List(_) => { - list_array_to_vector( - as_list_array(column.as_ref()), - &mut out.list_vector_child(i), - ); - } - DataType::LargeList(_) => { - list_array_to_vector( - as_large_list_array(column.as_ref()), - &mut out.list_vector_child(i), - ); - } - DataType::FixedSizeList(_, _) => { - fixed_size_list_array_to_vector( - as_fixed_size_list_array(column.as_ref()), - &mut out.list_vector_child(i), - ); - } - DataType::Struct(_) => { - let struct_array = as_struct_array(column.as_ref()); - let mut struct_vector = out.struct_vector_child(i); - struct_array_to_vector(struct_array, &mut struct_vector); - } - _ => { - todo!("Unsupported data type: {}, please file an issue at https://github.com/lancedb/lance", column.data_type()); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use arrow_schema::{Field, Schema}; - - // use libduckdb to link to a duckdb binary. - #[allow(unused_imports)] - use libduckdb_sys; - - #[test] - fn test_record_batch_to_data_chunk() { - let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Boolean, false)])); - - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(BooleanArray::from(vec![true, false, true]))], - ) - .unwrap(); - - let logical_types = schema - .fields - .iter() - .map(|f| to_duckdb_logical_type(f.data_type()).unwrap()) - .collect::>(); - let mut chunk = DataChunk::new(&logical_types); - - record_batch_to_duckdb_data_chunk(&batch, &mut chunk).unwrap(); - assert_eq!(chunk.len(), 3); - let vector = chunk.flat_vector(0); - assert_eq!(LogicalTypeId::Boolean, vector.logical_type().id()); - assert_eq!(vector.as_slice::()[0], true); - assert_eq!(vector.as_slice::()[1], false); - assert_eq!(vector.as_slice::()[2], true); - } -} diff --git a/integration/duckdb_lance/src/error.rs b/integration/duckdb_lance/src/error.rs deleted file mode 100644 index aac3495c112..00000000000 --- a/integration/duckdb_lance/src/error.rs +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#[derive(Debug)] -pub enum Error { - DuckDB(String), -} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let (catalog, message) = match self { - Self::DuckDB(s) => ("DuckDB", s.as_str()), - }; - write!(f, "Lance({catalog}): {message}") - } -} - -pub type Result = std::result::Result; - -// TODO: contribute to upstream (duckdb-extension) to have a Error impl. -impl From> for Error { - fn from(value: Box) -> Self { - Self::DuckDB(value.to_string()) - } -} - -impl From for duckdb_ext::Error { - fn from(e: Error) -> Self { - Self::DuckDB(e.to_string()) - } -} - -impl From for Error { - fn from(e: duckdb_ext::Error) -> Self { - Self::DuckDB(e.to_string()) - } -} diff --git a/integration/duckdb_lance/src/extension.c b/integration/duckdb_lance/src/extension.c deleted file mode 100644 index 35e16163699..00000000000 --- a/integration/duckdb_lance/src/extension.c +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2023 Lance Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/// Callbacks for duckdb to load lance (rust) code. - -#include "extension.h" - -const char* lance_version_rust(void); -void lance_init_rust(void* db); - -DUCKDB_EXTENSION_API const char* lance_version() { - return lance_version_rust(); -} - -DUCKDB_EXTENSION_API void lance_init(void* db) { - lance_init_rust(db); -} - diff --git a/integration/duckdb_lance/src/extension.h b/integration/duckdb_lance/src/extension.h deleted file mode 100644 index f58dd56c9e2..00000000000 --- a/integration/duckdb_lance/src/extension.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2023 Lance Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#define DUCKDB_EXTENSION_API - -#include "duckdb.h" diff --git a/integration/duckdb_lance/src/lib.rs b/integration/duckdb_lance/src/lib.rs deleted file mode 100644 index 8af4410ccc8..00000000000 --- a/integration/duckdb_lance/src/lib.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::ffi::c_char; - -use duckdb_ext::ffi::{_duckdb_database, duckdb_library_version}; -use duckdb_ext::Database; -use tokio::runtime::Runtime; - -mod arrow; -pub mod error; -mod scan; - -use crate::scan::scan_table_function; -use error::{Error, Result}; - -lazy_static::lazy_static! { - static ref RUNTIME: Runtime = tokio::runtime::Runtime::new() - .expect("Creating Tokio runtime"); -} - -#[no_mangle] -pub extern "C" fn lance_version_rust() -> *const c_char { - unsafe { duckdb_library_version() } -} - -#[no_mangle] -pub unsafe extern "C" fn lance_init_rust(db: *mut _duckdb_database) { - init(db).expect("duckdb lance extension init failed"); -} - -unsafe fn init(db: *mut _duckdb_database) -> Result<()> { - let db = Database::from(db); - let table_function = scan_table_function(); - let connection = db.connect()?; - connection.register_table_function(table_function)?; - Ok(()) -} - -#[cfg(test)] -mod tests {} diff --git a/integration/duckdb_lance/src/scan.rs b/integration/duckdb_lance/src/scan.rs deleted file mode 100644 index 8fe8d407e82..00000000000 --- a/integration/duckdb_lance/src/scan.rs +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2023 Lance Developers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::ffi::{c_char, c_void, CStr, CString}; - -use duckdb_ext::ffi::{ - duckdb_bind_info, duckdb_data_chunk, duckdb_free, duckdb_function_info, duckdb_init_info, - duckdb_vector_size, -}; -use duckdb_ext::table_function::{BindInfo, InitInfo, TableFunction}; -use duckdb_ext::{DataChunk, FunctionInfo, LogicalType, LogicalTypeId}; -use futures::StreamExt; -use lance::dataset::scanner::DatasetRecordBatchStream; -use lance::dataset::Dataset; - -use crate::arrow::{record_batch_to_duckdb_data_chunk, to_duckdb_logical_type}; - -#[repr(C)] -struct ScanBindData { - /// Dataset URI - uri: *mut c_char, -} - -impl ScanBindData { - fn new(uri: &str) -> Self { - Self { - uri: CString::new(uri).expect("Bind uri").into_raw(), - } - } -} - -/// Drop the ScanBindData from C. -/// -/// # Safety -unsafe extern "C" fn drop_scan_bind_data_c(v: *mut c_void) { - let actual = v.cast::(); - drop(CString::from_raw((*actual).uri.cast())); - duckdb_free(v); -} - -#[repr(C)] -struct ScanInitData { - stream: *mut DatasetRecordBatchStream, - - done: bool, -} - -impl ScanInitData { - fn new(stream: Box) -> Self { - Self { - stream: Box::into_raw(stream), - done: false, - } - } -} - -#[no_mangle] -unsafe extern "C" fn read_lance(info: duckdb_function_info, output: duckdb_data_chunk) { - let info = FunctionInfo::from(info); - let mut output = DataChunk::from(output); - - let init_data = info.init_data::(); - let batch = match crate::RUNTIME.block_on(async { (*(*init_data).stream).next().await }) { - Some(Ok(b)) => Some(b), - Some(Err(e)) => { - info.set_error(duckdb_ext::Error::DuckDB(e.to_string())); - return; - } - None => None, - }; - - if let Some(b) = batch { - if let Err(e) = record_batch_to_duckdb_data_chunk(&b, &mut output) { - info.set_error(e.into()) - }; - } else { - (*init_data).done = true; - output.set_len(0); - } -} - -#[no_mangle] -unsafe extern "C" fn read_lance_init(info: duckdb_init_info) { - let info = InitInfo::from(info); - let bind_data = info.bind_data::(); - - let uri = CStr::from_ptr((*bind_data).uri); - let dataset = - match crate::RUNTIME.block_on(async { Dataset::open(uri.to_str().unwrap()).await }) { - Ok(d) => Box::new(d), - Err(e) => { - info.set_error(duckdb_ext::Error::DuckDB(e.to_string())); - return; - } - }; - let projected_columns = info.projected_column_ids(); - let columns = projected_columns - .iter() - .map(|proj_id| dataset.schema().fields[*proj_id].name.as_str()) - .collect::>(); - - let stream = match crate::RUNTIME.block_on(async { - dataset - .scan() - .project(columns.as_slice()) - .unwrap() - .batch_size(duckdb_vector_size() as usize) - .try_into_stream() - .await - }) { - Ok(s) => Box::new(s), - Err(e) => { - info.set_error(duckdb_ext::Error::DuckDB(e.to_string())); - return; - } - }; - - let init_data = Box::new(ScanInitData::new(stream)); - info.set_init_data(Box::into_raw(init_data).cast(), Some(duckdb_free)); -} - -#[no_mangle] -unsafe extern "C" fn read_lance_bind_c(bind_info: duckdb_bind_info) { - let bind_info = BindInfo::from(bind_info); - assert!(bind_info.num_parameters() > 0); - - read_lance_bind(&bind_info); -} - -fn read_lance_bind(bind: &BindInfo) { - let uri = bind.parameter(0).to_string(); - let dataset = match crate::RUNTIME.block_on(async { Dataset::open(&uri).await }) { - Ok(d) => d, - Err(e) => { - bind.set_error(duckdb_ext::Error::DuckDB(e.to_string())); - return; - } - }; - - let schema = dataset.schema(); - for field in schema.fields.iter() { - bind.add_result_column( - &field.name, - to_duckdb_logical_type(&field.data_type()).unwrap(), - ); - } - - let bind_data = Box::new(ScanBindData::new(&uri)); - bind.set_bind_data(Box::into_raw(bind_data).cast(), Some(drop_scan_bind_data_c)); -} - -pub fn scan_table_function() -> TableFunction { - let table_function = TableFunction::new("lance_scan"); - let logical_type = LogicalType::new(LogicalTypeId::Varchar); - table_function.add_parameter(&logical_type); - - table_function.set_function(Some(read_lance)); - table_function.set_init(Some(read_lance_init)); - table_function.set_bind(Some(read_lance_bind_c)); - table_function.pushdown(true); - // TODO: add filter push down. - table_function -} diff --git a/java/.gitignore b/java/.gitignore index b3925a2ff8d..43ba4ff2778 100644 --- a/java/.gitignore +++ b/java/.gitignore @@ -1,2 +1,3 @@ *.iml -spark/dependency-reduced-pom.xml \ No newline at end of file +spark/dependency-reduced-pom.xml +.java-version diff --git a/java/.mvn/wrapper/maven-wrapper.properties b/java/.mvn/wrapper/maven-wrapper.properties new file mode 100644 index 00000000000..d58dfb70bab --- /dev/null +++ b/java/.mvn/wrapper/maven-wrapper.properties @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +wrapperVersion=3.3.2 +distributionType=only-script +distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.9/apache-maven-3.9.9-bin.zip diff --git a/java/.scalafmt.conf b/java/.scalafmt.conf new file mode 100644 index 00000000000..844652cf00e --- /dev/null +++ b/java/.scalafmt.conf @@ -0,0 +1,28 @@ +version = 3.7.5 +runner.dialect=scala212 +project.git=true + +align.preset = none +align.openParenDefnSite = false +align.openParenCallSite = false +align.stripMargin = true +align.tokens = [] +assumeStandardLibraryStripMargin = true +danglingParentheses.preset = false +docstrings.style = Asterisk +docstrings.wrap = no +importSelectors = singleLine +indent.extendSite = 2 +literals.hexDigits = Upper +maxColumn = 100 +newlines.source = keep +newlines.topLevelStatementBlankLines = [] +optIn.configStyleArguments = false +rewrite.imports.groups = [ + ["com\\.lancedb\\.lance\\..*"], + ["(?!com\\.lancedb\\.lance\\.).*"], + ["javax?\\..*"], + ["scala\\..*"], +] +rewrite.imports.sort = scalastyle +rewrite.rules = [Imports, SortModifiers] diff --git a/java/README.md b/java/README.md new file mode 100644 index 00000000000..68088a54a00 --- /dev/null +++ b/java/README.md @@ -0,0 +1,241 @@ +# Java bindings and SDK for Lance Data Format + +> :warning: **Under heavy development** + +
+

+ +Lance Logo + +Lance is a new columnar data format for data science and machine learning +

+ +Why you should use Lance +1. It is an order of magnitude faster than Parquet for point queries and nested data structures common to DS/ML +2. It comes with a fast vector index that delivers sub-millisecond nearest neighbor search performance +3. It is automatically versioned and supports lineage and time-travel for full reproducibility +4. It is integrated with duckdb/pandas/polars already. Easily convert from/to Parquet in 2 lines of code + +## Quick start + +Introduce the Lance SDK Java Maven dependency(It is recommended to choose the latest version.): + +```shell + + com.lancedb + lance-core + 0.18.0 + +``` + +### Basic I/O + +* create empty dataset + +```java +void createDataset() throws IOException, URISyntaxException { + String datasetPath = tempDir.resolve("write_stream").toString(); + Schema schema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8())), + null); + try (BufferAllocator allocator = new RootAllocator();) { + Dataset.create(allocator, datasetPath, schema, new WriteParams.Builder().build()); + try (Dataset dataset = Dataset.create(allocator, datasetPath, schema, new WriteParams.Builder().build());) { + dataset.version(); + dataset.latestVersion(); + } + } +} +``` + +* create and write a Lance dataset + +```java +void createAndWriteDataset() throws IOException, URISyntaxException { + Path path = ""; // the original source path + String datasetPath = ""; // specify a path point to a dataset + try (BufferAllocator allocator = new RootAllocator(); + ArrowFileReader reader = + new ArrowFileReader( + new SeekableReadChannel( + new ByteArrayReadableSeekableByteChannel(Files.readAllBytes(path))), allocator); + ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, reader, arrowStream); + try (Dataset dataset = + Dataset.create( + allocator, + arrowStream, + datasetPath, + new WriteParams.Builder() + .withMaxRowsPerFile(10) + .withMaxRowsPerGroup(20) + .withMode(WriteParams.WriteMode.CREATE) + .withStorageOptions(new HashMap<>()) + .build())) { + // access dataset + } + } +} +``` +* read dataset + +```java +void readDataset() { + String datasetPath = ""; // specify a path point to a dataset + try (BufferAllocator allocator = new RootAllocator()) { + try (Dataset dataset = Dataset.open(datasetPath, allocator)) { + dataset.countRows(); + dataset.getSchema(); + dataset.version(); + dataset.latestVersion(); + // access more information + } + } +} +``` + +* drop dataset + +```java +void dropDataset() { + String datasetPath = tempDir.resolve("drop_stream").toString(); + Dataset.drop(datasetPath, new HashMap<>()); +} +``` + +### Random Access + +```java +void randomAccess() { + String datasetPath = ""; // specify a path point to a dataset + try (BufferAllocator allocator = new RootAllocator()) { + try (Dataset dataset = Dataset.open(datasetPath, allocator)) { + List indices = Arrays.asList(1L, 4L); + List columns = Arrays.asList("id", "name"); + try (ArrowReader reader = dataset.take(indices, columns)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + result.getRowCount(); + + for (int i = 0; i < indices.size(); i++) { + result.getVector("id").getObject(i); + result.getVector("name").getObject(i); + } + } + } + } + } +} +``` + +### Schema evolution + +* add columns + +```java +void addColumns() { + String datasetPath = ""; // specify a path point to a dataset + try (BufferAllocator allocator = new RootAllocator()) { + try (Dataset dataset = Dataset.open(datasetPath, allocator)) { + SqlExpressions sqlExpressions = new SqlExpressions.Builder().withExpression("double_id", "id * 2").build(); + dataset.addColumns(sqlExpressions, Optional.empty()); + } + } +} +``` + +* alter columns + +```java +void alterColumns() { + String datasetPath = ""; // specify a path point to a dataset + try (BufferAllocator allocator = new RootAllocator()) { + try (Dataset dataset = Dataset.open(datasetPath, allocator)) { + ColumnAlteration nameColumnAlteration = + new ColumnAlteration.Builder("name") + .rename("new_name") + .nullable(true) + .castTo(new ArrowType.Utf8()) + .build(); + + dataset.alterColumns(Collections.singletonList(nameColumnAlteration)); + } + } +} +``` + +* drop columns + +```java +void dropColumns() { + String datasetPath = ""; // specify a path point to a dataset + try (BufferAllocator allocator = new RootAllocator()) { + try (Dataset dataset = Dataset.open(datasetPath, allocator)) { + dataset.dropColumns(Collections.singletonList("name")); + } + } +} +``` + +## Integrations + +This section introduces the ecosystem integration with Lance format. +With the integration, users are able to access lance dataset with other technology or tools. + +### Spark connector + +The [spark](https://github.com/lancedb/lance/tree/main/java/spark) module is a standard maven module. +It is the implementation of spark-lance connector that allows Apache Spark to efficiently access datasets stored in Lance format. +More details please see the [README](https://github.com/lancedb/lance/blob/main/java/spark/README.md) file. + +## Contributing + +From the codebase dimension, the lance project is a multiple-lang project. All Java-related code is located in the `java` directory. +And the whole `java` dir is a standard maven project(named `lance-parent`) can be imported into any IDEs support java project. + +Overall, it contains two Maven sub-modules: + +* lance-core: the core module of Lance Java binding, including `lance-jni`. +* lance-spark: the spark connector module. + +To build the project, you can run the following command: + +```shell +mvn clean package +``` + +if you only want to build rust code(`lance-jni`), you can run the following command: + +```shell +cargo build +``` + +The java module uses `spotless` maven plugin to format the code and check the license header. +And it is applied in the `validate` phase automatically. + +### Environment(IDE) setup + +Firstly, clone the repository into your local machine: + +```shell +git clone https://github.com/lancedb/lance.git +``` + +Then, import the `java` directory into your favorite IDEs, such as IntelliJ IDEA, Eclipse, etc. + +Due to the java module depends on the features provided by rust module. So, you also need to make sure you have installed rust in your local. + +To install rust, please refer to the [official documentation](https://www.rust-lang.org/tools/install). + +And you also need to install the rust plugin for your IDE. + +Then, you can build the whole java module: + +```shell +mvn clean package +``` + +Running these commands, it builds the rust jni binding codes automatically. diff --git a/java/core/lance-jni/Cargo.toml b/java/core/lance-jni/Cargo.toml index 7e49bb9ff7f..a26eee044a3 100644 --- a/java/core/lance-jni/Cargo.toml +++ b/java/core/lance-jni/Cargo.toml @@ -19,9 +19,12 @@ lance-encoding = { workspace = true } lance-linalg = { workspace = true } lance-index = { workspace = true } lance-io.workspace = true +lance-core.workspace = true +lance-file.workspace = true arrow = { workspace = true, features = ["ffi"] } arrow-schema.workspace = true datafusion.workspace = true +object_store.workspace = true tokio.workspace = true jni = "0.21.1" snafu.workspace = true diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index e8c2c94cf39..3a573620625 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -14,30 +14,37 @@ use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; -use crate::traits::FromJString; -use crate::utils::{extract_write_params, get_index_params}; +use crate::traits::{export_vec, import_vec, FromJObjectWithEnv, FromJString}; +use crate::utils::{extract_storage_options, extract_write_params, get_index_params}; use crate::{traits::IntoJava, RT}; use arrow::array::RecordBatchReader; use arrow::datatypes::Schema; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::ipc::writer::StreamWriter; use arrow::record_batch::RecordBatchIterator; +use arrow_schema::DataType; use jni::objects::{JMap, JString, JValue}; -use jni::sys::jlong; use jni::sys::{jboolean, jint}; +use jni::sys::{jbyteArray, jlong}; use jni::{objects::JObject, JNIEnv}; use lance::dataset::builder::DatasetBuilder; +use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt}; use lance::dataset::transaction::Operation; -use lance::dataset::{Dataset, ReadParams, WriteParams}; -use lance::io::ObjectStoreParams; +use lance::dataset::{ + ColumnAlteration, Dataset, NewColumnTransform, ProjectionRequest, ReadParams, WriteParams, +}; +use lance::io::{ObjectStore, ObjectStoreParams}; use lance::table::format::Fragment; use lance::table::format::Index; +use lance_core::datatypes::Schema as LanceSchema; use lance_index::DatasetIndexExt; use lance_index::{IndexParams, IndexType}; use lance_io::object_store::ObjectStoreRegistry; use std::collections::HashMap; use std::iter::empty; +use std::str::FromStr; use std::sync::Arc; pub const NATIVE_DATASET: &str = "nativeDatasetHandle"; @@ -48,6 +55,23 @@ pub struct BlockingDataset { } impl BlockingDataset { + pub fn drop(uri: &str, storage_options: HashMap) -> Result<()> { + RT.block_on(async move { + let registry = Arc::new(ObjectStoreRegistry::default()); + let object_store_params = ObjectStoreParams { + storage_options: Some(storage_options.clone()), + ..Default::default() + }; + let (object_store, path) = + ObjectStore::from_uri_and_params(registry, uri, &object_store_params) + .await + .map_err(|e| Error::io_error(e.to_string()))?; + object_store + .remove_dir_all(path) + .await + .map_err(|e| Error::io_error(e.to_string())) + }) + } pub fn write( reader: impl RecordBatchReader + Send + 'static, uri: &str, @@ -92,7 +116,6 @@ impl BlockingDataset { read_version: Option, storage_options: HashMap, ) -> Result { - let object_store_registry = Arc::new(ObjectStoreRegistry::default()); let inner = RT.block_on(Dataset::commit( uri, operation, @@ -102,7 +125,7 @@ impl BlockingDataset { ..Default::default() }), None, - object_store_registry, + Default::default(), false, // TODO: support enable_v2_manifest_paths ))?; Ok(Self { inner }) @@ -133,6 +156,11 @@ impl BlockingDataset { Ok(rows) } + pub fn calculate_data_stats(&self) -> Result { + let stats = RT.block_on(Arc::new(self.clone().inner).calculate_data_stats())?; + Ok(stats) + } + pub fn list_indexes(&self) -> Result>> { let indexes = RT.block_on(self.inner.load_indices())?; Ok(indexes) @@ -199,6 +227,20 @@ fn inner_create_with_ffi_schema<'local>( ) } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_drop<'local>( + mut env: JNIEnv<'local>, + _obj: JObject, + path: JString<'local>, + storage_options_obj: JObject<'local>, +) -> JObject<'local> { + let path_str = ok_or_throw!(env, path.extract(&mut env)); + let storage_options = + ok_or_throw!(env, extract_storage_options(&mut env, &storage_options_obj)); + ok_or_throw!(env, BlockingDataset::drop(&path_str, storage_options)); + JObject::null() +} + #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_Dataset_createWithFfiStream<'local>( mut env: JNIEnv<'local>, @@ -312,7 +354,7 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_commitAppend<'local>( _obj: JObject, path: JString, read_version_obj: JObject, // Optional - fragments_obj: JObject, // List, String is json serialized Fragment + fragments_obj: JObject, // List storage_options_obj: JObject, // Map ) -> JObject<'local> { ok_or_throw!( @@ -331,18 +373,70 @@ pub fn inner_commit_append<'local>( env: &mut JNIEnv<'local>, path: JString, read_version_obj: JObject, // Optional - fragments_obj: JObject, // List, String is json serialized Fragment) + fragment_objs: JObject, // List storage_options_obj: JObject, // Map ) -> Result> { - let json_fragments = env.get_strings(&fragments_obj)?; - let mut fragments: Vec = Vec::new(); - for json_fragment in json_fragments { - let fragment = Fragment::from_json(&json_fragment)?; - fragments.push(fragment); + let fragment_objs = import_vec(env, &fragment_objs)?; + let mut fragments = Vec::with_capacity(fragment_objs.len()); + for f in fragment_objs { + fragments.push(f.extract_object(env)?); } let op = Operation::Append { fragments }; let path_str = path.extract(env)?; let read_version = env.get_u64_opt(&read_version_obj)?; + let storage_options = extract_storage_options(env, &storage_options_obj)?; + let dataset = BlockingDataset::commit(&path_str, op, read_version, storage_options)?; + dataset.into_java(env) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_commitOverwrite<'local>( + mut env: JNIEnv<'local>, + _obj: JObject, + path: JString, + arrow_schema_addr: jlong, + read_version_obj: JObject, // Optional + fragments_obj: JObject, // List + storage_options_obj: JObject, // Map +) -> JObject<'local> { + ok_or_throw!( + env, + inner_commit_overwrite( + &mut env, + path, + arrow_schema_addr, + read_version_obj, + fragments_obj, + storage_options_obj + ) + ) +} + +pub fn inner_commit_overwrite<'local>( + env: &mut JNIEnv<'local>, + path: JString, + arrow_schema_addr: jlong, + read_version_obj: JObject, // Optional + fragments_obj: JObject, // List + storage_options_obj: JObject, // Map +) -> Result> { + let fragment_objs = import_vec(env, &fragments_obj)?; + let mut fragments = Vec::with_capacity(fragment_objs.len()); + for f in fragment_objs { + fragments.push(f.extract_object(env)?); + } + let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema; + let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; + let arrow_schema = Schema::try_from(&c_schema)?; + let schema = LanceSchema::try_from(&arrow_schema)?; + + let op = Operation::Overwrite { + fragments, + schema, + config_upsert_values: None, + }; + let path_str = path.extract(env)?; + let read_version = env.get_u64_opt(&read_version_obj)?; let jmap = JMap::from_env(env, &storage_options_obj)?; let storage_options: HashMap = env.with_local_frame(16, |env| { let mut map = HashMap::new(); @@ -356,6 +450,7 @@ pub fn inner_commit_append<'local>( } Ok::<_, Error>(map) })?; + let dataset = BlockingDataset::commit(&path_str, op, read_version, storage_options)?; dataset.into_java(env) } @@ -486,14 +581,14 @@ fn inner_open_native<'local>( } #[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_getJsonFragments<'a>( +pub extern "system" fn Java_com_lancedb_lance_Dataset_getFragmentsNative<'a>( mut env: JNIEnv<'a>, jdataset: JObject, ) -> JObject<'a> { - ok_or_throw!(env, inner_get_json_fragments(&mut env, jdataset)) + ok_or_throw!(env, inner_get_fragments(&mut env, jdataset)) } -fn inner_get_json_fragments<'local>( +fn inner_get_fragments<'local>( env: &mut JNIEnv<'local>, jdataset: JObject, ) -> Result> { @@ -502,22 +597,37 @@ fn inner_get_json_fragments<'local>( unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; dataset.inner.get_fragments() }; + let fragments = fragments + .iter() + .map(|f| f.metadata().clone()) + .collect::>(); + export_vec(env, &fragments) +} - let array_list_class = env.find_class("java/util/ArrayList")?; - - let array_list = env.new_object(array_list_class, "()V", &[])?; +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_getFragmentNative<'a>( + mut env: JNIEnv<'a>, + jdataset: JObject, + fragment_id: jint, +) -> JObject<'a> { + ok_or_throw!(env, inner_get_fragment(&mut env, jdataset, fragment_id)) +} - for fragment in fragments { - let json_string = serde_json::to_string(fragment.metadata())?; - let jstring = env.new_string(json_string)?; - env.call_method( - &array_list, - "add", - "(Ljava/lang/Object;)Z", - &[(&jstring).into()], - )?; - } - Ok(array_list) +fn inner_get_fragment<'local>( + env: &mut JNIEnv<'local>, + jdataset: JObject, + fragment_id: jint, +) -> Result> { + let fragment = { + let dataset = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; + dataset.inner.get_fragment(fragment_id as usize) + }; + let obj = match fragment { + Some(f) => f.metadata().into_java(env)?, + None => JObject::default(), + }; + Ok(obj) } #[no_mangle] @@ -548,6 +658,29 @@ fn inner_import_ffi_schema( Ok(()) } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeUri<'local>( + mut env: JNIEnv<'local>, + java_dataset: JObject, +) -> JString<'local> { + ok_or_throw_with_return!( + env, + inner_uri(&mut env, java_dataset).map_err(|err| Error::input_error(err.to_string())), + JString::from(JObject::null()) + ) +} + +fn inner_uri<'local>(env: &mut JNIEnv<'local>, java_dataset: JObject) -> Result> { + let uri = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + dataset_guard.inner.uri().to_string() + }; + + let jstring_uri = env.new_string(uri)?; + Ok(jstring_uri) +} + #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeVersion( mut env: JNIEnv, @@ -580,14 +713,61 @@ fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> Result pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeCountRows( mut env: JNIEnv, java_dataset: JObject, -) -> jint { - ok_or_throw_with_return!(env, inner_count_rows(&mut env, java_dataset), -1) as jint + filter_jobj: JObject, // Optional +) -> jlong { + ok_or_throw_with_return!( + env, + inner_count_rows(&mut env, java_dataset, filter_jobj), + -1 + ) as jlong } -fn inner_count_rows(env: &mut JNIEnv, java_dataset: JObject) -> Result { +fn inner_count_rows( + env: &mut JNIEnv, + java_dataset: JObject, + filter_jobj: JObject, +) -> Result { + let filter = env.get_string_opt(&filter_jobj)?; let dataset_guard = unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; - dataset_guard.count_rows(None) + dataset_guard.count_rows(filter) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeGetDataStatistics<'local>( + mut env: JNIEnv<'local>, + java_dataset: JObject, +) -> JObject<'local> { + ok_or_throw!(env, inner_get_data_statistics(&mut env, java_dataset)) +} + +fn inner_get_data_statistics<'local>( + env: &mut JNIEnv<'local>, + java_dataset: JObject, +) -> Result> { + let stats = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + dataset_guard.calculate_data_stats()? + }; + let data_stats = env.new_object("com/lancedb/lance/ipc/DataStatistics", "()V", &[])?; + + for field in stats.fields { + let id = field.id as jint; + let byte_size = field.bytes_on_disk as jlong; + let filed_jobj = env.new_object( + "com/lancedb/lance/ipc/FieldStatistics", + "(IJ)V", + &[JValue::Int(id), JValue::Long(byte_size)], + )?; + env.call_method( + &data_stats, + "addFiledStatistics", + "(Lcom/lancedb/lance/ipc/FieldStatistics;)V", + &[JValue::Object(&filed_jobj)], + )?; + } + Ok(data_stats) } #[no_mangle] @@ -626,3 +806,317 @@ fn inner_list_indexes<'local>( Ok(array_list) } + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake( + mut env: JNIEnv, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> jbyteArray { + match inner_take(&mut env, java_dataset, indices_obj, columns_obj) { + Ok(byte_array) => byte_array, + Err(e) => { + let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e)); + std::ptr::null_mut() + } + } +} + +fn inner_take( + env: &mut JNIEnv, + java_dataset: JObject, + indices_obj: JObject, // List + columns_obj: JObject, // List +) -> Result { + let indices: Vec = env.get_longs(&indices_obj)?; + let indices_u64: Vec = indices.iter().map(|&x| x as u64).collect(); + let indices_slice: &[u64] = &indices_u64; + let columns: Vec = env.get_strings(&columns_obj)?; + + let result = { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + let dataset = &dataset_guard.inner; + + let projection = ProjectionRequest::from_columns(columns, dataset.schema()); + + match RT.block_on(dataset.take(indices_slice, projection)) { + Ok(res) => res, + Err(e) => { + return Err(e.into()); + } + } + }; + + let mut buffer = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?; + writer.write(&result)?; + writer.finish()?; + } + + let byte_array = env.byte_array_from_slice(&buffer)?; + Ok(**byte_array) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeDelete( + mut env: JNIEnv, + java_dataset: JObject, + predicate: JString, +) { + ok_or_throw_without_return!(env, inner_delete(&mut env, java_dataset, predicate)) +} + +fn inner_delete(env: &mut JNIEnv, java_dataset: JObject, predicate: JString) -> Result<()> { + let predicate_str = predicate.extract(env)?; + let mut dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + RT.block_on(dataset_guard.inner.delete(&predicate_str))?; + Ok(()) +} + +////////////////////////////// +// Schema evolution Methods // +////////////////////////////// +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeDropColumns( + mut env: JNIEnv, + java_dataset: JObject, + columns_obj: JObject, // List +) { + ok_or_throw_without_return!(env, inner_drop_columns(&mut env, java_dataset, columns_obj)) +} + +fn inner_drop_columns( + env: &mut JNIEnv, + java_dataset: JObject, + columns_obj: JObject, // List +) -> Result<()> { + let columns: Vec = env.get_strings(&columns_obj)?; + let columns_slice: Vec<&str> = columns.iter().map(AsRef::as_ref).collect(); + let mut dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + RT.block_on(dataset_guard.inner.drop_columns(&columns_slice))?; + Ok(()) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeAlterColumns( + mut env: JNIEnv, + java_dataset: JObject, + column_alterations_obj: JObject, // List +) { + ok_or_throw_without_return!( + env, + inner_alter_columns(&mut env, java_dataset, column_alterations_obj) + ) +} + +fn create_column_alteration( + env: &mut JNIEnv, + column_alteration_jobj: JObject, // ColumnAlteration +) -> Result { + let path_obj = env + .get_field(&column_alteration_jobj, "path", "Ljava/lang/String;")? + .l()?; + let path_jstring: JString = path_obj.into(); + let path: String = env.get_string(&path_jstring)?.into(); + + let rename_obj = env + .get_field(&column_alteration_jobj, "rename", "Ljava/util/Optional;")? + .l()?; + let rename = if env.call_method(&rename_obj, "isPresent", "()Z", &[])?.z()? { + let jstring: JObject = env + .call_method(rename_obj, "get", "()Ljava/lang/Object;", &[])? + .l()?; + let jstring: JString = jstring.into(); + let rename_str: String = env.get_string(&jstring)?.into(); // Intermediate variable + Some(rename_str) + } else { + None + }; + + let nullable_obj = env + .get_field(&column_alteration_jobj, "nullable", "Ljava/util/Optional;")? + .l()?; + let nullable = if env + .call_method(&nullable_obj, "isPresent", "()Z", &[])? + .z()? + { + let nullable_value = env + .call_method(nullable_obj, "get", "()Ljava/lang/Object;", &[])? + .l()?; + Some( + env.call_method(nullable_value, "booleanValue", "()Z", &[])? + .z()?, + ) + } else { + None + }; + + let data_type_obj = env + .get_field(&column_alteration_jobj, "dataType", "Ljava/util/Optional;")? + .l()?; + let data_type = if env + .call_method(&data_type_obj, "isPresent", "()Z", &[])? + .z()? + { + let j_data_type: JObject = env + .call_method(data_type_obj, "get", "()Ljava/lang/Object;", &[])? + .l()?; + let jstring: JString = env + .call_method(j_data_type, "toString", "()Ljava/lang/String;", &[])? + .l()? + .into(); + let data_type_str: String = env.get_string(&jstring)?.into(); // Intermediate variable + DataType::from_str(&data_type_str) + .map_err(|e| Error::input_error(e.to_string())) + .ok() + } else { + None + }; + + Ok(ColumnAlteration { + path, + rename, + nullable, + data_type, + }) +} + +fn inner_alter_columns( + env: &mut JNIEnv, + java_dataset: JObject, + column_alterations_obj: JObject, // List +) -> Result<()> { + let list = env.get_list(&column_alterations_obj)?; + let mut iter = list.iter(env)?; + let mut column_alterations = Vec::new(); + + while let Some(elem) = iter.next(env)? { + let alteration = create_column_alteration(env, elem)?; + column_alterations.push(alteration); + } + + let mut dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + + RT.block_on(dataset_guard.inner.alter_columns(&column_alterations))?; + Ok(()) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeAddColumnsBySqlExpressions( + mut env: JNIEnv, + java_dataset: JObject, + sql_expressions: JObject, // SqlExpressions + batch_size: JObject, // Optional +) { + ok_or_throw_without_return!( + env, + inner_add_columns_by_sql_expressions(&mut env, java_dataset, sql_expressions, batch_size) + ) +} + +fn inner_add_columns_by_sql_expressions( + env: &mut JNIEnv, + java_dataset: JObject, + sql_expressions: JObject, // SqlExpressions + batch_size: JObject, // Optional +) -> Result<()> { + let sql_expressions_obj = env + .get_field(sql_expressions, "sqlExpressions", "Ljava/util/List;")? + .l()?; + + let sql_expressions_obj_list = env.get_list(&sql_expressions_obj)?; + let mut expressions: Vec<(String, String)> = Vec::new(); + + let mut iterator = sql_expressions_obj_list.iter(env)?; + + while let Some(item) = iterator.next(env)? { + let name = env + .call_method(&item, "getName", "()Ljava/lang/String;", &[])? + .l()?; + let value = env + .call_method(&item, "getExpression", "()Ljava/lang/String;", &[])? + .l()?; + let key_str: String = env.get_string(&JString::from(name))?.into(); + let value_str: String = env.get_string(&JString::from(value))?.into(); + expressions.push((key_str, value_str)); + } + + let rust_transform = NewColumnTransform::SqlExpressions(expressions); + + let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? { + let batch_size_value = env.get_long_opt(&batch_size)?; + match batch_size_value { + Some(value) => Some( + value + .try_into() + .map_err(|_| Error::input_error("Batch size conversion error".to_string()))?, + ), + None => None, + } + } else { + None + }; + + let mut dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + + RT.block_on( + dataset_guard + .inner + .add_columns(rust_transform, None, batch_size), + )?; + Ok(()) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeAddColumnsByReader( + mut env: JNIEnv, + java_dataset: JObject, + arrow_array_stream_addr: jlong, + batch_size: JObject, // Optional +) { + ok_or_throw_without_return!( + env, + inner_add_columns_by_reader(&mut env, java_dataset, arrow_array_stream_addr, batch_size) + ) +} + +fn inner_add_columns_by_reader( + env: &mut JNIEnv, + java_dataset: JObject, + arrow_array_stream_addr: jlong, + batch_size: JObject, // Optional +) -> Result<()> { + let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; + + let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }?; + + let transform = NewColumnTransform::Reader(Box::new(reader)); + + let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? { + let batch_size_value = env.get_long_opt(&batch_size)?; + match batch_size_value { + Some(value) => Some( + value + .try_into() + .map_err(|_| Error::input_error("Batch size conversion error".to_string()))?, + ), + None => None, + } + } else { + None + }; + + let mut dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + + RT.block_on(dataset_guard.inner.add_columns(transform, None, batch_size))?; + + Ok(()) +} diff --git a/java/core/lance-jni/src/blocking_scanner.rs b/java/core/lance-jni/src/blocking_scanner.rs index 8a3168e161c..fd1db069feb 100644 --- a/java/core/lance-jni/src/blocking_scanner.rs +++ b/java/core/lance-jni/src/blocking_scanner.rs @@ -22,7 +22,7 @@ use arrow_schema::SchemaRef; use jni::objects::{JObject, JString}; use jni::sys::{jboolean, jint, JNI_TRUE}; use jni::{sys::jlong, JNIEnv}; -use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; +use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner}; use lance_io::ffi::to_ffi_arrow_array_stream; use lance_linalg::distance::DistanceType; @@ -79,7 +79,9 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_createScanner<'lo offset_obj: JObject, // Optional query_obj: JObject, // Optional with_row_id: jboolean, // boolean + with_row_address: jboolean, // boolean batch_readahead: jint, // int + column_orderings: JObject, // Optional> ) -> JObject<'local> { ok_or_throw!( env, @@ -95,7 +97,9 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_createScanner<'lo offset_obj, query_obj, with_row_id, - batch_readahead + with_row_address, + batch_readahead, + column_orderings ) ) } @@ -113,7 +117,9 @@ fn inner_create_scanner<'local>( offset_obj: JObject, query_obj: JObject, with_row_id: jboolean, + with_row_address: jboolean, batch_readahead: jint, + column_orderings: JObject, ) -> Result> { let fragment_ids_opt = env.get_ints_opt(&fragment_ids_obj)?; let dataset_guard = @@ -166,6 +172,10 @@ fn inner_create_scanner<'local>( scanner.with_row_id(); } + if with_row_address == JNI_TRUE { + scanner.with_row_address(); + } + let query_is_present = env.call_method(&query_obj, "isPresent", "()Z", &[])?.z()?; if query_is_present { @@ -205,6 +215,32 @@ fn inner_create_scanner<'local>( scanner.use_index(use_index); } scanner.batch_readahead(batch_readahead as usize); + + let column_orders_is_present = env + .call_method(&column_orderings, "isPresent", "()Z", &[])? + .z()?; + if column_orders_is_present { + let java_obj = env + .call_method(&column_orderings, "get", "()Ljava/lang/Object;", &[])? + .l()?; + + let list = env.get_list(&java_obj)?; + let mut iter = list.iter(env)?; + let mut results = Vec::with_capacity(list.size(env)? as usize); + while let Some(elem) = iter.next(env)? { + let column_name = env.get_string_from_method(&elem, "getColumnName")?; + let nulls_first = env.get_boolean_from_method(&elem, "isNullFirst")?; + let ascending = env.get_boolean_from_method(&elem, "isAscending")?; + let col_order = ColumnOrdering { + ascending, + nulls_first, + column_name, + }; + results.push(col_order) + } + scanner.order_by(Some(results))?; + } + let scanner = BlockingScanner::create(scanner); scanner.into_java(env) } diff --git a/java/core/lance-jni/src/error.rs b/java/core/lance-jni/src/error.rs index 36f47ffb566..05454c6111b 100644 --- a/java/core/lance-jni/src/error.rs +++ b/java/core/lance-jni/src/error.rs @@ -19,12 +19,13 @@ use jni::{errors::Error as JniError, JNIEnv}; use lance::error::Error as LanceError; use serde_json::Error as JsonError; -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum JavaExceptionClass { IllegalArgumentException, IOException, RuntimeException, UnsupportedOperationException, + AlreadyInException, } impl JavaExceptionClass { @@ -34,6 +35,8 @@ impl JavaExceptionClass { Self::IOException => "java/io/IOException", Self::RuntimeException => "java/lang/RuntimeException", Self::UnsupportedOperationException => "java/lang/UnsupportedOperationException", + // Included for display purposes. This is not a real exception. + Self::AlreadyInException => "AlreadyInException", } } } @@ -71,7 +74,18 @@ impl Error { Self::new(message, JavaExceptionClass::UnsupportedOperationException) } + pub fn in_exception() -> Self { + Self { + message: String::default(), + java_class: JavaExceptionClass::AlreadyInException, + } + } + pub fn throw(&self, env: &mut JNIEnv) { + if self.java_class == JavaExceptionClass::AlreadyInException { + // An exception is already in progress, so we don't need to throw another one. + return; + } if let Err(e) = env.throw_new(self.java_class.as_str(), &self.message) { eprintln!("Error when throwing Java exception: {:?}", e.to_string()); panic!("Error when throwing Java exception: {:?}", e); @@ -96,6 +110,7 @@ impl From for Error { | LanceError::InvalidInput { .. } => Self::input_error(err.to_string()), LanceError::IO { .. } => Self::io_error(err.to_string()), LanceError::NotSupported { .. } => Self::unsupported_error(err.to_string()), + LanceError::NotFound { .. } => Self::io_error(err.to_string()), _ => Self::runtime_error(err.to_string()), } } @@ -120,7 +135,12 @@ impl From for Error { impl From for Error { fn from(err: JniError) -> Self { - Self::runtime_error(err.to_string()) + match err { + // If we get this then it means that an exception was already in progress. We can't + // throw another one so we just return an error indicating that. + JniError::JavaException => Self::in_exception(), + _ => Self::runtime_error(err.to_string()), + } } } diff --git a/java/core/lance-jni/src/ffi.rs b/java/core/lance-jni/src/ffi.rs index dd11a1ee382..f92d3ec8735 100644 --- a/java/core/lance-jni/src/ffi.rs +++ b/java/core/lance-jni/src/ffi.rs @@ -26,6 +26,9 @@ pub trait JNIEnvExt { /// Get integers from Java List object. fn get_integers(&mut self, obj: &JObject) -> Result>; + /// Get longs from Java List object. + fn get_longs(&mut self, obj: &JObject) -> Result>; + /// Get strings from Java List object. fn get_strings(&mut self, obj: &JObject) -> Result>; @@ -127,6 +130,18 @@ impl JNIEnvExt for JNIEnv<'_> { Ok(results) } + fn get_longs(&mut self, obj: &JObject) -> Result> { + let list = self.get_list(obj)?; + let mut iter = list.iter(self)?; + let mut results = Vec::with_capacity(list.size(self)? as usize); + while let Some(elem) = iter.next(self)? { + let long_obj = self.call_method(elem, "longValue", "()J", &[])?; + let long_value = long_obj.j()?; + results.push(long_value); + } + Ok(results) + } + fn get_strings(&mut self, obj: &JObject) -> Result> { let list = self.get_list(obj)?; let mut iter = list.iter(self)?; @@ -348,6 +363,15 @@ pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseInts( ok_or_throw_without_return!(env, env.get_integers(&list_obj)); } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseLongs( + mut env: JNIEnv, + _obj: JObject, + list_obj: JObject, // List +) { + ok_or_throw_without_return!(env, env.get_longs(&list_obj)); +} + #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseIntsOpt( mut env: JNIEnv, diff --git a/java/core/lance-jni/src/file_reader.rs b/java/core/lance-jni/src/file_reader.rs new file mode 100644 index 00000000000..cf9d301d961 --- /dev/null +++ b/java/core/lance-jni/src/file_reader.rs @@ -0,0 +1,201 @@ +use std::sync::Arc; + +use crate::{ + error::{Error, Result}, + traits::IntoJava, + RT, +}; +use arrow::{array::RecordBatchReader, ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; +use arrow_schema::SchemaRef; +use jni::{ + objects::{JObject, JString}, + sys::{jint, jlong}, + JNIEnv, +}; +use lance::io::ObjectStore; +use lance_core::cache::FileMetadataCache; +use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; +use lance_file::v2::reader::{FileReader, FileReaderOptions}; +use lance_io::{ + scheduler::{ScanScheduler, SchedulerConfig}, + ReadBatchParams, +}; +use object_store::path::Path; + +pub const NATIVE_READER: &str = "nativeFileReaderHandle"; + +#[derive(Clone, Debug)] +pub struct BlockingFileReader { + pub(crate) inner: Arc, +} + +impl BlockingFileReader { + pub fn create(file_reader: Arc) -> Self { + Self { inner: file_reader } + } + + pub fn open_stream( + &self, + batch_size: u32, + ) -> Result> { + Ok(self.inner.read_stream_projected_blocking( + ReadBatchParams::RangeFull, + batch_size, + None, + FilterExpression::no_filter(), + )?) + } + + pub fn schema(&self) -> Result { + Ok(Arc::new(self.inner.schema().as_ref().into())) + } + + pub fn num_rows(&self) -> u64 { + self.inner.num_rows() + } +} + +impl IntoJava for BlockingFileReader { + fn into_java<'local>(self, env: &mut JNIEnv<'local>) -> Result> { + attach_native_reader(env, self) + } +} + +fn attach_native_reader<'local>( + env: &mut JNIEnv<'local>, + reader: BlockingFileReader, +) -> Result> { + let j_reader = create_java_reader_object(env)?; + unsafe { env.set_rust_field(&j_reader, NATIVE_READER, reader) }?; + Ok(j_reader) +} + +fn create_java_reader_object<'a>(env: &mut JNIEnv<'a>) -> Result> { + let res = env.new_object("com/lancedb/lance/file/LanceFileReader", "()V", &[])?; + Ok(res) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileReader_openNative<'local>( + mut env: JNIEnv<'local>, + _reader_class: JObject, + file_uri: JString, +) -> JObject<'local> { + ok_or_throw!(env, inner_open(&mut env, file_uri,)) +} + +fn inner_open<'local>(env: &mut JNIEnv<'local>, file_uri: JString) -> Result> { + let file_uri_str: String = env.get_string(&file_uri)?.into(); + + let reader = RT.block_on(async move { + let (obj_store, path) = ObjectStore::from_uri(&file_uri_str).await?; + let config = SchedulerConfig::max_bandwidth(&obj_store); + let scan_scheduler = ScanScheduler::new(obj_store, config); + + let file_scheduler = scan_scheduler.open_file(&Path::parse(&path)?).await?; + FileReader::try_open( + file_scheduler, + None, + Arc::::default(), + &FileMetadataCache::no_cache(), + FileReaderOptions::default(), + ) + .await + })?; + + let reader = BlockingFileReader::create(Arc::new(reader)); + + reader.into_java(env) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileReader_closeNative<'local>( + mut env: JNIEnv<'local>, + reader: JObject, +) -> JObject<'local> { + let maybe_err = + unsafe { env.take_rust_field::<_, _, BlockingFileReader>(reader, NATIVE_READER) }; + match maybe_err { + Ok(_) => {} + // We were already closed, do nothing + Err(jni::errors::Error::NullPtr(_)) => {} + Err(err) => Error::from(err).throw(&mut env), + } + JObject::null() +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileReader_numRowsNative( + mut env: JNIEnv<'_>, + reader: JObject, +) -> jlong { + match inner_num_rows(&mut env, reader) { + Ok(num_rows) => num_rows, + Err(e) => { + e.throw(&mut env); + 0 + } + } +} + +// If the reader is closed, the native handle will be null and we will get a JniError::NullPtr +// error when we call get_rust_field. Translate that into a more meaningful error. +fn unwrap_reader(val: std::result::Result) -> Result { + match val { + Ok(val) => Ok(val), + Err(jni::errors::Error::NullPtr(_)) => Err(Error::io_error( + "FileReader has already been closed".to_string(), + )), + err => Ok(err?), + } +} + +fn inner_num_rows(env: &mut JNIEnv<'_>, reader: JObject) -> Result { + let reader = unsafe { env.get_rust_field::<_, _, BlockingFileReader>(reader, NATIVE_READER) }; + let reader = unwrap_reader(reader)?; + Ok(reader.num_rows() as i64) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileReader_populateSchemaNative( + mut env: JNIEnv, + reader: JObject, + schema_addr: jlong, +) { + ok_or_throw_without_return!(env, inner_populate_schema(&mut env, reader, schema_addr)); +} + +fn inner_populate_schema(env: &mut JNIEnv, reader: JObject, schema_addr: jlong) -> Result<()> { + let reader = unsafe { env.get_rust_field::<_, _, BlockingFileReader>(reader, NATIVE_READER) }; + let reader = unwrap_reader(reader)?; + let schema = reader.schema()?; + let ffi_schema = FFI_ArrowSchema::try_from(schema.as_ref())?; + unsafe { std::ptr::write_unaligned(schema_addr as *mut FFI_ArrowSchema, ffi_schema) } + Ok(()) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileReader_readAllNative( + mut env: JNIEnv<'_>, + reader: JObject, + batch_size: jint, + stream_addr: jlong, +) { + if let Err(e) = inner_read_all(&mut env, reader, batch_size, stream_addr) { + e.throw(&mut env); + } +} + +fn inner_read_all( + env: &mut JNIEnv<'_>, + reader: JObject, + batch_size: jint, + stream_addr: jlong, +) -> Result<()> { + let reader = unsafe { env.get_rust_field::<_, _, BlockingFileReader>(reader, NATIVE_READER) }; + let reader = unwrap_reader(reader)?; + let arrow_stream = reader.open_stream(batch_size as u32)?; + let ffi_stream = FFI_ArrowArrayStream::new(arrow_stream); + unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } + Ok(()) +} diff --git a/java/core/lance-jni/src/file_writer.rs b/java/core/lance-jni/src/file_writer.rs new file mode 100644 index 00000000000..98f6218c91a --- /dev/null +++ b/java/core/lance-jni/src/file_writer.rs @@ -0,0 +1,152 @@ +use std::sync::{Arc, Mutex}; + +use crate::{ + error::{Error, Result}, + traits::IntoJava, + RT, +}; +use arrow::{ + array::{RecordBatch, StructArray}, + ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}, +}; +use arrow_schema::DataType; +use jni::{ + objects::{JObject, JString}, + sys::jlong, + JNIEnv, +}; +use lance::io::ObjectStore; +use lance_file::{ + v2::writer::{FileWriter, FileWriterOptions}, + version::LanceFileVersion, +}; + +pub const NATIVE_WRITER: &str = "nativeFileWriterHandle"; + +#[derive(Clone)] +pub struct BlockingFileWriter { + pub(crate) inner: Arc>, +} + +impl BlockingFileWriter { + pub fn create(file_writer: FileWriter) -> Self { + Self { + inner: Arc::new(Mutex::new(file_writer)), + } + } +} + +impl IntoJava for BlockingFileWriter { + fn into_java<'local>(self, env: &mut JNIEnv<'local>) -> Result> { + attach_native_writer(env, self) + } +} + +fn attach_native_writer<'local>( + env: &mut JNIEnv<'local>, + writer: BlockingFileWriter, +) -> Result> { + let j_writer = create_java_writer_object(env)?; + unsafe { env.set_rust_field(&j_writer, NATIVE_WRITER, writer) }?; + Ok(j_writer) +} + +fn create_java_writer_object<'a>(env: &mut JNIEnv<'a>) -> Result> { + let res = env.new_object("com/lancedb/lance/file/LanceFileWriter", "()V", &[])?; + Ok(res) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileWriter_openNative<'local>( + mut env: JNIEnv<'local>, + _writer_class: JObject, + file_uri: JString, +) -> JObject<'local> { + ok_or_throw!(env, inner_open(&mut env, file_uri,)) +} + +fn inner_open<'local>(env: &mut JNIEnv<'local>, file_uri: JString) -> Result> { + let file_uri_str: String = env.get_string(&file_uri)?.into(); + + let writer = RT.block_on(async move { + let (obj_store, path) = ObjectStore::from_uri(&file_uri_str).await?; + let obj_store = Arc::new(obj_store); + let obj_writer = obj_store.create(&path).await?; + + Result::Ok(FileWriter::new_lazy( + obj_writer, + FileWriterOptions { + format_version: Some(LanceFileVersion::V2_1), + ..Default::default() + }, + )) + })?; + + let writer = BlockingFileWriter::create(writer); + + writer.into_java(env) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileWriter_closeNative<'local>( + mut env: JNIEnv<'local>, + writer: JObject, +) -> JObject<'local> { + let maybe_err = + unsafe { env.take_rust_field::<_, _, BlockingFileWriter>(writer, NATIVE_WRITER) }; + let writer = match maybe_err { + Ok(writer) => Some(writer), + // We were already closed, do nothing + Err(jni::errors::Error::NullPtr(_)) => None, + Err(err) => { + Error::from(err).throw(&mut env); + None + } + }; + if let Some(writer) = writer { + match RT.block_on(writer.inner.lock().unwrap().finish()) { + Ok(_) => {} + Err(e) => { + Error::from(e).throw(&mut env); + } + } + } + JObject::null() +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_file_LanceFileWriter_writeNative<'local>( + mut env: JNIEnv<'local>, + writer: JObject, + batch_address: jlong, + schema_address: jlong, +) -> JObject<'local> { + if let Err(e) = inner_write_batch(&mut env, writer, batch_address, schema_address) { + e.throw(&mut env); + return JObject::null(); + } + JObject::null() +} + +fn inner_write_batch( + env: &mut JNIEnv<'_>, + writer: JObject, + batch_address: jlong, + schema_address: jlong, +) -> Result<()> { + let c_array_ptr = batch_address as *mut FFI_ArrowArray; + let c_schema_ptr = schema_address as *mut FFI_ArrowSchema; + + let c_array = unsafe { FFI_ArrowArray::from_raw(c_array_ptr) }; + let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; + + let data_type = DataType::try_from(&c_schema)?; + let array_data = unsafe { from_ffi_and_data_type(c_array, data_type) }?; + let record_batch = RecordBatch::from(StructArray::from(array_data)); + + let writer = unsafe { env.get_rust_field::<_, _, BlockingFileWriter>(writer, NATIVE_WRITER) }?; + + let mut writer = writer.inner.lock().unwrap(); + RT.block_on(writer.write_batch(&record_batch))?; + Ok(()) +} diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index 66182b2d442..a0b05dd141b 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -16,20 +16,22 @@ use arrow::array::{RecordBatch, RecordBatchIterator, StructArray}; use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; use arrow_schema::DataType; +use jni::objects::{JIntArray, JValueGen}; use jni::{ objects::{JObject, JString}, sys::{jint, jlong}, JNIEnv, }; +use lance::table::format::{DataFile, DeletionFile, DeletionFileType, Fragment, RowIdMeta}; use std::iter::once; use lance::dataset::fragment::FileFragment; use lance_datafusion::utils::StreamingWriteSource; use crate::error::{Error, Result}; +use crate::traits::{export_vec, import_vec, FromJObjectWithEnv, IntoJava, JLance}; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - ffi::JNIEnvExt, traits::FromJString, utils::extract_write_params, RT, @@ -39,7 +41,7 @@ use crate::{ // Read Methods // ////////////////// #[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_DatasetFragment_countRowsNative( +pub extern "system" fn Java_com_lancedb_lance_Fragment_countRowsNative( mut env: JNIEnv, _jfragment: JObject, jdataset: JObject, @@ -63,7 +65,7 @@ fn inner_count_rows_native( "Fragment not found: {fragment_id}" ))); }; - let res = RT.block_on(fragment.count_rows())?; + let res = RT.block_on(fragment.count_rows(None))?; Ok(res) } @@ -77,13 +79,12 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local dataset_uri: JString, arrow_array_addr: jlong, arrow_schema_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional storage_options_obj: JObject, // Map -) -> JString<'local> { +) -> JObject<'local> { ok_or_throw_with_return!( env, inner_create_with_ffi_array( @@ -91,14 +92,13 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local dataset_uri, arrow_array_addr, arrow_schema_addr, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode, storage_options_obj ), - JString::default() + JObject::default() ) } @@ -108,13 +108,12 @@ fn inner_create_with_ffi_array<'local>( dataset_uri: JString, arrow_array_addr: jlong, arrow_schema_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional storage_options_obj: JObject, // Map -) -> Result> { +) -> Result> { let c_array_ptr = arrow_array_addr as *mut FFI_ArrowArray; let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema; @@ -131,7 +130,6 @@ fn inner_create_with_ffi_array<'local>( create_fragment( env, dataset_uri, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -147,27 +145,25 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( _obj: JObject, dataset_uri: JString, arrow_array_stream_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional storage_options_obj: JObject, // Map -) -> JString<'a> { +) -> JObject<'a> { ok_or_throw_with_return!( env, inner_create_with_ffi_stream( &mut env, dataset_uri, arrow_array_stream_addr, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode, storage_options_obj ), - JString::default() + JObject::null() ) } @@ -176,20 +172,18 @@ fn inner_create_with_ffi_stream<'local>( env: &mut JNIEnv<'local>, dataset_uri: JString, arrow_array_stream_addr: jlong, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional storage_options_obj: JObject, // Map -) -> Result> { +) -> Result> { let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }?; create_fragment( env, dataset_uri, - fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, @@ -203,18 +197,15 @@ fn inner_create_with_ffi_stream<'local>( fn create_fragment<'a>( env: &mut JNIEnv<'a>, dataset_uri: JString, - fragment_id: JObject, // Optional max_rows_per_file: JObject, // Optional max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional storage_options_obj: JObject, // Map source: impl StreamingWriteSource, -) -> Result> { +) -> Result> { let path_str = dataset_uri.extract(env)?; - let fragment_id_opts = env.get_int_opt(&fragment_id)?; - let write_params = extract_write_params( env, &max_rows_per_file, @@ -223,13 +214,251 @@ fn create_fragment<'a>( &mode, &storage_options_obj, )?; - let fragment = RT.block_on(FileFragment::create( + let fragments = RT.block_on(FileFragment::create_fragments( &path_str, - fragment_id_opts.unwrap_or(0) as usize, source, Some(write_params), ))?; - let json_string = serde_json::to_string(&fragment)?; - let res = env.new_string(json_string)?; - Ok(res) + export_vec(env, &fragments) +} + +const DATA_FILE_CLASS: &str = "com/lancedb/lance/fragment/DataFile"; +const DATA_FILE_CONSTRUCTOR_SIG: &str = "(Ljava/lang/String;[I[III)V"; +const DELETE_FILE_CLASS: &str = "com/lancedb/lance/fragment/DeletionFile"; +const DELETE_FILE_CONSTRUCTOR_SIG: &str = + "(JJLjava/lang/Long;Lcom/lancedb/lance/fragment/DeletionFileType;)V"; +const DELETE_FILE_TYPE_CLASS: &str = "com/lancedb/lance/fragment/DeletionFileType"; +const FRAGMENT_METADATA_CLASS: &str = "com/lancedb/lance/FragmentMetadata"; +const FRAGMENT_METADATA_CONSTRUCTOR_SIG: &str ="(ILjava/util/List;Ljava/lang/Long;Lcom/lancedb/lance/fragment/DeletionFile;Lcom/lancedb/lance/fragment/RowIdMeta;)V"; +const ROW_ID_META_CLASS: &str = "com/lancedb/lance/fragment/RowIdMeta"; +const ROW_ID_META_CONSTRUCTOR_SIG: &str = "(Ljava/lang/String;)V"; + +impl IntoJava for &DataFile { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let path = env.new_string(self.path.clone())?.into(); + let fields = JLance(self.fields.clone()).into_java(env)?; + let column_indices = JLance(self.column_indices.clone()).into_java(env)?; + Ok(env.new_object( + DATA_FILE_CLASS, + DATA_FILE_CONSTRUCTOR_SIG, + &[ + JValueGen::Object(&path), + JValueGen::Object(&fields), + JValueGen::Object(&column_indices), + JValueGen::Int(self.file_major_version as i32), + JValueGen::Int(self.file_minor_version as i32), + ], + )?) + } +} + +impl IntoJava for &DeletionFileType { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let name = match self { + lance::table::format::DeletionFileType::Array => "ARRAY", + lance::table::format::DeletionFileType::Bitmap => "BITMAP", + }; + env.get_static_field( + DELETE_FILE_TYPE_CLASS, + name, + format!("L{};", DELETE_FILE_TYPE_CLASS), + )? + .l() + .map_err(|e| { + Error::runtime_error(format!("failed to get {}: {}", DELETE_FILE_TYPE_CLASS, e)) + }) + } +} + +impl IntoJava for &DeletionFile { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let num_deleted_rows = match self.num_deleted_rows { + Some(f) => JLance(f).into_java(env)?, + None => JObject::null(), + }; + let file_type = self.file_type.into_java(env)?; + Ok(env.new_object( + DELETE_FILE_CLASS, + DELETE_FILE_CONSTRUCTOR_SIG, + &[ + JValueGen::Long(self.id as i64), + JValueGen::Long(self.read_version as i64), + JValueGen::Object(&num_deleted_rows), + JValueGen::Object(&file_type), + ], + )?) + } +} + +impl IntoJava for &RowIdMeta { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let json_str = serde_json::to_string(self)?; + let json = env.new_string(json_str)?.into(); + Ok(env.new_object( + ROW_ID_META_CLASS, + ROW_ID_META_CONSTRUCTOR_SIG, + &[JValueGen::Object(&json)], + )?) + } +} + +impl IntoJava for &Fragment { + fn into_java<'local>(self, env: &mut JNIEnv<'local>) -> Result> { + let files = self.files.clone(); + let files = export_vec::(env, &files)?; + let deletion_file = match &self.deletion_file { + Some(f) => f.into_java(env)?, + None => JObject::null(), + }; + let physical_rows = &JLance(self.physical_rows).into_java(env)?; + let row_id_meta = match &self.row_id_meta { + Some(m) => m.into_java(env)?, + None => JObject::null(), + }; + + env.new_object( + FRAGMENT_METADATA_CLASS, + FRAGMENT_METADATA_CONSTRUCTOR_SIG, + &[ + JValueGen::Int(self.id as i32), + JValueGen::Object(&files), + JValueGen::Object(physical_rows), + JValueGen::Object(&deletion_file), + JValueGen::Object(&row_id_meta), + ], + ) + .map_err(|e| { + Error::runtime_error(format!("failed to get {}: {}", FRAGMENT_METADATA_CLASS, e)) + }) + } +} + +impl FromJObjectWithEnv for JObject<'_> { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result { + let metadata = env + .call_method(self, "getMetadata", "()Ljava/lang/String;", &[])? + .l()?; + let s: String = env.get_string(&JString::from(metadata))?.into(); + let meta: RowIdMeta = serde_json::from_str(&s)?; + Ok(meta) + } +} + +impl FromJObjectWithEnv for JObject<'_> { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result { + let id = env.call_method(self, "getId", "()I", &[])?.i()? as u64; + let file_objs = env + .call_method(self, "getFiles", "()Ljava/util/List;", &[])? + .l()?; + let physical_rows = env.call_method(self, "getPhysicalRows", "()J", &[])?.j()? as usize; + let file_objs = import_vec(env, &file_objs)?; + let mut files = Vec::with_capacity(file_objs.len()); + for f in file_objs { + files.push(f.extract_object(env)?); + } + let deletion_file = env + .call_method( + self, + "getDeletionFile", + format!("()L{};", DELETE_FILE_CLASS), + &[], + )? + .l()?; + let deletion_file = if deletion_file.is_null() { + None + } else { + Some(deletion_file.extract_object(env)?) + }; + + let row_id_meta = env + .call_method( + self, + "getRowIdMeta", + format!("()L{};", ROW_ID_META_CLASS), + &[], + )? + .l()?; + let row_id_meta = if row_id_meta.is_null() { + None + } else { + Some(row_id_meta.extract_object(env)?) + }; + Ok(Fragment { + id, + files, + deletion_file, + physical_rows: Some(physical_rows), + row_id_meta, + }) + } +} + +impl FromJObjectWithEnv for JObject<'_> { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result { + let id = env.call_method(self, "getId", "()J", &[])?.j()? as u64; + let read_version = env.call_method(self, "getReadVersion", "()J", &[])?.j()? as u64; + let num_deleted_rows: Option = env + .call_method(self, "getNumDeletedRows", "()Ljava/lang/Long;", &[])? + .l()? + .extract_object(env)?; + let num_deleted_rows = num_deleted_rows.map(|r| r as usize); + let file_type: DeletionFileType = env + .call_method( + self, + "getFileType", + format!("()L{};", DELETE_FILE_TYPE_CLASS), + &[], + )? + .l()? + .extract_object(env)?; + Ok(DeletionFile { + read_version, + id, + num_deleted_rows, + file_type, + }) + } +} + +impl FromJObjectWithEnv for JObject<'_> { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result { + let s = env + .call_method(self, "toString", "()Ljava.lang.String;", &[])? + .l()?; + let s: String = env.get_string(&JString::from(s))?.into(); + let t = if s == "ARRAY" { + DeletionFileType::Array + } else { + DeletionFileType::Bitmap + }; + Ok(t) + } +} + +impl FromJObjectWithEnv for JObject<'_> { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result { + let path = env + .call_method(self, "getPath", "()Ljava/lang/String;", &[])? + .l()?; + let path: String = env.get_string(&JString::from(path))?.into(); + let fields = env.call_method(self, "getFields", "()[I", &[])?.l()?; + let fields = JIntArray::from(fields).extract_object(env)?; + let column_indices = env + .call_method(self, "getColumnIndices", "()[I", &[])? + .l()?; + let column_indices = JIntArray::from(column_indices).extract_object(env)?; + let file_major_version = env + .call_method(self, "getFileMajorVersion", "()I", &[])? + .i()? as u32; + let file_minor_version = env + .call_method(self, "getFileMinorVersion", "()I", &[])? + .i()? as u32; + Ok(DataFile { + path, + fields, + column_indices, + file_major_version, + file_minor_version, + }) + } } diff --git a/java/core/lance-jni/src/lib.rs b/java/core/lance-jni/src/lib.rs index 84b7ba64972..437c3d4c00d 100644 --- a/java/core/lance-jni/src/lib.rs +++ b/java/core/lance-jni/src/lib.rs @@ -54,6 +54,8 @@ mod blocking_dataset; mod blocking_scanner; pub mod error; pub mod ffi; +mod file_reader; +mod file_writer; mod fragment; pub mod traits; pub mod utils; diff --git a/java/core/lance-jni/src/traits.rs b/java/core/lance-jni/src/traits.rs index d91b449b1c9..d4e1f80f193 100644 --- a/java/core/lance-jni/src/traits.rs +++ b/java/core/lance-jni/src/traits.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use jni::objects::{JMap, JObject, JString, JValue}; +use jni::objects::{JIntArray, JMap, JObject, JString, JValue, JValueGen}; use jni::JNIEnv; use crate::error::Result; @@ -21,6 +21,10 @@ pub trait FromJObject { fn extract(&self) -> Result; } +pub trait FromJObjectWithEnv { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result; +} + /// Convert a Rust type into a Java Object. pub trait IntoJava { fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result>; @@ -124,3 +128,78 @@ impl JMapExt for JMap<'_, '_, '_> { get_map_value(env, self, key) } } + +pub fn export_vec<'a, 'b, T>(env: &mut JNIEnv<'a>, vec: &'b [T]) -> Result> +where + &'b T: IntoJava, +{ + let array_list_class = env.find_class("java/util/ArrayList")?; + let array_list = env.new_object(array_list_class, "()V", &[])?; + for e in vec { + let obj = &e.into_java(env)?; + env.call_method( + &array_list, + "add", + "(Ljava/lang/Object;)Z", + &[JValueGen::Object(obj)], + )?; + } + Ok(array_list) +} + +pub fn import_vec<'local>(env: &mut JNIEnv<'local>, obj: &JObject) -> Result>> { + let size = env.call_method(obj, "size", "()I", &[])?.i()?; + let mut ret = Vec::with_capacity(size as usize); + for i in 0..size { + let elem = env.call_method(obj, "get", "(I)Ljava/lang/Object;", &[JValueGen::Int(i)])?; + ret.push(elem.l()?); + } + Ok(ret) +} + +pub struct JLance(pub T); + +impl IntoJava for JLance> { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let arr = env.new_int_array(self.0.len() as i32)?; + env.set_int_array_region(&arr, 0, &self.0)?; + Ok(arr.into()) + } +} + +impl IntoJava for JLance { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + Ok(env.new_object("java/lang/Long", "(J)V", &[JValueGen::Long(self.0 as i64)])?) + } +} + +impl IntoJava for JLance> { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { + let obj = match self.0 { + Some(v) => env.new_object("java/lang/Long", "(J)V", &[JValueGen::Long(v as i64)])?, + None => JObject::null(), + }; + Ok(obj) + } +} + +impl FromJObjectWithEnv> for JObject<'_> { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result> { + let ret = if self.is_null() { + None + } else { + let v = env.call_method(self, "longValue", "()J", &[])?.j()?; + Some(v) + }; + Ok(ret) + } +} + +impl FromJObjectWithEnv> for JIntArray<'_> { + fn extract_object(&self, env: &mut JNIEnv<'_>) -> Result> { + let len = env.get_array_length(self)?; + let mut ret: Vec = vec![0; len as usize]; + env.get_int_array_region(self, 0, ret.as_mut_slice())?; + Ok(ret) + } +} diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 6b15d4d58b2..73da3355f80 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -34,6 +34,26 @@ use crate::ffi::JNIEnvExt; use lance_index::vector::Query; use std::collections::HashMap; +pub fn extract_storage_options( + env: &mut JNIEnv, + storage_options_obj: &JObject, +) -> Result> { + let jmap = JMap::from_env(env, storage_options_obj)?; + let storage_options: HashMap = env.with_local_frame(16, |env| { + let mut map = HashMap::new(); + let mut iter = jmap.iter(env)?; + while let Some((key, value)) = iter.next(env)? { + let key_jstring = JString::from(key); + let value_jstring = JString::from(value); + let key_string: String = env.get_string(&key_jstring)?.into(); + let value_string: String = env.get_string(&value_jstring)?.into(); + map.insert(key_string, value_string); + } + Ok::<_, Error>(map) + })?; + Ok(storage_options) +} + pub fn extract_write_params( env: &mut JNIEnv, max_rows_per_file: &JObject, @@ -56,21 +76,10 @@ pub fn extract_write_params( if let Some(mode_val) = env.get_string_opt(mode)? { write_params.mode = WriteMode::try_from(mode_val.as_str())?; } - // Java code always sets the data storage version to Legacy for now - write_params.data_storage_version = Some(LanceFileVersion::Legacy); - let jmap = JMap::from_env(env, storage_options_obj)?; - let storage_options: HashMap = env.with_local_frame(16, |env| { - let mut map = HashMap::new(); - let mut iter = jmap.iter(env)?; - while let Some((key, value)) = iter.next(env)? { - let key_jstring = JString::from(key); - let value_jstring = JString::from(value); - let key_string: String = env.get_string(&key_jstring)?.into(); - let value_string: String = env.get_string(&value_jstring)?.into(); - map.insert(key_string, value_string); - } - Ok::<_, Error>(map) - })?; + // Java code always sets the data storage version to stable for now + write_params.data_storage_version = Some(LanceFileVersion::Stable); + let storage_options: HashMap = + extract_storage_options(env, storage_options_obj)?; write_params.store_params = Some(ObjectStoreParams { storage_options: Some(storage_options), @@ -109,6 +118,8 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result> column, key, k, + lower_bound: None, + upper_bound: None, nprobes, ef, refine_factor, @@ -178,7 +189,6 @@ pub fn get_index_params( env.get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionBatches")?; let shuffle_partition_concurrency = env.get_int_as_usize_from_method(&ivf_params_obj, "getShufflePartitionConcurrency")?; - let use_residual = env.get_boolean_from_method(&ivf_params_obj, "useResidual")?; let ivf_params = IvfBuildParams { num_partitions, @@ -186,7 +196,6 @@ pub fn get_index_params( sample_rate, shuffle_partition_batches, shuffle_partition_concurrency, - use_residual, ..Default::default() }; stages.push(StageParams::Ivf(ivf_params)); diff --git a/java/core/pom.xml b/java/core/pom.xml index 5c51d872c70..6bd45059faf 100644 --- a/java/core/pom.xml +++ b/java/core/pom.xml @@ -8,7 +8,7 @@ com.lancedb lance-parent - 0.20.0 + 0.26.2 ../pom.xml @@ -40,10 +40,6 @@ org.apache.commons commons-lang3 - - org.json - json - org.questdb jar-jni diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index bdd61afbe13..364e9542871 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -1,35 +1,47 @@ /* - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ - package com.lancedb.lance; import com.lancedb.lance.index.IndexParams; import com.lancedb.lance.index.IndexType; +import com.lancedb.lance.ipc.DataStatistics; import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; -import java.io.Closeable; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; +import com.lancedb.lance.schema.ColumnAlteration; +import com.lancedb.lance.schema.SqlExpressions; + import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.pojo.Schema; +import java.io.ByteArrayInputStream; +import java.io.Closeable; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + /** * Class representing a Lance dataset, interfacing with the native lance library. This class * provides functionality to open and manage datasets with native code. The native library is loaded @@ -59,8 +71,8 @@ private Dataset() {} * @param params write params * @return Dataset */ - public static Dataset create(BufferAllocator allocator, String path, Schema schema, - WriteParams params) { + public static Dataset create( + BufferAllocator allocator, String path, Schema schema, WriteParams params) { Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(path); Preconditions.checkNotNull(schema); @@ -68,8 +80,13 @@ public static Dataset create(BufferAllocator allocator, String path, Schema sche try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator)) { Data.exportSchema(allocator, schema, null, arrowSchema); Dataset dataset = - createWithFfiSchema(arrowSchema.memoryAddress(), path, params.getMaxRowsPerFile(), - params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode(), + createWithFfiSchema( + arrowSchema.memoryAddress(), + path, + params.getMaxRowsPerFile(), + params.getMaxRowsPerGroup(), + params.getMaxBytesPerFile(), + params.getMode(), params.getStorageOptions()); dataset.allocator = allocator; return dataset; @@ -85,26 +102,42 @@ public static Dataset create(BufferAllocator allocator, String path, Schema sche * @param params write parameters * @return Dataset */ - public static Dataset create(BufferAllocator allocator, ArrowArrayStream stream, String path, - WriteParams params) { + public static Dataset create( + BufferAllocator allocator, ArrowArrayStream stream, String path, WriteParams params) { Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(stream); Preconditions.checkNotNull(path); Preconditions.checkNotNull(params); - Dataset dataset = createWithFfiStream(stream.memoryAddress(), path, params.getMaxRowsPerFile(), - params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode(), - params.getStorageOptions()); + Dataset dataset = + createWithFfiStream( + stream.memoryAddress(), + path, + params.getMaxRowsPerFile(), + params.getMaxRowsPerGroup(), + params.getMaxBytesPerFile(), + params.getMode(), + params.getStorageOptions()); dataset.allocator = allocator; return dataset; } - private static native Dataset createWithFfiSchema(long arrowSchemaMemoryAddress, String path, - Optional maxRowsPerFile, Optional maxRowsPerGroup, - Optional maxBytesPerFile, Optional mode, Map storageOptions); + private static native Dataset createWithFfiSchema( + long arrowSchemaMemoryAddress, + String path, + Optional maxRowsPerFile, + Optional maxRowsPerGroup, + Optional maxBytesPerFile, + Optional mode, + Map storageOptions); - private static native Dataset createWithFfiStream(long arrowStreamMemoryAddress, String path, - Optional maxRowsPerFile, Optional maxRowsPerGroup, - Optional maxBytesPerFile, Optional mode, Map storageOptions); + private static native Dataset createWithFfiStream( + long arrowStreamMemoryAddress, + String path, + Optional maxRowsPerFile, + Optional maxRowsPerGroup, + Optional maxBytesPerFile, + Optional mode, + Map storageOptions); /** * Open a dataset from the specified path. @@ -157,20 +190,30 @@ public static Dataset open(BufferAllocator allocator, String path, ReadOptions o * @param options the open options * @return Dataset */ - private static Dataset open(BufferAllocator allocator, boolean selfManagedAllocator, String path, - ReadOptions options) { + private static Dataset open( + BufferAllocator allocator, boolean selfManagedAllocator, String path, ReadOptions options) { Preconditions.checkNotNull(path); Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(options); - Dataset dataset = openNative(path, options.getVersion(), options.getBlockSize(), - options.getIndexCacheSize(), options.getMetadataCacheSize(), options.getStorageOptions()); + Dataset dataset = + openNative( + path, + options.getVersion(), + options.getBlockSize(), + options.getIndexCacheSize(), + options.getMetadataCacheSize(), + options.getStorageOptions()); dataset.allocator = allocator; dataset.selfManagedAllocator = selfManagedAllocator; return dataset; } - private static native Dataset openNative(String path, Optional version, - Optional blockSize, int indexCacheSize, int metadataCacheSize, + private static native Dataset openNative( + String path, + Optional version, + Optional blockSize, + int indexCacheSize, + int metadataCacheSize, Map storageOptions); /** @@ -180,16 +223,23 @@ private static native Dataset openNative(String path, Optional version, * @param path The file path of the dataset to open. * @param operation The operation to apply to the dataset. * @param readVersion The version of the dataset that was used as the base for the changes. This - * is not needed for overwrite or restore operations. + * is not needed for overwrite or restore operations. * @return A new instance of {@link Dataset} linked to the opened dataset. */ - public static Dataset commit(BufferAllocator allocator, String path, FragmentOperation operation, - Optional readVersion) { + public static Dataset commit( + BufferAllocator allocator, + String path, + FragmentOperation operation, + Optional readVersion) { return commit(allocator, path, operation, readVersion, new HashMap<>()); } - public static Dataset commit(BufferAllocator allocator, String path, FragmentOperation operation, - Optional readVersion, Map storageOptions) { + public static Dataset commit( + BufferAllocator allocator, + String path, + FragmentOperation operation, + Optional readVersion, + Map storageOptions) { Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(path); Preconditions.checkNotNull(operation); @@ -199,8 +249,88 @@ public static Dataset commit(BufferAllocator allocator, String path, FragmentOpe return dataset; } - public static native Dataset commitAppend(String path, Optional readVersion, - List fragmentsMetadata, Map storageOptions); + public static native Dataset commitAppend( + String path, + Optional readVersion, + List fragmentsMetadata, + Map storageOptions); + + public static native Dataset commitOverwrite( + String path, + long arrowSchemaMemoryAddress, + Optional readVersion, + List fragmentsMetadata, + Map storageOptions); + + /** + * Drop a Dataset. + * + * @param path The file path of the dataset + * @param storageOptions Storage options + */ + public static native void drop(String path, Map storageOptions); + + /** + * Add columns to the dataset. + * + * @param sqlExpressions The SQL expressions to add columns + * @param batchSize The number of rows to read at a time from the source dataset when applying the + * transform. + */ + public void addColumns(SqlExpressions sqlExpressions, Optional batchSize) { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + nativeAddColumnsBySqlExpressions(sqlExpressions, batchSize); + } + } + + private native void nativeAddColumnsBySqlExpressions( + SqlExpressions sqlExpressions, Optional batchSize); + + /** + * Add columns to the dataset. + * + * @param stream The Arrow Array Stream generated by arrow reader to add columns. + * @param batchSize The number of rows to read at a time from the source dataset when applying the + * transform. + */ + public void addColumns(ArrowArrayStream stream, Optional batchSize) { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + nativeAddColumnsByReader(stream.memoryAddress(), batchSize); + } + } + + private native void nativeAddColumnsByReader( + long arrowStreamMemoryAddress, Optional batchSize); + + /** + * Drop columns from the dataset. + * + * @param columns The columns to drop + */ + public void dropColumns(List columns) { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + nativeDropColumns(columns); + } + } + + private native void nativeDropColumns(List columns); + + /** + * Alter columns in the dataset. + * + * @param columnAlterations The list of columns need to be altered. + */ + public void alterColumns(List columnAlterations) { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + nativeAlterColumns(columnAlterations); + } + } + + private native void nativeAlterColumns(List columnAlterations); /** * Create a new Dataset Scanner. @@ -235,6 +365,60 @@ public LanceScanner newScan(ScanOptions options) { } } + /** + * Select rows of data by index. + * + * @param indices the indices to take + * @param columns the columns to take + * @return an ArrowReader + */ + public ArrowReader take(List indices, List columns) throws IOException { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + byte[] arrowData = nativeTake(indices, columns); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData); + ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream); + return new ArrowStreamReader(readChannel, allocator) { + @Override + public void close() throws IOException { + super.close(); + readChannel.close(); + byteArrayInputStream.close(); + } + }; + } + } + + private native byte[] nativeTake(List indices, List columns); + + /** + * Delete rows of data by predicate. + * + * @param predicate the predicate to delete + */ + public void delete(String predicate) { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + nativeDelete(predicate); + } + } + + private native void nativeDelete(String predicate); + + /** + * Gets the URI of the dataset. + * + * @return the URI of the dataset + */ + public String uri() { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return nativeUri(); + } + } + + private native String nativeUri(); + /** * Gets the currently checked out version of the dataset. * @@ -249,9 +433,7 @@ public long version() { private native long nativeVersion(); - /** - * @return the latest version of the dataset. - */ + /** @return the latest version of the dataset. */ public long latestVersion() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); @@ -270,50 +452,92 @@ public long latestVersion() { * @param params index params * @param replace whether to replace the existing index */ - public void createIndex(List columns, IndexType indexType, Optional name, - IndexParams params, boolean replace) { + public void createIndex( + List columns, + IndexType indexType, + Optional name, + IndexParams params, + boolean replace) { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); nativeCreateIndex(columns, indexType.getValue(), name, params, replace); } } - private native void nativeCreateIndex(List columns, int indexTypeCode, - Optional name, IndexParams params, boolean replace); + private native void nativeCreateIndex( + List columns, + int indexTypeCode, + Optional name, + IndexParams params, + boolean replace); + + /** + * Count the number of rows in the dataset. + * + * @return num of rows + */ + public long countRows() { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return nativeCountRows(Optional.empty()); + } + } /** * Count the number of rows in the dataset. * + * @param filter the filter expr to count row * @return num of rows */ - public int countRows() { + public long countRows(String filter) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + Preconditions.checkArgument( + null != filter && !filter.isEmpty(), "filter cannot be null or empty"); + return nativeCountRows(Optional.of(filter)); + } + } + + private native long nativeCountRows(Optional filter); + + /** + * Calculate the size of the dataset. + * + * @return the size of the dataset + */ + public long calculateDataSize() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); - return nativeCountRows(); + return nativeGetDataStatistics().getDataSize(); } } - private native int nativeCountRows(); + /** + * Calculate the statistics of the dataset. + * + * @return the statistics of the dataset + */ + private native DataStatistics nativeGetDataStatistics(); /** * Get all fragments in this dataset. * - * @return A list of {@link DatasetFragment}. + * @return A list of {@link Fragment}. */ - public List getFragments() { + public List getFragments() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); // Set a pointer in Fragment to dataset, to make it is easier to issue IOs // later. // // We do not need to close Fragments. - return this.getJsonFragments().stream() - .map(jsonFragment -> new DatasetFragment(this, FragmentMetadata.fromJson(jsonFragment))) + return this.getFragmentsNative().stream() + .map(metadata -> new Fragment(this, metadata)) .collect(Collectors.toList()); } } - private native List getJsonFragments(); + private native List getFragmentsNative(); /** * Gets the schema of the dataset. @@ -332,9 +556,7 @@ public Schema getSchema() { private native void importFfiSchema(long arrowSchemaMemoryAddress); - /** - * @return all the created indexes names - */ + /** @return all the created indexes names */ public List listIndexes() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); @@ -378,4 +600,11 @@ public boolean closed() { return nativeDatasetHandle == 0; } } + + public Fragment getFragment(int fragmentId) { + FragmentMetadata metadata = getFragmentNative(fragmentId); + return new Fragment(this, metadata); + } + + private native FragmentMetadata getFragmentNative(int fragmentId); } diff --git a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java b/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java deleted file mode 100644 index 1aa2a1a307f..00000000000 --- a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.lancedb.lance; - -import com.lancedb.lance.ipc.LanceScanner; -import com.lancedb.lance.ipc.ScanOptions; -import java.util.Arrays; -import org.apache.arrow.util.Preconditions; - -/** - * Dataset format. - * Matching to Lance Rust FileFragment. - */ -public class DatasetFragment { - /** Pointer to the {@link Dataset} instance in Java. */ - private final Dataset dataset; - private final FragmentMetadata metadata; - - /** Private constructor, calling from JNI. */ - DatasetFragment(Dataset dataset, FragmentMetadata metadata) { - Preconditions.checkNotNull(dataset); - Preconditions.checkNotNull(metadata); - this.dataset = dataset; - this.metadata = metadata; - } - - /** - * Create a new Dataset Scanner. - * - * @return a dataset scanner - */ - public LanceScanner newScan() { - return LanceScanner.create(dataset, new ScanOptions.Builder() - .fragmentIds(Arrays.asList(metadata.getId())).build(), dataset.allocator); - } - - /** - * Create a new Dataset Scanner. - * - * @param batchSize scan batch size - * @return a dataset scanner - */ - public LanceScanner newScan(long batchSize) { - return LanceScanner.create(dataset, - new ScanOptions.Builder() - .fragmentIds(Arrays.asList(metadata.getId())).batchSize(batchSize).build(), - dataset.allocator); - } - - /** - * Create a new Dataset Scanner. - * - * @param options the scan options - * @return a dataset scanner - */ - public LanceScanner newScan(ScanOptions options) { - Preconditions.checkNotNull(options); - return LanceScanner.create(dataset, - new ScanOptions.Builder(options).fragmentIds(Arrays.asList(metadata.getId())).build(), - dataset.allocator); - } - - private native int countRowsNative(Dataset dataset, long fragmentId); - - public int getId() { - return metadata.getId(); - } - - /** - * @return row counts in this Fragment - */ - public int countRows() { - return countRowsNative(dataset, metadata.getId()); - } - - public String toString() { - return String.format("Fragment(%s)", metadata.getJsonMetadata()); - } -} diff --git a/java/core/src/main/java/com/lancedb/lance/Fragment.java b/java/core/src/main/java/com/lancedb/lance/Fragment.java index db994a6e4a4..c588d5c8b92 100644 --- a/java/core/src/main/java/com/lancedb/lance/Fragment.java +++ b/java/core/src/main/java/com/lancedb/lance/Fragment.java @@ -11,11 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import java.util.Map; -import java.util.Optional; +import com.lancedb.lance.ipc.LanceScanner; +import com.lancedb.lance.ipc.ScanOptions; + import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.ArrowSchema; @@ -24,76 +24,167 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.VectorSchemaRoot; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; + /** Fragment operations. */ public class Fragment { static { JniLoader.ensureLoaded(); } + /** Pointer to the {@link Dataset} instance in Java. */ + private final Dataset dataset; + + private final FragmentMetadata fragment; + + public Fragment(Dataset dataset, int fragmentId) { + Preconditions.checkNotNull(dataset); + this.dataset = dataset; + this.fragment = dataset.getFragment(fragmentId).fragment; + } + + public Fragment(Dataset dataset, FragmentMetadata fragment) { + Preconditions.checkNotNull(dataset); + Preconditions.checkNotNull(fragment); + this.dataset = dataset; + this.fragment = fragment; + } + + /** + * Create a new Dataset Scanner. + * + * @return a dataset scanner + */ + public LanceScanner newScan() { + return LanceScanner.create( + dataset, + new ScanOptions.Builder().fragmentIds(Arrays.asList(fragment.getId())).build(), + dataset.allocator); + } + + /** + * Create a new Dataset Scanner. + * + * @param batchSize scan batch size + * @return a dataset scanner + */ + public LanceScanner newScan(long batchSize) { + return LanceScanner.create( + dataset, + new ScanOptions.Builder() + .fragmentIds(Arrays.asList(fragment.getId())) + .batchSize(batchSize) + .build(), + dataset.allocator); + } + + /** + * Create a new Dataset Scanner. + * + * @param options the scan options + * @return a dataset scanner + */ + public LanceScanner newScan(ScanOptions options) { + Preconditions.checkNotNull(options); + return LanceScanner.create( + dataset, + new ScanOptions.Builder(options).fragmentIds(Arrays.asList(fragment.getId())).build(), + dataset.allocator); + } + + private native int countRowsNative(Dataset dataset, long fragmentId); + + public int getId() { + return fragment.getId(); + } + + /** @return row counts in this Fragment */ + public int countRows() { + return countRowsNative(dataset, fragment.getId()); + } + /** * Create a fragment from the given data. * * @param datasetUri the dataset uri * @param allocator the buffer allocator * @param root the vector schema root - * @param fragmentId the fragment id * @param params the write params * @return the fragment metadata */ - public static FragmentMetadata create(String datasetUri, BufferAllocator allocator, - VectorSchemaRoot root, Optional fragmentId, WriteParams params) { + public static List create( + String datasetUri, BufferAllocator allocator, VectorSchemaRoot root, WriteParams params) { Preconditions.checkNotNull(datasetUri); Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(root); - Preconditions.checkNotNull(fragmentId); Preconditions.checkNotNull(params); try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); - ArrowArray arrowArray = ArrowArray.allocateNew(allocator)) { + ArrowArray arrowArray = ArrowArray.allocateNew(allocator)) { Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); - return FragmentMetadata.fromJson(createWithFfiArray(datasetUri, arrowArray.memoryAddress(), - arrowSchema.memoryAddress(), fragmentId, params.getMaxRowsPerFile(), - params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode(), - params.getStorageOptions())); + return createWithFfiArray( + datasetUri, + arrowArray.memoryAddress(), + arrowSchema.memoryAddress(), + params.getMaxRowsPerFile(), + params.getMaxRowsPerGroup(), + params.getMaxBytesPerFile(), + params.getMode(), + params.getStorageOptions()); } } /** * Create a fragment from the given arrow stream. - * @param datasetUri the dataset uri - * @param stream the arrow stream - * @param fragmentId the fragment id - * @param params the write params - * @return the fragment metadata + * + * @param datasetUri the dataset uri + * @param stream the arrow stream + * @param params the write params + * @return the fragment metadata */ - public static FragmentMetadata create(String datasetUri, ArrowArrayStream stream, - Optional fragmentId, WriteParams params) { + public static List create( + String datasetUri, ArrowArrayStream stream, WriteParams params) { Preconditions.checkNotNull(datasetUri); Preconditions.checkNotNull(stream); - Preconditions.checkNotNull(fragmentId); Preconditions.checkNotNull(params); - return FragmentMetadata.fromJson(createWithFfiStream(datasetUri, - stream.memoryAddress(), fragmentId, - params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), - params.getMaxBytesPerFile(), params.getMode(), params.getStorageOptions())); + return createWithFfiStream( + datasetUri, + stream.memoryAddress(), + params.getMaxRowsPerFile(), + params.getMaxRowsPerGroup(), + params.getMaxBytesPerFile(), + params.getMode(), + params.getStorageOptions()); } /** * Create a fragment from the given arrow array and schema. * - * @return the json serialized fragment metadata + * @return the fragment metadata */ - private static native String createWithFfiArray(String datasetUri, - long arrowArrayMemoryAddress, long arrowSchemaMemoryAddress, Optional fragmentId, - Optional maxRowsPerFile, Optional maxRowsPerGroup, - Optional maxBytesPerFile, Optional mode, Map storageOptions); + private static native List createWithFfiArray( + String datasetUri, + long arrowArrayMemoryAddress, + long arrowSchemaMemoryAddress, + Optional maxRowsPerFile, + Optional maxRowsPerGroup, + Optional maxBytesPerFile, + Optional mode, + Map storageOptions); /** * Create a fragment from the given arrow stream. * - * @return the json serialized fragment metadata + * @return the fragment metadata */ - private static native String createWithFfiStream(String datasetUri, long arrowStreamMemoryAddress, - Optional fragmentId, Optional maxRowsPerFile, - Optional maxRowsPerGroup, Optional maxBytesPerFile, - Optional mode, Map storageOptions); + private static native List createWithFfiStream( + String datasetUri, + long arrowStreamMemoryAddress, + Optional maxRowsPerFile, + Optional maxRowsPerGroup, + Optional maxBytesPerFile, + Optional mode, + Map storageOptions); } diff --git a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java index c2b5d665a2f..47d198f9b08 100644 --- a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java +++ b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java @@ -11,42 +11,72 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import java.io.Serializable; -import org.apache.arrow.util.Preconditions; -import org.json.JSONObject; +import com.lancedb.lance.fragment.DataFile; +import com.lancedb.lance.fragment.DeletionFile; +import com.lancedb.lance.fragment.RowIdMeta; + import org.apache.commons.lang3.builder.ToStringBuilder; -/** - * Metadata of a Fragment in the dataset. - * Matching to lance Fragment. - */ +import java.io.Serializable; +import java.util.List; + +/** Metadata of a Fragment in the dataset. Matching to lance Fragment. */ public class FragmentMetadata implements Serializable { private static final long serialVersionUID = -5886811251944130460L; - private static final String ID_KEY = "id"; - private static final String PHYSICAL_ROWS_KEY = "physical_rows"; - private final String jsonMetadata; - private final int id; - private final long physicalRows; + private int id; + private List files; + private long physicalRows; + private DeletionFile deletionFile; + private RowIdMeta rowIdMeta; - private FragmentMetadata(String jsonMetadata, int id, long physicalRows) { - this.jsonMetadata = jsonMetadata; + public FragmentMetadata( + int id, + List files, + Long physicalRows, + DeletionFile deletionFile, + RowIdMeta rowIdMeta) { this.id = id; + this.files = files; this.physicalRows = physicalRows; + this.deletionFile = deletionFile; + this.rowIdMeta = rowIdMeta; } public int getId() { return id; } + public List getFiles() { + return files; + } + public long getPhysicalRows() { return physicalRows; } - public String getJsonMetadata() { - return jsonMetadata; + public DeletionFile getDeletionFile() { + return deletionFile; + } + + public long getNumDeletions() { + if (deletionFile == null) { + return 0; + } + Long deleted = deletionFile.getNumDeletedRows(); + if (deleted == null) { + return 0; + } + return deleted; + } + + public long getNumRows() { + return getPhysicalRows() - getNumDeletions(); + } + + public RowIdMeta getRowIdMeta() { + return rowIdMeta; } @Override @@ -54,25 +84,9 @@ public String toString() { return new ToStringBuilder(this) .append("id", id) .append("physicalRows", physicalRows) - .append("jsonMetadata", jsonMetadata) + .append("files", files) + .append("deletionFile", deletionFile) + .append("rowIdMeta", rowIdMeta) .toString(); } - - /** - * Creates the fragment metadata from json serialized string. - * - * @param jsonMetadata json metadata - * @return created fragment metadata - */ - public static FragmentMetadata fromJson(String jsonMetadata) { - Preconditions.checkNotNull(jsonMetadata); - JSONObject metadata = new JSONObject(jsonMetadata); - if (!metadata.has(ID_KEY) || !metadata.has(PHYSICAL_ROWS_KEY)) { - throw new IllegalArgumentException( - String.format("Fragment metadata must have {} and {} but is {}", - ID_KEY, PHYSICAL_ROWS_KEY, jsonMetadata)); - } - return new FragmentMetadata(jsonMetadata, metadata.getInt(ID_KEY), - metadata.getLong(PHYSICAL_ROWS_KEY)); - } } diff --git a/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java b/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java index 08e51c7fc54..72a5c35178c 100644 --- a/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java +++ b/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java @@ -11,15 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; + import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; /** Fragment related operations. */ public abstract class FragmentOperation { @@ -29,8 +31,11 @@ protected static void validateFragments(List fragments) { } } - public abstract Dataset commit(BufferAllocator allocator, String path, - Optional readVersion, Map storageOptions); + public abstract Dataset commit( + BufferAllocator allocator, + String path, + Optional readVersion, + Map storageOptions); /** Fragment append operation. */ public static class Append extends FragmentOperation { @@ -42,14 +47,43 @@ public Append(List fragments) { } @Override - public Dataset commit(BufferAllocator allocator, String path, Optional readVersion, - Map storageOptions) { + public Dataset commit( + BufferAllocator allocator, + String path, + Optional readVersion, + Map storageOptions) { + Preconditions.checkNotNull(allocator); + Preconditions.checkNotNull(path); + Preconditions.checkNotNull(readVersion); + return Dataset.commitAppend(path, readVersion, fragments, storageOptions); + } + } + + /** Fragment overwrite operation. */ + public static class Overwrite extends FragmentOperation { + private final List fragments; + private final Schema schema; + + public Overwrite(List fragments, Schema schema) { + validateFragments(fragments); + this.fragments = fragments; + this.schema = schema; + } + + @Override + public Dataset commit( + BufferAllocator allocator, + String path, + Optional readVersion, + Map storageOptions) { Preconditions.checkNotNull(allocator); Preconditions.checkNotNull(path); Preconditions.checkNotNull(readVersion); - return Dataset.commitAppend(path, readVersion, - fragments.stream().map(FragmentMetadata::getJsonMetadata).collect(Collectors.toList()), - storageOptions); + try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator)) { + Data.exportSchema(allocator, schema, null, arrowSchema); + return Dataset.commitOverwrite( + path, arrowSchema.memoryAddress(), readVersion, fragments, storageOptions); + } } } } diff --git a/java/core/src/main/java/com/lancedb/lance/JniLoader.java b/java/core/src/main/java/com/lancedb/lance/JniLoader.java index 4ce07b33954..85d67392990 100644 --- a/java/core/src/main/java/com/lancedb/lance/JniLoader.java +++ b/java/core/src/main/java/com/lancedb/lance/JniLoader.java @@ -11,24 +11,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; import io.questdb.jar.jni.JarJniLoader; -/** - * Utility class to load the native library. - */ +/** Utility class to load the native library. */ public class JniLoader { static { JarJniLoader.loadLib(Dataset.class, "/nativelib", "lance_jni"); } - /** - * Ensures the native library is loaded. - * This method will trigger the static initializer - */ + /** Ensures the native library is loaded. This method will trigger the static initializer */ public static void ensureLoaded() {} private JniLoader() {} -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/LockManager.java b/java/core/src/main/java/com/lancedb/lance/LockManager.java index f65a5001885..361b06d1c03 100644 --- a/java/core/src/main/java/com/lancedb/lance/LockManager.java +++ b/java/core/src/main/java/com/lancedb/lance/LockManager.java @@ -11,25 +11,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; import java.util.concurrent.locks.ReentrantReadWriteLock; -/** - * The LockManager class provides a way to manage read and write locks. - */ +/** The LockManager class provides a way to manage read and write locks. */ public class LockManager { private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); /** - * Represents a read lock for the LockManager. - * This lock allows multiple threads to read concurrently, but prevents write access. + * Represents a read lock for the LockManager. This lock allows multiple threads to read + * concurrently, but prevents write access. */ public class ReadLock implements AutoCloseable { - /** - * Acquires a read lock on the lock manager. - */ + /** Acquires a read lock on the lock manager. */ public ReadLock() { lock.readLock().lock(); } @@ -40,13 +35,9 @@ public void close() { } } - /** - * Represents a write lock that can be acquired and released. - */ + /** Represents a write lock that can be acquired and released. */ public class WriteLock implements AutoCloseable { - /** - * Constructs a new WriteLock and acquires the write lock. - */ + /** Constructs a new WriteLock and acquires the write lock. */ public WriteLock() { lock.writeLock().lock(); } @@ -74,4 +65,4 @@ public ReadLock acquireReadLock() { public WriteLock acquireWriteLock() { return new WriteLock(); } -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/ReadOptions.java b/java/core/src/main/java/com/lancedb/lance/ReadOptions.java index e41ce17afc5..984ccc1ccc7 100644 --- a/java/core/src/main/java/com/lancedb/lance/ReadOptions.java +++ b/java/core/src/main/java/com/lancedb/lance/ReadOptions.java @@ -1,25 +1,25 @@ /* - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ - package com.lancedb.lance; import org.apache.commons.lang3.builder.ToStringBuilder; -import java.util.Optional; -import java.util.Map; + import java.util.HashMap; +import java.util.Map; +import java.util.Optional; -/** - * Read options for reading from a dataset. - */ +/** Read options for reading from a dataset. */ public class ReadOptions { private final Optional version; @@ -58,9 +58,12 @@ public Map getStorageOptions() { @Override public String toString() { - return new ToStringBuilder(this).append("version", version.orElse(null)) - .append("blockSize", blockSize.orElse(null)).append("indexCacheSize", indexCacheSize) - .append("metadataCacheSize", metadataCacheSize).append("storageOptions", storageOptions) + return new ToStringBuilder(this) + .append("version", version.orElse(null)) + .append("blockSize", blockSize.orElse(null)) + .append("indexCacheSize", indexCacheSize) + .append("metadataCacheSize", metadataCacheSize) + .append("storageOptions", storageOptions) .toString(); } diff --git a/java/core/src/main/java/com/lancedb/lance/Utils.java b/java/core/src/main/java/com/lancedb/lance/Utils.java index d9153fe5d0b..70d01e8e0b0 100644 --- a/java/core/src/main/java/com/lancedb/lance/Utils.java +++ b/java/core/src/main/java/com/lancedb/lance/Utils.java @@ -11,22 +11,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.types.pojo.Schema; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + /** Utility. */ public class Utils { /** * Convert schema to ArrowSchema for JNI processing. + * * @param schema schema * @param allocator buffer allocator * @return ArrowSchema @@ -39,7 +40,8 @@ public static ArrowSchema toFfi(Schema schema, BufferAllocator allocator) { /** * Convert optional array to optional list for JNI processing. - * @param optionalArray Optional array + * + * @param optionalArray Optional array * @return Optional list */ public static Optional> convert(Optional optionalArray) { diff --git a/java/core/src/main/java/com/lancedb/lance/WriteParams.java b/java/core/src/main/java/com/lancedb/lance/WriteParams.java index 1b20dd733bd..524bf07eb8f 100644 --- a/java/core/src/main/java/com/lancedb/lance/WriteParams.java +++ b/java/core/src/main/java/com/lancedb/lance/WriteParams.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; import org.apache.commons.lang3.builder.ToStringBuilder; @@ -20,14 +19,10 @@ import java.util.Map; import java.util.Optional; -/** - * Write Params for Write Operations of Lance. - */ +/** Write Params for Write Operations of Lance. */ public class WriteParams { - /** - * Write Mode. - */ + /** Write Mode. */ public enum WriteMode { CREATE, APPEND, @@ -40,8 +35,11 @@ public enum WriteMode { private final Optional mode; private Map storageOptions = new HashMap<>(); - private WriteParams(Optional maxRowsPerFile, Optional maxRowsPerGroup, - Optional maxBytesPerFile, Optional mode, + private WriteParams( + Optional maxRowsPerFile, + Optional maxRowsPerGroup, + Optional maxBytesPerFile, + Optional mode, Map storageOptions) { this.maxRowsPerFile = maxRowsPerFile; this.maxRowsPerGroup = maxRowsPerGroup; @@ -64,6 +62,7 @@ public Optional getMaxBytesPerFile() { /** * Get Mode with name. + * * @return mode */ public Optional getMode() { @@ -84,9 +83,7 @@ public String toString() { .toString(); } - /** - * A builder of WriteParams. - */ + /** A builder of WriteParams. */ public static class Builder { private Optional maxRowsPerFile = Optional.empty(); private Optional maxRowsPerGroup = Optional.empty(); @@ -120,8 +117,8 @@ public Builder withStorageOptions(Map storageOptions) { } public WriteParams build() { - return new WriteParams(maxRowsPerFile, maxRowsPerGroup, maxBytesPerFile, mode, - storageOptions); + return new WriteParams( + maxRowsPerFile, maxRowsPerGroup, maxBytesPerFile, mode, storageOptions); } } -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/file/LanceFileReader.java b/java/core/src/main/java/com/lancedb/lance/file/LanceFileReader.java new file mode 100644 index 00000000000..ca4d56e884b --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/file/LanceFileReader.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.file; + +import com.lancedb.lance.JniLoader; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; + +public class LanceFileReader implements AutoCloseable { + + static { + JniLoader.ensureLoaded(); + } + + private long nativeFileReaderHandle; + + private BufferAllocator allocator; + private Schema schema; + + private static native LanceFileReader openNative(String fileUri) throws IOException; + + private native void closeNative(long nativeLanceFileReaderHandle) throws IOException; + + private native long numRowsNative() throws IOException; + + private native void populateSchemaNative(long arrowSchemaMemoryAddress); + + private native void readAllNative(int batchSize, long streamMemoryAddress) throws IOException; + + private LanceFileReader() {} + + /** + * Open a LanceFileReader from a file URI + * + * @param path the URI to the Lance file + * @param allocator the Arrow BufferAllocator to use for the reader + * @return a new LanceFileReader + */ + public static LanceFileReader open(String path, BufferAllocator allocator) throws IOException { + LanceFileReader reader = openNative(path); + reader.allocator = allocator; + reader.schema = reader.load_schema(); + return reader; + } + + /** + * Close the LanceFileReader + * + *

This method must be called to release resources when the reader is no longer needed. + */ + @Override + public void close() throws Exception { + closeNative(nativeFileReaderHandle); + } + + /** + * Get the number of rows in the Lance file + * + * @return the number of rows in the Lance file + */ + public long numRows() throws IOException { + long numRows = numRowsNative(); + return numRows; + } + + /** + * Get the schema of the Lance file + * + * @return the schema of the Lance file + */ + public Schema schema() { + return schema; + } + + private Schema load_schema() throws IOException { + try (ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { + populateSchemaNative(ffiArrowSchema.memoryAddress()); + return Data.importSchema(allocator, ffiArrowSchema, null); + } + } + + /** + * Read all rows from the Lance file + * + * @param batchSize the maximum number of rows to read in a single batch + * @return an ArrowReader for the Lance file + */ + public ArrowReader readAll(int batchSize) throws IOException { + try (ArrowArrayStream ffiArrowArrayStream = ArrowArrayStream.allocateNew(allocator)) { + readAllNative(batchSize, ffiArrowArrayStream.memoryAddress()); + return Data.importArrayStream(allocator, ffiArrowArrayStream); + } + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/file/LanceFileWriter.java b/java/core/src/main/java/com/lancedb/lance/file/LanceFileWriter.java new file mode 100644 index 00000000000..a8d469aef21 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/file/LanceFileWriter.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.file; + +import com.lancedb.lance.JniLoader; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; + +import java.io.IOException; + +public class LanceFileWriter implements AutoCloseable { + + static { + JniLoader.ensureLoaded(); + } + + private long nativeFileWriterHandle; + private BufferAllocator allocator; + private DictionaryProvider dictionaryProvider; + + private static native LanceFileWriter openNative(String fileUri) throws IOException; + + private native void closeNative(long nativeLanceFileReaderHandle) throws IOException; + + private native void writeNative(long batchMemoryAddress, long schemaMemoryAddress) + throws IOException; + + private LanceFileWriter() {} + + /** + * Open a LanceFileWriter to write to a given file URI + * + * @param path the URI of the file to write to + * @param allocator the BufferAllocator to use for the writer + * @param dictionaryProvider the DictionaryProvider to use for the writer + * @return a new LanceFileWriter + */ + public static LanceFileWriter open( + String path, BufferAllocator allocator, DictionaryProvider dictionaryProvider) + throws IOException { + LanceFileWriter writer = openNative(path); + writer.allocator = allocator; + writer.dictionaryProvider = dictionaryProvider; + return writer; + } + + /** + * Write a batch of data + * + * @param batch the batch of data to write + * @throws IOException if the batch cannot be written + */ + public void write(VectorSchemaRoot batch) throws IOException { + try (ArrowArray ffiArrowArray = ArrowArray.allocateNew(allocator); + ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { + Data.exportVectorSchemaRoot( + allocator, batch, dictionaryProvider, ffiArrowArray, ffiArrowSchema); + writeNative(ffiArrowArray.memoryAddress(), ffiArrowSchema.memoryAddress()); + } + } + + /** + * Close the LanceFileWriter + * + *

This method must be called to release resources when the writer is no longer needed. + * + *

This method will also flush all remaining data and write the footer to the file. + * + * @throws Exception if the writer cannot be closed + */ + @Override + public void close() throws Exception { + closeNative(nativeFileWriterHandle); + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/fragment/DataFile.java b/java/core/src/main/java/com/lancedb/lance/fragment/DataFile.java new file mode 100644 index 00000000000..1a6a07c341e --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/fragment/DataFile.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.fragment; + +import org.apache.commons.lang3.builder.ToStringBuilder; + +import java.io.Serializable; + +public class DataFile implements Serializable { + private static final long serialVersionUID = -2827710928026343591L; + private final String path; + private final int[] fields; + private final int[] columnIndices; + private final int fileMajorVersion; + private final int fileMinorVersion; + + public DataFile( + String path, int[] fields, int[] columnIndices, int fileMajorVersion, int fileMinorVersion) { + this.path = path; + this.fields = fields; + this.columnIndices = columnIndices; + this.fileMajorVersion = fileMajorVersion; + this.fileMinorVersion = fileMinorVersion; + } + + public String getPath() { + return path; + } + + public int[] getFields() { + return fields; + } + + public int[] getColumnIndices() { + return columnIndices; + } + + public int getFileMajorVersion() { + return fileMajorVersion; + } + + public int getFileMinorVersion() { + return fileMinorVersion; + } + + @Override + public String toString() { + return new ToStringBuilder(this) + .append("path", path) + .append("fields", fields) + .append("columnIndices", columnIndices) + .append("fileMajorVersion", fileMajorVersion) + .append("fileMinorVersion", fileMinorVersion) + .toString(); + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/fragment/DeletionFile.java b/java/core/src/main/java/com/lancedb/lance/fragment/DeletionFile.java new file mode 100644 index 00000000000..157f251269b --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/fragment/DeletionFile.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.fragment; + +import org.apache.commons.lang3.builder.ToStringBuilder; + +import java.io.Serializable; + +public class DeletionFile implements Serializable { + private static final long serialVersionUID = 3786348766842875859L; + + private final long id; + private final long readVersion; + private final Long numDeletedRows; + private final DeletionFileType fileType; + + public DeletionFile(long id, long readVersion, Long numDeletedRows, DeletionFileType fileType) { + this.id = id; + this.readVersion = readVersion; + this.numDeletedRows = numDeletedRows; + this.fileType = fileType; + } + + public long getId() { + return id; + } + + public long getReadVersion() { + return readVersion; + } + + public Long getNumDeletedRows() { + return numDeletedRows; + } + + public DeletionFileType getFileType() { + return fileType; + } + + @Override + public String toString() { + return new ToStringBuilder(this) + .append("id", id) + .append("readVersion", readVersion) + .append("numDeletedRows", numDeletedRows) + .append("fileType", fileType) + .toString(); + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/fragment/DeletionFileType.java b/java/core/src/main/java/com/lancedb/lance/fragment/DeletionFileType.java new file mode 100644 index 00000000000..552f3899105 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/fragment/DeletionFileType.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.fragment; + +public enum DeletionFileType { + ARRAY, + BITMAP +} diff --git a/java/core/src/main/java/com/lancedb/lance/fragment/RowIdMeta.java b/java/core/src/main/java/com/lancedb/lance/fragment/RowIdMeta.java new file mode 100644 index 00000000000..8d0d453f3d5 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/fragment/RowIdMeta.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.fragment; + +import org.apache.commons.lang3.builder.ToStringBuilder; + +import java.io.Serializable; + +public class RowIdMeta implements Serializable { + private static final long serialVersionUID = -6532828695072614148L; + + private final String metadata; + + public RowIdMeta(String metadata) { + this.metadata = metadata; + } + + public String getMetadata() { + return metadata; + } + + @Override + public String toString() { + return new ToStringBuilder(this).append("metadata", metadata).toString(); + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/index/DistanceType.java b/java/core/src/main/java/com/lancedb/lance/index/DistanceType.java index 724accb1ab8..61f2020e419 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/DistanceType.java +++ b/java/core/src/main/java/com/lancedb/lance/index/DistanceType.java @@ -11,21 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package com.lancedb.lance.index; public enum DistanceType { @@ -33,4 +18,4 @@ public enum DistanceType { Cosine, Dot, Hamming; -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/index/IndexParams.java b/java/core/src/main/java/com/lancedb/lance/index/IndexParams.java index 902386c80ed..c24d4340c45 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/IndexParams.java +++ b/java/core/src/main/java/com/lancedb/lance/index/IndexParams.java @@ -11,16 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.index; import com.lancedb.lance.index.vector.VectorIndexParams; + import org.apache.commons.lang3.builder.ToStringBuilder; + import java.util.Optional; -/** - * Parameters for creating an index. - */ +/** Parameters for creating an index. */ public class IndexParams { private final DistanceType distanceType; private final Optional vectorIndexParams; @@ -37,8 +36,7 @@ public static class Builder { public Builder() {} /** - * Set the distance type for calculating the distance between vectors. - * Default to L2. + * Set the distance type for calculating the distance between vectors. Default to L2. * * @param distanceType distance type * @return this builder @@ -50,6 +48,7 @@ public Builder setDistanceType(DistanceType distanceType) { /** * Vector index parameters for creating a vector index. + * * @param vectorIndexParams vector index parameters * @return this builder */ @@ -74,8 +73,8 @@ public Optional getVectorIndexParams() { @Override public String toString() { return new ToStringBuilder(this) - .append("distanceType", distanceType) - .append("vectorIndexParams", vectorIndexParams.orElse(null)) - .toString(); + .append("distanceType", distanceType) + .append("vectorIndexParams", vectorIndexParams.orElse(null)) + .toString(); } -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/index/IndexType.java b/java/core/src/main/java/com/lancedb/lance/index/IndexType.java index d2499e23d26..17e3a706cdf 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/IndexType.java +++ b/java/core/src/main/java/com/lancedb/lance/index/IndexType.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.index; public enum IndexType { @@ -30,11 +29,10 @@ public enum IndexType { private final int value; IndexType(int value) { - this.value = value; + this.value = value; } public int getValue() { - return value; + return value; } } - diff --git a/java/core/src/main/java/com/lancedb/lance/index/vector/HnswBuildParams.java b/java/core/src/main/java/com/lancedb/lance/index/vector/HnswBuildParams.java index dc09fe12503..829214c4b5f 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/vector/HnswBuildParams.java +++ b/java/core/src/main/java/com/lancedb/lance/index/vector/HnswBuildParams.java @@ -11,15 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.index.vector; import org.apache.commons.lang3.builder.ToStringBuilder; + import java.util.Optional; /** - * Parameters for building an HNSW index in each IVF partition. - * This speeds up the search in a large dataset. + * Parameters for building an HNSW index in each IVF partition. This speeds up the search in a large + * dataset. */ public class HnswBuildParams { private final short maxLevel; @@ -41,11 +41,10 @@ public static class Builder { private Optional prefetchDistance = Optional.of(2); /** - * Create a new builder for HNSW index parameters. - * Each IVF partition will be built with an HNSW index. + * Create a new builder for HNSW index parameters. Each IVF partition will be built with an HNSW + * index. */ - public Builder() { - } + public Builder() {} /** * @param maxLevel the maximum number of levels in the graph @@ -113,4 +112,4 @@ public String toString() { .append("prefetchDistance", prefetchDistance.orElse(null)) .toString(); } -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/index/vector/IvfBuildParams.java b/java/core/src/main/java/com/lancedb/lance/index/vector/IvfBuildParams.java index 00317bc85d5..85dc9fbacba 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/vector/IvfBuildParams.java +++ b/java/core/src/main/java/com/lancedb/lance/index/vector/IvfBuildParams.java @@ -11,19 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.index.vector; import org.apache.commons.lang3.builder.ToStringBuilder; /** - * Parameters for building an IVF index. - * Train IVF centroids for the given vector column. - * This will run k-means clustering on the given vector column to train the IVF centroids. - * This is the first step in several vector indices. - * The centroids will be used to partition the vectors into different clusters. - * IVF centroids are trained from a sample of the data (determined by the sample_rate). - * While this sample is not huge it might still be quite large. + * Parameters for building an IVF index. Train IVF centroids for the given vector column. This will + * run k-means clustering on the given vector column to train the IVF centroids. This is the first + * step in several vector indices. The centroids will be used to partition the vectors into + * different clusters. IVF centroids are trained from a sample of the data (determined by the + * sample_rate). While this sample is not huge it might still be quite large. */ public class IvfBuildParams { private final int numPartitions; @@ -51,19 +48,16 @@ public static class Builder { private boolean useResidual = true; /** - * Parameters for building an IVF index. - * Train IVF centroids for the given vector column. - * This will run k-means clustering on the given vector column to train the IVF centroids. - * This is the first step in several vector indices. - * The centroids will be used to partition the vectors into different clusters. - * IVF centroids are trained from a sample of the data (determined by the sample_rate). - * While this sample is not huge it might still be quite large. + * Parameters for building an IVF index. Train IVF centroids for the given vector column. This + * will run k-means clustering on the given vector column to train the IVF centroids. This is + * the first step in several vector indices. The centroids will be used to partition the vectors + * into different clusters. IVF centroids are trained from a sample of the data (determined by + * the sample_rate). While this sample is not huge it might still be quite large. */ public Builder() {} /** - * @param numPartitions set the number of partitions of IVF (Inverted File Index) - * Default to 32 + * @param numPartitions set the number of partitions of IVF (Inverted File Index) Default to 32 * @return Builder */ public Builder setNumPartitions(int numPartitions) { @@ -81,10 +75,9 @@ public Builder setMaxIters(int maxIters) { } /** - * Set the sample rate for training IVF centroids - * IVF centroids are trained from a sample of the data (determined by the sample_rate). - * While this sample is not huge it might still be quite large. - * Default to 256. + * Set the sample rate for training IVF centroids IVF centroids are trained from a sample of the + * data (determined by the sample_rate). While this sample is not huge it might still be quite + * large. Default to 256. * * @param sampleRate set the sample rate for training IVF centroids * @return Builder @@ -95,12 +88,10 @@ public Builder setSampleRate(int sampleRate) { } /** - * Sets the number of batches, using the row group size of the dataset, - * to include in each shuffle partition. Default value is 10240. - * Assuming the row group size is 1024, - * each shuffle partition will hold 10240 * 1024 = 10,485,760 rows. - * By making this value smaller, this shuffle will consume less memory - * but will take longer to complete, and vice versa. + * Sets the number of batches, using the row group size of the dataset, to include in each + * shuffle partition. Default value is 10240. Assuming the row group size is 1024, each shuffle + * partition will hold 10240 * 1024 = 10,485,760 rows. By making this value smaller, this + * shuffle will consume less memory but will take longer to complete, and vice versa. * * @param shufflePartitionBatches the number of batches to include in shuffle * @return Builder @@ -111,9 +102,9 @@ public Builder setShufflePartitionBatches(int shufflePartitionBatches) { } /** - * Set the number of shuffle partitions to process concurrently. Default value is 2. - * By making this value smaller, this shuffle will consume less memory - * but will take longer to complete, and vice versa. + * Set the number of shuffle partitions to process concurrently. Default value is 2. By making + * this value smaller, this shuffle will consume less memory but will take longer to complete, + * and vice versa. * * @param shufflePartitionConcurrency the number of shuffle partitions to process concurrently * @return Builder @@ -125,6 +116,7 @@ public Builder setShufflePartitionConcurrency(int shufflePartitionConcurrency) { /** * Set whether to use residual for k-means clustering. Default value is true. + * * @param useResidual whether to use residual for k-means clustering * @return Builder */ @@ -165,12 +157,12 @@ public boolean useResidual() { @Override public String toString() { return new ToStringBuilder(this) - .append("numPartitions", numPartitions) - .append("maxIters", maxIters) - .append("sampleRate", sampleRate) - .append("shufflePartitionBatches", shufflePartitionBatches) - .append("shufflePartitionConcurrency", shufflePartitionConcurrency) - .append("useResidual", useResidual) - .toString(); + .append("numPartitions", numPartitions) + .append("maxIters", maxIters) + .append("sampleRate", sampleRate) + .append("shufflePartitionBatches", shufflePartitionBatches) + .append("shufflePartitionConcurrency", shufflePartitionConcurrency) + .append("useResidual", useResidual) + .toString(); } } diff --git a/java/core/src/main/java/com/lancedb/lance/index/vector/PQBuildParams.java b/java/core/src/main/java/com/lancedb/lance/index/vector/PQBuildParams.java index f90efaed814..7060faf23d1 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/vector/PQBuildParams.java +++ b/java/core/src/main/java/com/lancedb/lance/index/vector/PQBuildParams.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.index.vector; import org.apache.commons.lang3.builder.ToStringBuilder; @@ -19,12 +18,10 @@ /** * Train a PQ model for a given column. * - * This will run k-means clustering on each subvector to determine the centroids - * that will be used to quantize the subvectors. - * This step runs against a randomly chosen sample of the data. - * The sample size is typically quite small - * and PQ training is relatively fast regardless of dataset scale. - * As a result, accelerators are not needed here. + *

This will run k-means clustering on each subvector to determine the centroids that will be + * used to quantize the subvectors. This step runs against a randomly chosen sample of the data. The + * sample size is typically quite small and PQ training is relatively fast regardless of dataset + * scale. As a result, accelerators are not needed here. */ public class PQBuildParams { private final int numSubVectors; @@ -48,15 +45,12 @@ public static class Builder { private int kmeansRedos = 1; private int sampleRate = 256; - /** - * Create a new builder for training a PQ model. - */ - public Builder() { - } + /** Create a new builder for training a PQ model. */ + public Builder() {} /** - * The number of subvectors to divide the source vectors into. - * This must be a divisor of the vector dimension. + * The number of subvectors to divide the source vectors into. This must be a divisor of the + * vector dimension. * * @param numSubVectors the number of subvectors * @return Builder @@ -130,11 +124,11 @@ public int getSampleRate() { @Override public String toString() { return new ToStringBuilder(this) - .append("numSubVectors", numSubVectors) - .append("numBits", numBits) - .append("maxIters", maxIters) - .append("kmeansRedos", kmeansRedos) - .append("sampleRate", sampleRate) - .toString(); + .append("numSubVectors", numSubVectors) + .append("numBits", numBits) + .append("maxIters", maxIters) + .append("kmeansRedos", kmeansRedos) + .append("sampleRate", sampleRate) + .toString(); } -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/index/vector/SQBuildParams.java b/java/core/src/main/java/com/lancedb/lance/index/vector/SQBuildParams.java index fba6ddaafef..fb419f2eeae 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/vector/SQBuildParams.java +++ b/java/core/src/main/java/com/lancedb/lance/index/vector/SQBuildParams.java @@ -11,14 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.index.vector; import org.apache.commons.lang3.builder.ToStringBuilder; -/** - * Parameters for using SQ quantizer. - */ +/** Parameters for using SQ quantizer. */ public class SQBuildParams { private final short numBits; private final int sampleRate; @@ -32,8 +29,7 @@ public static class Builder { private short numBits = 8; private int sampleRate = 256; - public Builder() { - } + public Builder() {} /** * @param numBits number of bits of scaling range. @@ -70,8 +66,8 @@ public int getSampleRate() { @Override public String toString() { return new ToStringBuilder(this) - .append("numBits", numBits) - .append("sampleRate", sampleRate) - .toString(); + .append("numBits", numBits) + .append("sampleRate", sampleRate) + .toString(); } -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/index/vector/VectorIndexParams.java b/java/core/src/main/java/com/lancedb/lance/index/vector/VectorIndexParams.java index 7ff0a35f74e..e9104c2731c 100644 --- a/java/core/src/main/java/com/lancedb/lance/index/vector/VectorIndexParams.java +++ b/java/core/src/main/java/com/lancedb/lance/index/vector/VectorIndexParams.java @@ -11,16 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.index.vector; import com.lancedb.lance.index.DistanceType; -import java.util.Optional; + import org.apache.commons.lang3.builder.ToStringBuilder; -/** - * Parameters for creating a vector index. - */ +import java.util.Optional; + +/** Parameters for creating a vector index. */ public class VectorIndexParams { private final DistanceType distanceType; private final IvfBuildParams ivfParams; @@ -66,27 +65,29 @@ public static VectorIndexParams ivfFlat(int numPartitions, DistanceType distance * Create a new IVF index with PQ quantizer. * * @param numPartitions the number of partitions of IVF (Inverted File Index) - * @param numBits maps the float vectors to integer vectors, each integer is of num_bits. - * Now only 8 bits are supported + * @param numBits maps the float vectors to integer vectors, each integer is of num_bits. Now only + * 8 bits are supported * @param numSubVectors the number of sub-vectors for PQ (Product Quantization) * @param distanceType the distance type for calculating the distance between vectors * @param maxIterations K-means max iterations. This will run k-means clustering on each subvector - * to determine the centroids that will be used to quantize the subvectors. + * to determine the centroids that will be used to quantize the subvectors. * @return the VectorIndexParams */ - public static VectorIndexParams ivfPq(int numPartitions, int numBits, int numSubVectors, - DistanceType distanceType, int maxIterations) { + public static VectorIndexParams ivfPq( + int numPartitions, + int numBits, + int numSubVectors, + DistanceType distanceType, + int maxIterations) { IvfBuildParams ivfParams = new IvfBuildParams.Builder().setNumPartitions(numPartitions).build(); - PQBuildParams pqParams = new PQBuildParams.Builder() - .setNumBits(numBits) - .setNumSubVectors(numSubVectors) - .setMaxIters(maxIterations) - .build(); - - return new Builder(ivfParams) - .setDistanceType(distanceType) - .setPqParams(pqParams) - .build(); + PQBuildParams pqParams = + new PQBuildParams.Builder() + .setNumBits(numBits) + .setNumSubVectors(numSubVectors) + .setMaxIters(maxIterations) + .build(); + + return new Builder(ivfParams).setDistanceType(distanceType).setPqParams(pqParams).build(); } /** @@ -97,18 +98,14 @@ public static VectorIndexParams ivfPq(int numPartitions, int numBits, int numSub * @param pq the PQ build parameters * @return the VectorIndexParams */ - public static VectorIndexParams withIvfPqParams(DistanceType distanceType, - IvfBuildParams ivf, - PQBuildParams pq) { - return new Builder(ivf) - .setDistanceType(distanceType) - .setPqParams(pq) - .build(); + public static VectorIndexParams withIvfPqParams( + DistanceType distanceType, IvfBuildParams ivf, PQBuildParams pq) { + return new Builder(ivf).setDistanceType(distanceType).setPqParams(pq).build(); } /** - * Create a new IVF HNSW index with PQ quantizer. - * The dataset is partitioned into IVF partitions, and each partition builds an HNSW graph. + * Create a new IVF HNSW index with PQ quantizer. The dataset is partitioned into IVF partitions, + * and each partition builds an HNSW graph. * * @param distanceType the distance type for calculating the distance between vectors * @param ivf the IVF build parameters @@ -116,10 +113,8 @@ public static VectorIndexParams withIvfPqParams(DistanceType distanceType, * @param pq the PQ build parameters * @return the VectorIndexParams */ - public static VectorIndexParams withIvfHnswPqParams(DistanceType distanceType, - IvfBuildParams ivf, - HnswBuildParams hnsw, - PQBuildParams pq) { + public static VectorIndexParams withIvfHnswPqParams( + DistanceType distanceType, IvfBuildParams ivf, HnswBuildParams hnsw, PQBuildParams pq) { return new Builder(ivf) .setDistanceType(distanceType) .setHnswParams(hnsw) @@ -128,8 +123,8 @@ public static VectorIndexParams withIvfHnswPqParams(DistanceType distanceType, } /** - * Create a new IVF HNSW index with SQ quantizer. - * The dataset is partitioned into IVF partitions, and each partition builds an HNSW graph. + * Create a new IVF HNSW index with SQ quantizer. The dataset is partitioned into IVF partitions, + * and each partition builds an HNSW graph. * * @param distanceType the distance type for calculating the distance between vectors * @param ivf the IVF build parameters @@ -137,10 +132,8 @@ public static VectorIndexParams withIvfHnswPqParams(DistanceType distanceType, * @param sq the SQ build parameters * @return the VectorIndexParams */ - public static VectorIndexParams withIvfHnswSqParams(DistanceType distanceType, - IvfBuildParams ivf, - HnswBuildParams hnsw, - SQBuildParams sq) { + public static VectorIndexParams withIvfHnswSqParams( + DistanceType distanceType, IvfBuildParams ivf, HnswBuildParams hnsw, SQBuildParams sq) { return new Builder(ivf) .setDistanceType(distanceType) .setHnswParams(hnsw) @@ -183,8 +176,8 @@ public Builder setPqParams(PQBuildParams pqParams) { } /** - * @param hnswParams the HNSW build parameters for building the HNSW graph - * for each IVF partition + * @param hnswParams the HNSW build parameters for building the HNSW graph for each IVF + * partition * @return Builder */ public Builder setHnswParams(HnswBuildParams hnswParams) { @@ -229,11 +222,11 @@ public Optional getSqParams() { @Override public String toString() { return new ToStringBuilder(this) - .append("distanceType", distanceType) - .append("ivfParams", ivfParams) - .append("pqParams", pqParams.orElse(null)) - .append("hnswParams", hnswParams.orElse(null)) - .append("sqParams", sqParams.orElse(null)) - .toString(); + .append("distanceType", distanceType) + .append("ivfParams", ivfParams) + .append("pqParams", pqParams.orElse(null)) + .append("hnswParams", hnswParams.orElse(null)) + .append("sqParams", sqParams.orElse(null)) + .toString(); } } diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/ColumnOrdering.java b/java/core/src/main/java/com/lancedb/lance/ipc/ColumnOrdering.java new file mode 100644 index 00000000000..5c2432fab58 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/ipc/ColumnOrdering.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.ipc; + +import org.apache.arrow.util.Preconditions; + +import java.io.Serializable; + +public class ColumnOrdering implements Serializable { + private static final long serialVersionUID = 1L; + private final String columnName; + private final boolean nullFirst; + private final boolean ascending; + + private ColumnOrdering(Builder builder) { + this.columnName = Preconditions.checkNotNull(builder.columnName, "Columns must be set"); + Preconditions.checkArgument(!builder.columnName.isEmpty(), "Column must not be empty"); + this.nullFirst = builder.nullFirst; + this.ascending = builder.ascending; + } + + public String getColumnName() { + return columnName; + } + + public boolean isNullFirst() { + return nullFirst; + } + + public boolean isAscending() { + return ascending; + } + + @Override + public String toString() { + return "ColumnOrdering{" + + "columnName='" + + columnName + + '\'' + + ", nullFirst=" + + nullFirst + + ", ascending=" + + ascending + + '}'; + } + + public static class Builder { + private String columnName; + private boolean nullFirst = true; + private boolean ascending = true; + + public void setColumnName(String columnName) { + this.columnName = columnName; + } + + public void setNullFirst(boolean nullFirst) { + this.nullFirst = nullFirst; + } + + public void setAscending(boolean ascending) { + this.ascending = ascending; + } + + public ColumnOrdering build() { + return new ColumnOrdering(this); + } + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java b/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java new file mode 100644 index 00000000000..fad3086f9f3 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/ipc/DataStatistics.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.ipc; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +public class DataStatistics implements Serializable { + private final List fields; + + public DataStatistics() { + this.fields = new ArrayList<>(); + } + + // used for rust to add field statistics + public void addFiledStatistics(FieldStatistics fieldStatistics) { + fields.add(fieldStatistics); + } + + public List getFields() { + return fields; + } + + // get total data size of the whole dataset in bytes + public long getDataSize() { + return fields.stream().mapToLong(FieldStatistics::getDataSize).sum(); + } + + @Override + public String toString() { + return "DataStatistics{" + "fields=" + fields + '}'; + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java b/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java new file mode 100644 index 00000000000..34b83cd2d1b --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/ipc/FieldStatistics.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.ipc; + +import java.io.Serializable; + +public class FieldStatistics implements Serializable { + private final int id; + // The size of the data in bytes + private final long dataSize; + + public FieldStatistics(int id, long dataSize) { + this.id = id; + this.dataSize = dataSize; + } + + public int getId() { + return id; + } + + public long getDataSize() { + return dataSize; + } + + @Override + public String toString() { + return "FieldStatistics{" + "id=" + id + ", dataSize=" + dataSize + '}'; + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java b/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java index 8c8c699ba41..f87097b91b5 100644 --- a/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java +++ b/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java @@ -11,15 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.ipc; import com.lancedb.lance.Dataset; import com.lancedb.lance.LockManager; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Optional; + import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; @@ -29,6 +25,11 @@ import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.Schema; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Optional; + /** Scanner over a Fragment. */ public class LanceScanner implements org.apache.arrow.dataset.scanner.Scanner { Dataset dataset; @@ -51,30 +52,50 @@ private LanceScanner() {} * @param allocator allocator * @return a Scanner */ - public static LanceScanner create(Dataset dataset, ScanOptions options, - BufferAllocator allocator) { + public static LanceScanner create( + Dataset dataset, ScanOptions options, BufferAllocator allocator) { Preconditions.checkNotNull(dataset); Preconditions.checkNotNull(options); Preconditions.checkNotNull(allocator); - LanceScanner scanner = createScanner(dataset, options.getFragmentIds(), options.getColumns(), - options.getSubstraitFilter(), options.getFilter(), options.getBatchSize(), - options.getLimit(), options.getOffset(), options.getNearest(), - options.isWithRowId(), options.getBatchReadahead()); + LanceScanner scanner = + createScanner( + dataset, + options.getFragmentIds(), + options.getColumns(), + options.getSubstraitFilter(), + options.getFilter(), + options.getBatchSize(), + options.getLimit(), + options.getOffset(), + options.getNearest(), + options.isWithRowId(), + options.isWithRowAddress(), + options.getBatchReadahead(), + options.getColumnOrderings()); scanner.allocator = allocator; scanner.dataset = dataset; scanner.options = options; return scanner; } - static native LanceScanner createScanner(Dataset dataset, Optional> fragmentIds, - Optional> columns, Optional substraitFilter, - Optional filter, Optional batchSize, Optional limit, - Optional offset, Optional query, boolean withRowId, int batchReadahead - ); + static native LanceScanner createScanner( + Dataset dataset, + Optional> fragmentIds, + Optional> columns, + Optional substraitFilter, + Optional filter, + Optional batchSize, + Optional limit, + Optional offset, + Optional query, + boolean withRowId, + boolean withRowAddress, + int batchReadahead, + Optional> columnOrderings); /** - * Closes this scanner and releases any system resources associated with it. If - * the scanner is already closed, then invoking this method has no effect. + * Closes this scanner and releases any system resources associated with it. If the scanner is + * already closed, then invoking this method has no effect. */ @Override public void close() throws Exception { @@ -87,8 +108,7 @@ public void close() throws Exception { } /** - * Native method to release the Lance scanner resources associated with the - * given handle. + * Native method to release the Lance scanner resources associated with the given handle. * * @param handle The native handle to the scanner resource. */ diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/Query.java b/java/core/src/main/java/com/lancedb/lance/ipc/Query.java index b24582485b5..8ea81f1a80e 100644 --- a/java/core/src/main/java/com/lancedb/lance/ipc/Query.java +++ b/java/core/src/main/java/com/lancedb/lance/ipc/Query.java @@ -11,12 +11,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.ipc; +import com.lancedb.lance.index.DistanceType; + import org.apache.arrow.util.Preconditions; import org.apache.commons.lang3.builder.ToStringBuilder; -import com.lancedb.lance.index.DistanceType; + import java.util.Optional; public class Query { @@ -145,8 +146,8 @@ public Builder setNprobes(int nprobes) { } /** - * Sets the number of candidates to reserve while searching. - * This is an optional parameter for HNSW related index types. + * Sets the number of candidates to reserve while searching. This is an optional parameter for + * HNSW related index types. * * @param ef The number of candidates to reserve. * @return The Builder instance for method chaining. @@ -193,11 +194,10 @@ public Builder setUseIndex(boolean useIndex) { * Builds the Query object. * * @return A new immutable Query instance. - * @throws IllegalStateException if any required fields are not set or have - * invalid values. + * @throws IllegalStateException if any required fields are not set or have invalid values. */ public Query build() { return new Query(this); } } -} \ No newline at end of file +} diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/ScanOptions.java b/java/core/src/main/java/com/lancedb/lance/ipc/ScanOptions.java index ae1d222f8c3..e16936806cc 100644 --- a/java/core/src/main/java/com/lancedb/lance/ipc/ScanOptions.java +++ b/java/core/src/main/java/com/lancedb/lance/ipc/ScanOptions.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.ipc; import org.apache.arrow.util.Preconditions; @@ -21,9 +20,7 @@ import java.util.List; import java.util.Optional; -/** - * Lance scan options. - */ +/** Lance scan options. */ public class ScanOptions { private final Optional> fragmentIds; private final Optional batchSize; @@ -34,33 +31,42 @@ public class ScanOptions { private final Optional offset; private final Optional nearest; private final boolean withRowId; + private final boolean withRowAddress; private final int batchReadahead; + private final Optional> columnOrderings; /** * Constructor for LanceScanOptions. * - * @param fragmentIds the id of the fragments to scan - * @param batchSize Maximum row number of each returned ArrowRecordBatch. - * Optional, use Optional.empty() if unspecified. - * @param columns (Optional) Projected columns. Optional.empty() for - * scanning all columns. - * Otherwise, only columns present in the List will be - * scanned. - * @param filter (Optional) Filter expression. Optional.empty() for no - * filter. + * @param fragmentIds the id of the fragments to scan + * @param batchSize Maximum row number of each returned ArrowRecordBatch. Optional, use + * Optional.empty() if unspecified. + * @param columns (Optional) Projected columns. Optional.empty() for scanning all columns. + * Otherwise, only columns present in the List will be scanned. + * @param filter (Optional) Filter expression. Optional.empty() for no filter. * @param substraitFilter (Optional) Substrait filter expression. - * @param limit (Optional) Maximum number of rows to return. - * @param offset (Optional) Number of rows to skip before returning - * results. - * @param withRowId Whether to include the row ID in the results. - * @param nearest (Optional) Nearest neighbor query. - * @param batchReadahead Number of batches to read ahead. + * @param limit (Optional) Maximum number of rows to return. + * @param offset (Optional) Number of rows to skip before returning results. + * @param withRowId Whether to include the row ID in the results. + * @param withRowAddress Whether to include the row address in the results. + * @param nearest (Optional) Nearest neighbor query. + * @param batchReadahead Number of batches to read ahead. */ - public ScanOptions(Optional> fragmentIds, Optional batchSize, - Optional> columns, Optional filter, - Optional substraitFilter, Optional limit, - Optional offset, Optional nearest, boolean withRowId, int batchReadahead) { - Preconditions.checkArgument(!(filter.isPresent() && substraitFilter.isPresent()), + public ScanOptions( + Optional> fragmentIds, + Optional batchSize, + Optional> columns, + Optional filter, + Optional substraitFilter, + Optional limit, + Optional offset, + Optional nearest, + boolean withRowId, + boolean withRowAddress, + int batchReadahead, + Optional> columnOrderings) { + Preconditions.checkArgument( + !(filter.isPresent() && substraitFilter.isPresent()), "cannot set both substrait filter and string filter"); this.fragmentIds = fragmentIds; this.batchSize = batchSize; @@ -71,7 +77,9 @@ public ScanOptions(Optional> fragmentIds, Optional batchSize this.offset = offset; this.nearest = nearest; this.withRowId = withRowId; + this.withRowAddress = withRowAddress; this.batchReadahead = batchReadahead; + this.columnOrderings = columnOrderings; } /** @@ -113,8 +121,7 @@ public Optional getFilter() { /** * Get the substrait filter. * - * @return Optional containing the substrait filter if specified, otherwise - * empty. + * @return Optional containing the substrait filter if specified, otherwise empty. */ public Optional getSubstraitFilter() { return substraitFilter; @@ -141,8 +148,7 @@ public Optional getOffset() { /** * Get the nearest neighbor query. * - * @return Optional containing the nearest neighbor query if specified, - * otherwise empty. + * @return Optional containing the nearest neighbor query if specified, otherwise empty. */ public Optional getNearest() { return nearest; @@ -157,6 +163,15 @@ public boolean isWithRowId() { return withRowId; } + /** + * Get whether to include the row address. + * + * @return true if row address should be included, false otherwise. + */ + public boolean isWithRowAddress() { + return withRowAddress; + } + /** * Get the batch readahead. * @@ -166,6 +181,10 @@ public int getBatchReadahead() { return batchReadahead; } + public Optional> getColumnOrderings() { + return columnOrderings; + } + @Override public String toString() { return new ToStringBuilder(this) @@ -173,19 +192,20 @@ public String toString() { .append("batchSize", batchSize.orElse(null)) .append("columns", columns.orElse(null)) .append("filter", filter.orElse(null)) - .append("substraitFilter", substraitFilter - .map(buf -> "ByteBuffer[" + buf.remaining() + " bytes]").orElse(null)) + .append( + "substraitFilter", + substraitFilter.map(buf -> "ByteBuffer[" + buf.remaining() + " bytes]").orElse(null)) .append("limit", limit.orElse(null)) .append("offset", offset.orElse(null)) .append("nearest", nearest.orElse(null)) .append("withRowId", withRowId) + .append("WithRowAddress", withRowAddress) .append("batchReadahead", batchReadahead) + .append("columnOrdering", columnOrderings) .toString(); } - /** - * Builder for constructing LanceScanOptions. - */ + /** Builder for constructing LanceScanOptions. */ public static class Builder { private Optional> fragmentIds = Optional.empty(); private Optional batchSize = Optional.empty(); @@ -196,10 +216,11 @@ public static class Builder { private Optional offset = Optional.empty(); private Optional nearest = Optional.empty(); private boolean withRowId = false; + private boolean withRowAddress = false; private int batchReadahead = 16; + private Optional> columnOrderings = Optional.empty(); - public Builder() { - } + public Builder() {} /** * Create a builder from another scan options. @@ -216,7 +237,9 @@ public Builder(ScanOptions options) { this.offset = options.getOffset(); this.nearest = options.getNearest(); this.withRowId = options.isWithRowId(); + this.withRowAddress = options.isWithRowAddress(); this.batchReadahead = options.getBatchReadahead(); + this.columnOrderings = options.getColumnOrderings(); } /** @@ -318,6 +341,17 @@ public Builder withRowId(boolean withRowId) { return this; } + /** + * Set whether to include the row addr. + * + * @param withRowAddress true to include row ID, false otherwise. + * @return Builder instance for method chaining. + */ + public Builder withRowAddress(boolean withRowAddress) { + this.withRowAddress = withRowAddress; + return this; + } + /** * Set the batch readahead. * @@ -329,14 +363,30 @@ public Builder batchReadahead(int batchReadahead) { return this; } + public Builder setColumnOrderings(List columnOrderings) { + this.columnOrderings = Optional.of(columnOrderings); + return this; + } + /** * Build the LanceScanOptions instance. * * @return LanceScanOptions instance with the specified parameters. */ public ScanOptions build() { - return new ScanOptions(fragmentIds, batchSize, columns, filter, substraitFilter, - limit, offset, nearest, withRowId, batchReadahead); + return new ScanOptions( + fragmentIds, + batchSize, + columns, + filter, + substraitFilter, + limit, + offset, + nearest, + withRowId, + withRowAddress, + batchReadahead, + columnOrderings); } } } diff --git a/java/core/src/main/java/com/lancedb/lance/schema/ColumnAlteration.java b/java/core/src/main/java/com/lancedb/lance/schema/ColumnAlteration.java new file mode 100644 index 00000000000..4d58a9412b7 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/schema/ColumnAlteration.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.schema; + +import org.apache.arrow.vector.types.pojo.ArrowType; + +import java.util.Optional; + +/** Column alteration used to alter dataset columns. */ +public class ColumnAlteration { + + private String path; + private Optional rename; + private Optional nullable; + private Optional dataType; + + private ColumnAlteration(String path) { + this.path = path; + this.rename = Optional.empty(); + this.nullable = Optional.empty(); + this.dataType = Optional.empty(); + } + + public String getPath() { + return path; + } + + public Optional getRename() { + return rename; + } + + public Optional getNullable() { + return nullable; + } + + public Optional getDataType() { + return dataType; + } + + public static class Builder { + private final ColumnAlteration columnAlteration; + + public Builder(String path) { + this.columnAlteration = new ColumnAlteration(path); + } + + public Builder rename(String rename) { + this.columnAlteration.rename = Optional.of(rename); + return this; + } + + public Builder nullable(boolean nullable) { + this.columnAlteration.nullable = Optional.of(nullable); + return this; + } + + public Builder castTo(ArrowType dataType) { + this.columnAlteration.dataType = Optional.of(dataType); + return this; + } + + public ColumnAlteration build() { + return columnAlteration; + } + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/schema/SqlExpressions.java b/java/core/src/main/java/com/lancedb/lance/schema/SqlExpressions.java new file mode 100644 index 00000000000..e05ce58aa1e --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/schema/SqlExpressions.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.schema; + +import java.util.ArrayList; +import java.util.List; + +/** + * Represents a list of SQL expressions. Each expression has a name and an expression string. Name: + * is used to refer to the new column name. Expression: SQL expression strings. These strings can + * reference existing columns in the dataset. The expression would be calculated as the value of new + * column. + */ +public class SqlExpressions { + + private final List sqlExpressions; + + private SqlExpressions(List sqlExpressions) { + this.sqlExpressions = sqlExpressions; + } + + public List getSqlExpressions() { + return sqlExpressions; + } + + public static class SqlExpression { + + private String name; + private String expression; + + public SqlExpression() {} + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getExpression() { + return expression; + } + + public void setExpression(String expression) { + this.expression = expression; + } + } + + public static class Builder { + + private final SqlExpressions sqlExpressions; + + public Builder() { + this.sqlExpressions = new SqlExpressions(new ArrayList<>()); + } + + public Builder withExpression(String name, String expr) { + SqlExpression expression = new SqlExpression(); + expression.setName(name); + expression.setExpression(expr); + this.sqlExpressions.getSqlExpressions().add(expression); + return this; + } + + public SqlExpressions build() { + return this.sqlExpressions; + } + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java b/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java index cb79624cf66..be92bf8f08a 100644 --- a/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java +++ b/java/core/src/main/java/com/lancedb/lance/test/JniTestHelper.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.test; import com.lancedb.lance.JniLoader; @@ -22,9 +21,8 @@ import java.util.Optional; /** - * Used by the JNI test to test the JNI FFI functionality. - * Note that if ffi parsing errors out, the whole JVM will crash - * or all tests will show as UnsatisfiedLinkError. + * Used by the JNI test to test the JNI FFI functionality. Note that if ffi parsing errors out, the + * whole JVM will crash or all tests will show as UnsatisfiedLinkError. */ public class JniTestHelper { static { @@ -38,6 +36,13 @@ public class JniTestHelper { */ public static native void parseInts(List intsList); + /** + * JNI parse longs test. + * + * @param longsList the given list of longs + */ + public static native void parseLongs(List longsList); + /** * JNI parse ints opts test. * diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index 5a24f9005c0..f31d2b8c4f4 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -1,32 +1,53 @@ /* - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package com.lancedb.lance; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertEquals; +import com.lancedb.lance.ipc.LanceScanner; +import com.lancedb.lance.schema.ColumnAlteration; +import com.lancedb.lance.schema.SqlExpressions; -import java.io.IOException; -import java.net.URISyntaxException; -import java.nio.file.Path; +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.channels.ClosedChannelException; +import java.nio.file.Path; +import java.util.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.*; + public class DatasetTest { - @TempDir - static Path tempDir; // Temporary directory for the tests + @TempDir static Path tempDir; // Temporary directory for the tests private static Dataset dataset; @BeforeAll @@ -75,9 +96,11 @@ void testCreateDirNotExist() throws IOException, URISyntaxException { @Test void testOpenInvalidPath() { String validPath = tempDir.resolve("Invalid_dataset").toString(); - assertThrows(RuntimeException.class, () -> { - dataset = Dataset.open(validPath, new RootAllocator(Long.MAX_VALUE)); - }); + assertThrows( + RuntimeException.class, + () -> { + dataset = Dataset.open(validPath, new RootAllocator(Long.MAX_VALUE)); + }); } @Test @@ -131,13 +154,27 @@ void testDatasetVersion() { } } + @Test + void testDatasetUri() { + String datasetPath = tempDir.resolve("dataset_uri").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + try (Dataset dataset = testDataset.createEmptyDataset()) { + assertEquals(datasetPath, dataset.uri()); + } + } + } + @Test void testOpenNonExist() throws IOException, URISyntaxException { String datasetPath = tempDir.resolve("non_exist").toString(); try (BufferAllocator allocator = new RootAllocator()) { - assertThrows(IllegalArgumentException.class, () -> { - Dataset.open(datasetPath, allocator); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + Dataset.open(datasetPath, allocator); + }); } } @@ -148,9 +185,11 @@ void testCreateExist() throws IOException, URISyntaxException { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - assertThrows(IllegalArgumentException.class, () -> { - testDataset.createEmptyDataset(); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + testDataset.createEmptyDataset(); + }); } } @@ -164,9 +203,11 @@ void testCommitConflict() { try (Dataset dataset = testDataset.createEmptyDataset()) { assertEquals(1, dataset.version()); assertEquals(1, dataset.latestVersion()); - assertThrows(IllegalArgumentException.class, () -> { - testDataset.write(0, 5); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + testDataset.write(0, 5); + }); } } } @@ -183,4 +224,346 @@ void testGetSchemaWithClosedDataset() { assertThrows(RuntimeException.class, dataset::getSchema); } } + + @Test + void testDropColumns() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + assertEquals(testDataset.getSchema(), dataset.getSchema()); + dataset.dropColumns(Collections.singletonList("name")); + + Schema changedSchema = + new Schema( + Collections.singletonList(Field.nullable("id", new ArrowType.Int(32, true))), null); + + assertEquals(changedSchema.getFields().size(), dataset.getSchema().getFields().size()); + assertEquals( + changedSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()), + dataset.getSchema().getFields().stream() + .map(Field::getName) + .collect(Collectors.toList())); + } + } + + @Test + void testAlterColumns() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + assertEquals(testDataset.getSchema(), dataset.getSchema()); + + ColumnAlteration nameColumnAlteration = + new ColumnAlteration.Builder("name") + .rename("new_name") + .nullable(true) + .castTo(new ArrowType.Utf8()) + .build(); + + dataset.alterColumns(Collections.singletonList(nameColumnAlteration)); + + Schema changedSchema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.notNullable("new_name", new ArrowType.Utf8())), + null); + + assertEquals(changedSchema.getFields().size(), dataset.getSchema().getFields().size()); + assertEquals( + changedSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()), + dataset.getSchema().getFields().stream() + .map(Field::getName) + .collect(Collectors.toList())); + + nameColumnAlteration = + new ColumnAlteration.Builder("new_name") + .rename("new_name_2") + .castTo(new ArrowType.LargeUtf8()) + .build(); + + dataset.alterColumns(Collections.singletonList(nameColumnAlteration)); + changedSchema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.notNullable("new_name_2", new ArrowType.LargeUtf8())), + null); + + assertEquals(changedSchema.getFields().size(), dataset.getSchema().getFields().size()); + assertEquals( + changedSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()), + dataset.getSchema().getFields().stream() + .map(Field::getName) + .collect(Collectors.toList())); + + nameColumnAlteration = new ColumnAlteration.Builder("new_name_2").build(); + dataset.alterColumns(Collections.singletonList(nameColumnAlteration)); + assertNotNull(dataset.getSchema().findField("new_name_2")); + } + } + + @Test + void testAddColumnBySqlExpressions() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + SqlExpressions sqlExpressions = + new SqlExpressions.Builder().withExpression("double_id", "id * 2").build(); + dataset.addColumns(sqlExpressions, Optional.empty()); + + Schema changedSchema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()), + Field.nullable("double_id", new ArrowType.Int(32, true))), + null); + + assertEquals(changedSchema.getFields().size(), dataset.getSchema().getFields().size()); + assertEquals( + changedSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()), + dataset.getSchema().getFields().stream() + .map(Field::getName) + .collect(Collectors.toList())); + + sqlExpressions = new SqlExpressions.Builder().withExpression("triple_id", "id * 3").build(); + dataset.addColumns(sqlExpressions, Optional.empty()); + changedSchema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()), + Field.nullable("double_id", new ArrowType.Int(32, true)), + Field.nullable("triple_id", new ArrowType.Int(32, true))), + null); + assertEquals(changedSchema.getFields().size(), dataset.getSchema().getFields().size()); + assertEquals( + changedSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()), + dataset.getSchema().getFields().stream() + .map(Field::getName) + .collect(Collectors.toList())); + } + } + + @Test + void testAddColumnsByStream() throws IOException { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + + try (Dataset initialDataset = testDataset.createEmptyDataset()) { + try (Dataset datasetV1 = testDataset.write(1, 3)) { + assertEquals(3, datasetV1.countRows()); + } + } + + dataset = Dataset.open(datasetPath, allocator); + + Schema newColumnSchema = + new Schema( + Collections.singletonList(Field.nullable("age", new ArrowType.Int(32, true))), null); + + try (VectorSchemaRoot vector = VectorSchemaRoot.create(newColumnSchema, allocator); + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + + IntVector ageVector = (IntVector) vector.getVector("age"); + ageVector.allocateNew(3); + ageVector.set(0, 25); + ageVector.set(1, 30); + ageVector.set(2, 35); + vector.setRowCount(3); + + class SimpleVectorReader extends ArrowReader { + private boolean batchLoaded = false; + + protected SimpleVectorReader(BufferAllocator allocator) { + super(allocator); + } + + @Override + public boolean loadNextBatch() { + if (!batchLoaded) { + batchLoaded = true; + return true; + } + return false; + } + + @Override + public VectorSchemaRoot getVectorSchemaRoot() { + return vector; + } + + @Override + public long bytesRead() { + return vector.getFieldVectors().stream().mapToLong(FieldVector::getBufferSize).sum(); + } + + @Override + protected void closeReadSource() {} + + @Override + protected Schema readSchema() { + return newColumnSchema; + } + } + + try (ArrowReader reader = new SimpleVectorReader(allocator)) { + Data.exportArrayStream(allocator, reader, stream); + + dataset.addColumns(stream, Optional.of(3L)); + + Schema expectedSchema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()), + Field.nullable("age", new ArrowType.Int(32, true))), + null); + Schema actualSchema = dataset.getSchema(); + assertEquals(expectedSchema.getFields(), actualSchema.getFields()); + + try (LanceScanner scanner = dataset.newScan()) { + try (ArrowReader resultReader = scanner.scanBatches()) { + assertTrue(resultReader.loadNextBatch()); + VectorSchemaRoot root = resultReader.getVectorSchemaRoot(); + assertEquals(3, root.getRowCount()); + + IntVector idVector = (IntVector) root.getVector("id"); + IntVector ageVectorResult = (IntVector) root.getVector("age"); + for (int i = 0; i < 3; i++) { + assertEquals(i, idVector.get(i)); + assertEquals(25 + i * 5, ageVectorResult.get(i)); + } + } + } + } + } + } catch (Exception e) { + fail("Exception occurred during test: " + e.getMessage(), e); + } + } + + @Test + void testDropPath() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + Dataset.drop(datasetPath, new HashMap<>()); + } + } + + @Test + void testTake() throws IOException, ClosedChannelException { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + List indices = Arrays.asList(1L, 4L); + List columns = Arrays.asList("id", "name"); + try (ArrowReader reader = dataset2.take(indices, columns)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + assertNotNull(result); + assertEquals(indices.size(), result.getRowCount()); + + for (int i = 0; i < indices.size(); i++) { + assertEquals(indices.get(i).intValue(), result.getVector("id").getObject(i)); + assertNotNull(result.getVector("name").getObject(i)); + } + } + } + } + } + } + + @Test + void testCountRows() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + assertEquals(5, dataset2.countRows()); + // get id = 3 and 4 + assertEquals(2, dataset2.countRows("id > 2")); + + assertThrows(IllegalArgumentException.class, () -> dataset2.countRows(null)); + assertThrows(IllegalArgumentException.class, () -> dataset2.countRows("")); + } + } + } + + @Test + void testCalculateDataSize() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + assertEquals(100, dataset2.calculateDataSize()); + } + } + } + + @Test + void testDeleteRows() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + + try (Dataset dataset2 = testDataset.write(1, 5)) { + // Initially there are 5 rows + assertEquals(5, dataset2.countRows()); + + // Delete rows where id > 2 (should delete id=3, id=4) + dataset2.delete("id > 2"); + + // Now verify we have 3 rows left (id=0, id=1, id=2) + assertEquals(3, dataset2.countRows()); + + // Verify the rows that remain + assertEquals(0, dataset2.countRows("id > 2")); + assertEquals(3, dataset2.countRows("id <= 2")); + + // Delete another row + dataset2.delete("id = 1"); + + // Now verify we have 2 rows left (id=0, id=2) + assertEquals(2, dataset2.countRows()); + assertEquals(1, dataset2.countRows("id = 0")); + assertEquals(1, dataset2.countRows("id = 2")); + assertEquals(0, dataset2.countRows("id = 1")); + } + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/FileReaderWriterTest.java b/java/core/src/test/java/com/lancedb/lance/FileReaderWriterTest.java new file mode 100644 index 00000000000..4b19565de82 --- /dev/null +++ b/java/core/src/test/java/com/lancedb/lance/FileReaderWriterTest.java @@ -0,0 +1,166 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance; + +import com.lancedb.lance.file.LanceFileReader; +import com.lancedb.lance.file.LanceFileWriter; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class FileReaderWriterTest { + + @TempDir private static Path tempDir; + + private VectorSchemaRoot createBatch(BufferAllocator allocator) throws IOException { + Schema schema = + new Schema( + Arrays.asList( + Field.nullable("x", new ArrowType.Int(64, true)), + Field.nullable("y", new ArrowType.Utf8())), + null); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + BigIntVector iVector = (BigIntVector) root.getVector("x"); + VarCharVector sVector = (VarCharVector) root.getVector("y"); + + for (int i = 0; i < 100; i++) { + iVector.setSafe(i, i); + sVector.setSafe(i, new Text("s-" + i)); + } + root.setRowCount(100); + + return root; + } + + void createSimpleFile(String filePath) throws Exception { + BufferAllocator allocator = new RootAllocator(); + try (LanceFileWriter writer = LanceFileWriter.open(filePath, allocator, null)) { + try (VectorSchemaRoot batch = createBatch(allocator)) { + writer.write(batch); + } + } + } + + @Test + void testBasicRead() throws Exception { + BufferAllocator allocator = new RootAllocator(); + String filePath = tempDir.resolve("basic_read.lance").toString(); + createSimpleFile(filePath); + LanceFileReader reader = LanceFileReader.open(filePath, allocator); + + Schema expectedSchema = + new Schema( + Arrays.asList( + Field.nullable("x", new ArrowType.Int(64, true)), + Field.nullable("y", new ArrowType.Utf8())), + null); + + assertEquals(100, reader.numRows()); + assertEquals(expectedSchema, reader.schema()); + + try (ArrowReader batches = reader.readAll(100)) { + assertTrue(batches.loadNextBatch()); + VectorSchemaRoot batch = batches.getVectorSchemaRoot(); + assertEquals(100, batch.getRowCount()); + assertEquals(2, batch.getSchema().getFields().size()); + assertFalse(batches.loadNextBatch()); + } + + try (ArrowReader batches = reader.readAll(15)) { + for (int i = 0; i < 100; i += 15) { + int expected = Math.min(15, 100 - i); + assertTrue(batches.loadNextBatch()); + VectorSchemaRoot batch = batches.getVectorSchemaRoot(); + assertEquals(expected, batch.getRowCount()); + assertEquals(2, batch.getSchema().getFields().size()); + } + assertFalse(batches.loadNextBatch()); + } + + reader.close(); + try { + reader.numRows(); + fail("Expected LanceException to be thrown"); + } catch (IOException e) { + assertEquals("FileReader has already been closed", e.getMessage()); + } + + // Ok to call schema after close + assertEquals(expectedSchema, reader.schema()); + + // close should be idempotent + reader.close(); + } + + @Test + void testBasicWrite() throws Exception { + String filePath = tempDir.resolve("basic_write.lance").toString(); + createSimpleFile(filePath); + } + + @Test + void testWriteNoData() throws Exception { + String filePath = tempDir.resolve("no_data.lance").toString(); + BufferAllocator allocator = new RootAllocator(); + + LanceFileWriter writer = LanceFileWriter.open(filePath, allocator, null); + + try { + writer.close(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("no data provided")); + } + } + + @Test + void testInvalidPath() { + BufferAllocator allocator = new RootAllocator(); + try { + LanceFileReader.open("/tmp/does_not_exist.lance", allocator); + fail("Expected LanceException to be thrown"); + } catch (IOException e) { + assertTrue(e.getMessage().contains("Not found: tmp/does_not_exist.lance")); + } + try { + LanceFileReader.open("", allocator); + fail("Expected LanceException to be thrown"); + } catch (RuntimeException e) { + // expected, would be nice if it was an IOException, but it's not because + // lance throws a wrapped error :( + } catch (IOException e) { + fail("Expected RuntimeException to be thrown"); + } + } +} diff --git a/java/core/src/test/java/com/lancedb/lance/FilterTest.java b/java/core/src/test/java/com/lancedb/lance/FilterTest.java index 91dfa1cfcb2..c7cd52f17c5 100644 --- a/java/core/src/test/java/com/lancedb/lance/FilterTest.java +++ b/java/core/src/test/java/com/lancedb/lance/FilterTest.java @@ -11,14 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import java.io.IOException; -import java.nio.file.Path; - import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.junit.jupiter.api.AfterAll; @@ -26,11 +23,14 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; + import static org.junit.jupiter.api.Assertions.assertEquals; public class FilterTest { - @TempDir - static Path tempDir; + @TempDir static Path tempDir; private static BufferAllocator allocator; private static Dataset dataset; @@ -38,7 +38,8 @@ public class FilterTest { static void setup() throws IOException { String datasetPath = tempDir.resolve("filter_test_dataset").toString(); allocator = new RootAllocator(); - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); // write id with value from 0 to 39 dataset = testDataset.write(1, 40); @@ -92,7 +93,8 @@ void testFilters() throws Exception { testFilter("(name IS NOT NULL) AND (name == 'Person 1')", 1); testFilter("(name IS NOT NULL) AND (name == 'Person')", 0); - // Not supported, bug?, LanceError(IO): Schema error: No field named person. Valid fields are id, name. + // Not supported, bug?, LanceError(IO): Schema error: No field named person. Valid fields are + // id, name. // testFilter("(name IS NOT NULL) AND (name == Person)", 0); // Not supported @@ -101,7 +103,13 @@ void testFilters() throws Exception { } private void testFilter(String filter, int expectedCount) throws Exception { - try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().filter(filter).build())) { + try (LanceScanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .columns(Arrays.asList()) + .withRowId(true) + .filter(filter) + .build())) { assertEquals(expectedCount, scanner.countRows()); } } diff --git a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java index a9fbe6c0173..0bdf8ba1cb5 100644 --- a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java +++ b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java @@ -11,23 +11,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - import com.lancedb.lance.ipc.LanceScanner; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Optional; + import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class FragmentTest { @TempDir private static Path tempDir; // Temporary directory for the tests @@ -35,9 +37,10 @@ public class FragmentTest { void testFragmentCreateFfiArray() { String datasetPath = tempDir.resolve("new_fragment_array").toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - testDataset.createNewFragment(123, 20); + testDataset.createNewFragment(20); } } @@ -45,11 +48,11 @@ void testFragmentCreateFfiArray() { void testFragmentCreate() throws Exception { String datasetPath = tempDir.resolve("new_fragment").toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int fragmentId = 312; int rowCount = 21; - FragmentMetadata fragmentMeta = testDataset.createNewFragment(fragmentId, rowCount); + FragmentMetadata fragmentMeta = testDataset.createNewFragment(rowCount); // Commit fragment FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(fragmentMeta)); @@ -57,9 +60,8 @@ void testFragmentCreate() throws Exception { assertEquals(2, dataset.version()); assertEquals(2, dataset.latestVersion()); assertEquals(rowCount, dataset.countRows()); - DatasetFragment fragment = dataset.getFragments().get(0); - assertEquals(fragmentId, fragment.getId()); - + Fragment fragment = dataset.getFragments().get(0); + try (LanceScanner scanner = fragment.newScan()) { Schema schemaRes = scanner.schema(); assertEquals(testDataset.getSchema(), schemaRes); @@ -72,13 +74,16 @@ void testFragmentCreate() throws Exception { void commitWithoutVersion() { String datasetPath = tempDir.resolve("commit_without_version").toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentMetadata meta = testDataset.createNewFragment(20); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(meta)); - assertThrows(IllegalArgumentException.class, () -> { - Dataset.commit(allocator, datasetPath, appendOp, Optional.empty()); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + Dataset.commit(allocator, datasetPath, appendOp, Optional.empty()); + }); } } @@ -86,13 +91,16 @@ void commitWithoutVersion() { void commitOldVersion() { String datasetPath = tempDir.resolve("commit_old_version").toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentMetadata meta = testDataset.createNewFragment(20); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(meta)); - assertThrows(IllegalArgumentException.class, () -> { - Dataset.commit(allocator, datasetPath, appendOp, Optional.of(0L)); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + Dataset.commit(allocator, datasetPath, appendOp, Optional.of(0L)); + }); } } @@ -100,11 +108,84 @@ void commitOldVersion() { void appendWithoutFragment() { String datasetPath = tempDir.resolve("append_without_fragment").toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + assertThrows( + IllegalArgumentException.class, + () -> { + new FragmentOperation.Append(new ArrayList<>()); + }); + } + } + + @Test + void testOverwriteCommit() throws Exception { + String datasetPath = tempDir.resolve("testOverwriteCommit").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + + // Commit fragment + int rowCount = 20; + FragmentMetadata fragmentMeta = testDataset.createNewFragment(rowCount); + FragmentOperation.Overwrite overwrite = + new FragmentOperation.Overwrite( + Collections.singletonList(fragmentMeta), testDataset.getSchema()); + try (Dataset dataset = Dataset.commit(allocator, datasetPath, overwrite, Optional.of(1L))) { + assertEquals(2, dataset.version()); + assertEquals(2, dataset.latestVersion()); + assertEquals(rowCount, dataset.countRows()); + Fragment fragment = dataset.getFragments().get(0); + + try (LanceScanner scanner = fragment.newScan()) { + Schema schemaRes = scanner.schema(); + assertEquals(testDataset.getSchema(), schemaRes); + } + } + + // Commit fragment again + rowCount = 40; + fragmentMeta = testDataset.createNewFragment(rowCount); + overwrite = + new FragmentOperation.Overwrite( + Collections.singletonList(fragmentMeta), testDataset.getSchema()); + try (Dataset dataset = Dataset.commit(allocator, datasetPath, overwrite, Optional.of(2L))) { + assertEquals(3, dataset.version()); + assertEquals(3, dataset.latestVersion()); + assertEquals(rowCount, dataset.countRows()); + Fragment fragment = dataset.getFragments().get(0); + + try (LanceScanner scanner = fragment.newScan()) { + Schema schemaRes = scanner.schema(); + assertEquals(testDataset.getSchema(), schemaRes); + } + } + } + } + + @Test + void testEmptyFragments() { + String datasetPath = tempDir.resolve("testEmptyFragments").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + List fragments = testDataset.createNewFragment(0, 10); + assertEquals(0, fragments.size()); + } + } + + @Test + void testMultiFragments() { + String datasetPath = tempDir.resolve("testMultiFragments").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - assertThrows(IllegalArgumentException.class, () -> { - new FragmentOperation.Append(new ArrayList<>()); - }); + List fragments = testDataset.createNewFragment(20, 10); + assertEquals(2, fragments.size()); } } } diff --git a/java/core/src/test/java/com/lancedb/lance/JNITest.java b/java/core/src/test/java/com/lancedb/lance/JNITest.java index 60b9731a7e3..afae110e54d 100644 --- a/java/core/src/test/java/com/lancedb/lance/JNITest.java +++ b/java/core/src/test/java/com/lancedb/lance/JNITest.java @@ -11,17 +11,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import java.util.Arrays; -import java.util.Optional; - -import org.junit.jupiter.api.Test; - -import com.lancedb.lance.test.JniTestHelper; import com.lancedb.lance.index.DistanceType; import com.lancedb.lance.index.IndexParams; import com.lancedb.lance.index.vector.HnswBuildParams; @@ -30,6 +21,14 @@ import com.lancedb.lance.index.vector.SQBuildParams; import com.lancedb.lance.index.vector.VectorIndexParams; import com.lancedb.lance.ipc.Query; +import com.lancedb.lance.test.JniTestHelper; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertThrows; public class JNITest { @Test @@ -37,6 +36,11 @@ public void testInts() { JniTestHelper.parseInts(Arrays.asList(1, 2, 3)); } + @Test + public void testLongs() { + JniTestHelper.parseLongs(Arrays.asList(1L, 2L, 3L, Long.MAX_VALUE)); + } + @Test public void testIntsOpt() { JniTestHelper.parseIntsOpt(Optional.of(Arrays.asList(1, 2, 3))); @@ -44,96 +48,95 @@ public void testIntsOpt() { @Test public void testQuery() { - JniTestHelper.parseQuery(Optional.of(new Query.Builder() - .setColumn("column") - .setKey(new float[] { 1.0f, 2.0f, 3.0f }) - .setK(10) - .setNprobes(20) - .setEf(30) - .setRefineFactor(40) - .setDistanceType(DistanceType.L2) - .setUseIndex(true) - .build())); + JniTestHelper.parseQuery( + Optional.of( + new Query.Builder() + .setColumn("column") + .setKey(new float[] {1.0f, 2.0f, 3.0f}) + .setK(10) + .setNprobes(20) + .setEf(30) + .setRefineFactor(40) + .setDistanceType(DistanceType.L2) + .setUseIndex(true) + .build())); } @Test public void testIvfFlatIndexParams() { - JniTestHelper.parseIndexParams(new IndexParams.Builder() - .setVectorIndexParams( - VectorIndexParams.ivfFlat(10, DistanceType.L2)) - .build()); + JniTestHelper.parseIndexParams( + new IndexParams.Builder() + .setVectorIndexParams(VectorIndexParams.ivfFlat(10, DistanceType.L2)) + .build()); } @Test public void testIvfPqIndexParams() { - JniTestHelper.parseIndexParams(new IndexParams.Builder() - .setVectorIndexParams( - VectorIndexParams.ivfPq(10, 8, 4, DistanceType.L2, 50)) - .build()); + JniTestHelper.parseIndexParams( + new IndexParams.Builder() + .setVectorIndexParams(VectorIndexParams.ivfPq(10, 8, 4, DistanceType.L2, 50)) + .build()); } @Test public void testIvfPqWithCustomParamsIndexParams() { - IvfBuildParams ivf = new IvfBuildParams.Builder() - .setNumPartitions(20) - .setMaxIters(100) - .setSampleRate(512) - .build(); - PQBuildParams pq = new PQBuildParams.Builder() - .setNumSubVectors(8) - .setNumBits(8) - .setMaxIters(100) - .setKmeansRedos(3) - .setSampleRate(1024) - .build(); - - JniTestHelper.parseIndexParams(new IndexParams.Builder() - .setVectorIndexParams( - VectorIndexParams.withIvfPqParams(DistanceType.Cosine, ivf, pq)) - .build()); + IvfBuildParams ivf = + new IvfBuildParams.Builder() + .setNumPartitions(20) + .setMaxIters(100) + .setSampleRate(512) + .build(); + PQBuildParams pq = + new PQBuildParams.Builder() + .setNumSubVectors(8) + .setNumBits(8) + .setMaxIters(100) + .setKmeansRedos(3) + .setSampleRate(1024) + .build(); + + JniTestHelper.parseIndexParams( + new IndexParams.Builder() + .setVectorIndexParams(VectorIndexParams.withIvfPqParams(DistanceType.Cosine, ivf, pq)) + .build()); } @Test public void testIvfHnswPqIndexParams() { - IvfBuildParams ivf = new IvfBuildParams.Builder() - .setNumPartitions(15) - .build(); - HnswBuildParams hnsw = new HnswBuildParams.Builder() - .setMaxLevel((short) 10) - .setM(30) - .setEfConstruction(200) - .setPrefetchDistance(3) - .build(); - PQBuildParams pq = new PQBuildParams.Builder() - .setNumSubVectors(16) - .setNumBits(8) - .build(); - - JniTestHelper.parseIndexParams(new IndexParams.Builder() - .setVectorIndexParams( - VectorIndexParams.withIvfHnswPqParams(DistanceType.L2, ivf, hnsw, pq)) - .build()); + IvfBuildParams ivf = new IvfBuildParams.Builder().setNumPartitions(15).build(); + HnswBuildParams hnsw = + new HnswBuildParams.Builder() + .setMaxLevel((short) 10) + .setM(30) + .setEfConstruction(200) + .setPrefetchDistance(3) + .build(); + PQBuildParams pq = new PQBuildParams.Builder().setNumSubVectors(16).setNumBits(8).build(); + + JniTestHelper.parseIndexParams( + new IndexParams.Builder() + .setVectorIndexParams( + VectorIndexParams.withIvfHnswPqParams(DistanceType.L2, ivf, hnsw, pq)) + .build()); } @Test public void testIvfHnswSqIndexParams() { - IvfBuildParams ivf = new IvfBuildParams.Builder() - .setNumPartitions(25) - .build(); - HnswBuildParams hnsw = new HnswBuildParams.Builder() - .setMaxLevel((short) 8) - .setM(25) - .setEfConstruction(175) - .build(); - SQBuildParams sq = new SQBuildParams.Builder() - .setNumBits((short) 16) - .setSampleRate(512) - .build(); - - JniTestHelper.parseIndexParams(new IndexParams.Builder() - .setVectorIndexParams( - VectorIndexParams.withIvfHnswSqParams(DistanceType.Dot, ivf, hnsw, sq)) - .build()); + IvfBuildParams ivf = new IvfBuildParams.Builder().setNumPartitions(25).build(); + HnswBuildParams hnsw = + new HnswBuildParams.Builder() + .setMaxLevel((short) 8) + .setM(25) + .setEfConstruction(175) + .build(); + SQBuildParams sq = + new SQBuildParams.Builder().setNumBits((short) 16).setSampleRate(512).build(); + + JniTestHelper.parseIndexParams( + new IndexParams.Builder() + .setVectorIndexParams( + VectorIndexParams.withIvfHnswSqParams(DistanceType.Dot, ivf, hnsw, sq)) + .build()); } @Test @@ -142,13 +145,15 @@ public void testInvalidCombinationPqAndSq() { PQBuildParams pq = new PQBuildParams.Builder().build(); SQBuildParams sq = new SQBuildParams.Builder().build(); - assertThrows(IllegalArgumentException.class, () -> { - new VectorIndexParams.Builder(ivf) - .setDistanceType(DistanceType.L2) - .setPqParams(pq) - .setSqParams(sq) - .build(); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + new VectorIndexParams.Builder(ivf) + .setDistanceType(DistanceType.L2) + .setPqParams(pq) + .setSqParams(sq) + .build(); + }); } @Test @@ -156,12 +161,14 @@ public void testInvalidCombinationHnswWithoutPqOrSq() { IvfBuildParams ivf = new IvfBuildParams.Builder().setNumPartitions(10).build(); HnswBuildParams hnsw = new HnswBuildParams.Builder().build(); - assertThrows(IllegalArgumentException.class, () -> { - new VectorIndexParams.Builder(ivf) - .setDistanceType(DistanceType.L2) - .setHnswParams(hnsw) - .build(); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + new VectorIndexParams.Builder(ivf) + .setDistanceType(DistanceType.L2) + .setHnswParams(hnsw) + .build(); + }); } @Test @@ -169,11 +176,13 @@ public void testInvalidCombinationSqWithoutHnsw() { IvfBuildParams ivf = new IvfBuildParams.Builder().setNumPartitions(10).build(); SQBuildParams sq = new SQBuildParams.Builder().build(); - assertThrows(IllegalArgumentException.class, () -> { - new VectorIndexParams.Builder(ivf) - .setDistanceType(DistanceType.L2) - .setSqParams(sq) - .build(); - }); + assertThrows( + IllegalArgumentException.class, + () -> { + new VectorIndexParams.Builder(ivf) + .setDistanceType(DistanceType.L2) + .setSqParams(sq) + .build(); + }); } } diff --git a/java/core/src/test/java/com/lancedb/lance/ScannerTest.java b/java/core/src/test/java/com/lancedb/lance/ScannerTest.java index fc46a95c52c..d575a7b7dbe 100644 --- a/java/core/src/test/java/com/lancedb/lance/ScannerTest.java +++ b/java/core/src/test/java/com/lancedb/lance/ScannerTest.java @@ -11,23 +11,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import java.io.IOException; -import java.nio.file.Path; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - +import com.lancedb.lance.ipc.ColumnOrdering; import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; + import org.apache.arrow.dataset.scanner.Scanner; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -38,13 +33,19 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; public class ScannerTest { - @TempDir - static Path tempDir; // Temporary directory for the tests + @TempDir static Path tempDir; // Temporary directory for the tests private static Dataset dataset; @BeforeAll @@ -62,7 +63,8 @@ static void tearDown() { void testDatasetScanner() throws IOException { String datasetPath = tempDir.resolve("dataset_scanner").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 40; int batchRows = 20; @@ -77,11 +79,13 @@ void testDatasetScanner() throws IOException { void testDatasetScannerFilter() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_filter").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); // write id with value from 0 to 39 try (Dataset dataset = testDataset.write(1, 40)) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().filter("id < 20").build())) { + try (Scanner scanner = + dataset.newScan(new ScanOptions.Builder().filter("id < 20").build())) { testDataset.validateScanResults(dataset, scanner, 20, 20); } } @@ -92,13 +96,18 @@ void testDatasetScannerFilter() throws Exception { void testDatasetScannerColumns() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_columns").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 40; int batchRows = 20; try (Dataset dataset = testDataset.write(1, totalRows)) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder() - .batchSize(batchRows).columns(Arrays.asList("id")).build())) { + try (Scanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .batchSize(batchRows) + .columns(Arrays.asList("id")) + .build())) { try (ArrowReader reader = scanner.scanBatches()) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); int index = 0; @@ -124,15 +133,19 @@ void testDatasetScannerColumns() throws Exception { void testDatasetScannerSchema() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_schema").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 40; try (Dataset dataset = testDataset.write(1, totalRows)) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder() - .batchSize(totalRows).columns(Arrays.asList("id")).build())) { - Schema expectedSchema = new Schema(Arrays.asList( - Field.nullable("id", new ArrowType.Int(32, true)) - )); + try (Scanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .batchSize(totalRows) + .columns(Arrays.asList("id")) + .build())) { + Schema expectedSchema = + new Schema(Arrays.asList(Field.nullable("id", new ArrowType.Int(32, true)))); assertEquals(expectedSchema, scanner.schema()); } } @@ -143,11 +156,18 @@ void testDatasetScannerSchema() throws Exception { void testDatasetScannerCountRows() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_count").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); // write id with value from 0 to 39 try (Dataset dataset = testDataset.write(1, 40)) { - try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder().filter("id < 20").build())) { + try (LanceScanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .columns(Arrays.asList()) + .withRowId(true) + .filter("id < 20") + .build())) { assertEquals(20, scanner.countRows()); } } @@ -158,12 +178,13 @@ void testDatasetScannerCountRows() throws Exception { void testFragmentScanner() throws Exception { String datasetPath = tempDir.resolve("fragment_scanner").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 40; int batchRows = 20; try (Dataset dataset = testDataset.write(1, totalRows)) { - DatasetFragment fragment = dataset.getFragments().get(0); + Fragment fragment = dataset.getFragments().get(0); try (Scanner scanner = fragment.newScan(batchRows)) { testDataset.validateScanResults(dataset, scanner, totalRows, batchRows); } @@ -175,12 +196,14 @@ void testFragmentScanner() throws Exception { void testFragmentScannerFilter() throws Exception { String datasetPath = tempDir.resolve("fragment_scanner_filter").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); // write id with value from 0 to 39 try (Dataset dataset = testDataset.write(1, 40)) { - DatasetFragment fragment = dataset.getFragments().get(0); - try (Scanner scanner = fragment.newScan(new ScanOptions.Builder().filter("id < 20").build())) { + Fragment fragment = dataset.getFragments().get(0); + try (Scanner scanner = + fragment.newScan(new ScanOptions.Builder().filter("id < 20").build())) { testDataset.validateScanResults(dataset, scanner, 20, 20); } } @@ -191,13 +214,19 @@ void testFragmentScannerFilter() throws Exception { void testFragmentScannerColumns() throws Exception { String datasetPath = tempDir.resolve("fragment_scanner_columns").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 40; int batchRows = 20; try (Dataset dataset = testDataset.write(1, totalRows)) { - DatasetFragment fragment = dataset.getFragments().get(0); - try (Scanner scanner = fragment.newScan(new ScanOptions.Builder().batchSize(batchRows).columns(Arrays.asList("id")).build())) { + Fragment fragment = dataset.getFragments().get(0); + try (Scanner scanner = + fragment.newScan( + new ScanOptions.Builder() + .batchSize(batchRows) + .columns(Arrays.asList("id")) + .build())) { try (ArrowReader reader = scanner.scanBatches()) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); int index = 0; @@ -223,19 +252,20 @@ void testFragmentScannerColumns() throws Exception { void testScanFragment() throws Exception { String datasetPath = tempDir.resolve("fragment_scanner_single_fragment").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int[] fragment0 = new int[]{0, 3}; - int[] fragment1 = new int[]{1, 5}; - int[] fragment2 = new int[]{2, 7}; - FragmentMetadata metadata0 = testDataset.createNewFragment(fragment0[0], fragment0[1]); - FragmentMetadata metadata1 = testDataset.createNewFragment(fragment1[0], fragment1[1]); - FragmentMetadata metadata2 = testDataset.createNewFragment(fragment2[0], fragment2[1]); - FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); + FragmentMetadata metadata0 = testDataset.createNewFragment(3); + FragmentMetadata metadata1 = testDataset.createNewFragment(5); + FragmentMetadata metadata2 = testDataset.createNewFragment(7); + FragmentOperation.Append appendOp = + new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); try (Dataset dataset = Dataset.commit(allocator, datasetPath, appendOp, Optional.of(1L))) { - validScanResult(dataset, fragment0[0], fragment0[1]); - validScanResult(dataset, fragment1[0], fragment1[1]); - validScanResult(dataset, fragment2[0], fragment2[1]); + List frags = dataset.getFragments(); + assertEquals(3, frags.size()); + validScanResult(dataset, frags.get(0).getId(), 3); + validScanResult(dataset, frags.get(1).getId(), 5); + validScanResult(dataset, frags.get(2).getId(), 7); } } } @@ -244,19 +274,27 @@ void testScanFragment() throws Exception { void testScanFragments() throws Exception { String datasetPath = tempDir.resolve("fragments_scanner").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); - int[] fragment0 = new int[]{0, 3}; - int[] fragment1 = new int[]{1, 5}; - int[] fragment2 = new int[]{2, 7}; - FragmentMetadata metadata0 = testDataset.createNewFragment(fragment0[0], fragment0[1]); - FragmentMetadata metadata1 = testDataset.createNewFragment(fragment1[0], fragment1[1]); - FragmentMetadata metadata2 = testDataset.createNewFragment(fragment2[0], fragment2[1]); - FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); + FragmentMetadata metadata0 = testDataset.createNewFragment(3); + FragmentMetadata metadata1 = testDataset.createNewFragment(5); + FragmentMetadata metadata2 = testDataset.createNewFragment(7); + FragmentOperation.Append appendOp = + new FragmentOperation.Append(Arrays.asList(metadata0, metadata1, metadata2)); try (Dataset dataset = Dataset.commit(allocator, datasetPath, appendOp, Optional.of(1L))) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().batchSize(1024).fragmentIds(Arrays.asList(1, 2)).build())) { + List frags = dataset.getFragments(); + assertEquals(3, frags.size()); + try (Scanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .batchSize(1024) + .fragmentIds(Arrays.asList(frags.get(1).getId(), frags.get(2).getId())) + .build())) { try (ArrowReader reader = scanner.scanBatches()) { - assertEquals(dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + assertEquals( + dataset.getSchema().getFields(), + reader.getVectorSchemaRoot().getSchema().getFields()); int rowcount = 0; reader.loadNextBatch(); int currentRowCount = reader.getVectorSchemaRoot().getRowCount(); @@ -277,7 +315,8 @@ void testScanFragments() throws Exception { void testDatasetScannerLimit() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_limit").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 100; int limit = 50; @@ -293,13 +332,15 @@ void testDatasetScannerLimit() throws Exception { void testDatasetScannerOffset() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_offset").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 100; int offset = 50; try (Dataset dataset = testDataset.write(1, totalRows)) { try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().offset(offset).build())) { - testDataset.validateScanResults(dataset, scanner, totalRows - offset, totalRows - offset, offset); + testDataset.validateScanResults( + dataset, scanner, totalRows - offset, totalRows - offset, offset); } } } @@ -307,44 +348,50 @@ void testDatasetScannerOffset() throws Exception { @Test void testDatasetScannerWithRowId() throws Exception { - String datasetPath = tempDir.resolve("dataset_scanner_with_row_id").toString(); - try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); - testDataset.createEmptyDataset().close(); - int totalRows = 50; - try (Dataset dataset = testDataset.write(1, totalRows)) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().withRowId(true).build())) { - try (ArrowReader reader = scanner.scanBatches()) { - VectorSchemaRoot root = reader.getVectorSchemaRoot(); - assertTrue(root.getSchema().getFields().stream().anyMatch(field -> field.getName().equals("_rowid"))); - while (reader.loadNextBatch()) { - List fieldVectors = root.getFieldVectors(); - assertTrue(fieldVectors.stream().anyMatch(vector -> vector.getName().equals("_rowid"))); - } - } - } + String datasetPath = tempDir.resolve("dataset_scanner_with_row_id").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 50; + try (Dataset dataset = testDataset.write(1, totalRows)) { + try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().withRowId(true).build())) { + try (ArrowReader reader = scanner.scanBatches()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertTrue( + root.getSchema().getFields().stream() + .anyMatch(field -> field.getName().equals("_rowid"))); + while (reader.loadNextBatch()) { + List fieldVectors = root.getFieldVectors(); + assertTrue( + fieldVectors.stream().anyMatch(vector -> vector.getName().equals("_rowid"))); + } } + } } + } } @Test void testDatasetScannerBatchReadahead() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_batch_readahead").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 1000; int batchSize = 100; int batchReadahead = 5; try (Dataset dataset = testDataset.write(1, totalRows)) { - try (LanceScanner scanner = dataset.newScan(new ScanOptions.Builder() - .batchSize(batchSize) - .batchReadahead(batchReadahead) - .build())) { + try (LanceScanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .batchSize(batchSize) + .batchReadahead(batchReadahead) + .build())) { // This test is more about ensuring that the batchReadahead parameter is accepted // and doesn't cause errors. The actual effect of batchReadahead might not be // directly observable in this test. - assertEquals(totalRows, scanner.countRows()); try (ArrowReader reader = scanner.scanBatches()) { int rowCount = 0; while (reader.loadNextBatch()) { @@ -357,29 +404,106 @@ void testDatasetScannerBatchReadahead() throws Exception { } } + @Test + void testDatasetScannerSortBy() throws Exception { + String datasetPath = tempDir.resolve("testDatasetScannerSortBy").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + try (Dataset dataset = testDataset.writeSortByDataset(1)) { + ColumnOrdering.Builder nameBuilder = new ColumnOrdering.Builder(); + nameBuilder.setColumnName("name"); + nameBuilder.setAscending(true); + nameBuilder.setNullFirst(false); + + ColumnOrdering.Builder idBuilder = new ColumnOrdering.Builder(); + idBuilder.setColumnName("id"); + idBuilder.setAscending(false); + idBuilder.setNullFirst(true); + + List columnOrderings = + Arrays.asList(nameBuilder.build(), idBuilder.build()); + ScanOptions.Builder scanOptionBuilder = new ScanOptions.Builder(); + scanOptionBuilder + .columns(Arrays.asList("name", "id")) + .limit(10) + .setColumnOrderings(columnOrderings); + ScanOptions scanOptions = scanOptionBuilder.build(); + try (Scanner scanner = dataset.newScan(scanOptions)) { + try (ArrowReader reader = scanner.scanBatches()) { + while (reader.loadNextBatch()) { + List fieldVectors = reader.getVectorSchemaRoot().getFieldVectors(); + VarCharVector nameVector = (VarCharVector) fieldVectors.get(0); + /* dataset context + * i: | id | name | :i + * 1: | 1 | P0 | :0 + * 2: | null | P1 | :1 + * 3: | 2 | P2 | :2 + * 5: | null | P3 | :3 + * 4: | 2 | P3 | :4 + * 7: | 4 | P4 | :5 + * 9: | 5 | P5 | :6 + * 8: | 4 | P5 | :7 + * 6: | 3 | null | :8 + * 0: | 0 | null | :9 + */ + assertEquals("P0", new String(nameVector.get(0))); + assertEquals("P1", new String(nameVector.get(1))); + assertEquals("P2", new String(nameVector.get(2))); + assertEquals("P3", new String(nameVector.get(3))); + assertEquals("P3", new String(nameVector.get(4))); + assertEquals("P4", new String(nameVector.get(5))); + assertEquals("P5", new String(nameVector.get(6))); + assertEquals("P5", new String(nameVector.get(7))); + assertTrue(nameVector.isNull(8)); + assertTrue(nameVector.isNull(9)); + + IntVector idVector = (IntVector) fieldVectors.get(1); + assertEquals(1, idVector.get(0)); + assertTrue(idVector.isNull(1)); + assertEquals(2, idVector.get(2)); + assertTrue(idVector.isNull(3)); + assertEquals(2, idVector.get(4)); + assertEquals(4, idVector.get(5)); + assertEquals(5, idVector.get(6)); + assertEquals(4, idVector.get(7)); + assertEquals(3, idVector.get(8)); + assertEquals(0, idVector.get(9)); + } + } + } + } + } + } + @Test void testDatasetScannerCombinedParams() throws Exception { String datasetPath = tempDir.resolve("dataset_scanner_combined_params").toString(); try (BufferAllocator allocator = new RootAllocator()) { - TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); testDataset.createEmptyDataset().close(); int totalRows = 600; int limit = 200; int offset = 300; int batchSize = 50; try (Dataset dataset = testDataset.write(1, totalRows)) { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder() - .limit(limit) - .offset(offset) - .withRowId(true) - .batchSize(batchSize) - .batchReadahead(3) - .build())) { + try (Scanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .limit(limit) + .offset(offset) + .withRowId(true) + .batchSize(batchSize) + .batchReadahead(3) + .build())) { try (ArrowReader reader = scanner.scanBatches()) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); - List fieldNames = root.getSchema().getFields().stream() - .map(Field::getName) - .collect(Collectors.toList()); + List fieldNames = + root.getSchema().getFields().stream() + .map(Field::getName) + .collect(Collectors.toList()); assertTrue(fieldNames.contains("_rowid"), "Schema should contain _rowid column"); assertTrue(fieldNames.contains("id"), "Schema should contain id column"); @@ -387,7 +511,8 @@ void testDatasetScannerCombinedParams() throws Exception { int expectedIdStart = offset; while (reader.loadNextBatch()) { List fieldVectors = root.getFieldVectors(); - assertTrue(fieldVectors.stream().anyMatch(vector -> vector.getName().equals("_rowid"))); + assertTrue( + fieldVectors.stream().anyMatch(vector -> vector.getName().equals("_rowid"))); IntVector idVector = (IntVector) root.getVector("id"); int batchRowCount = root.getRowCount(); rowCount += batchRowCount; @@ -395,9 +520,15 @@ void testDatasetScannerCombinedParams() throws Exception { for (int i = 0; i < batchRowCount; i++) { int expectedId = expectedIdStart + i; - assertEquals(expectedId, idVector.get(i), - "Mismatch at row " + (rowCount - batchRowCount + i) + - ". Expected: " + expectedId + ", Actual: " + idVector.get(i)); + assertEquals( + expectedId, + idVector.get(i), + "Mismatch at row " + + (rowCount - batchRowCount + i) + + ". Expected: " + + expectedId + + ", Actual: " + + idVector.get(i)); } expectedIdStart += batchRowCount; } @@ -409,9 +540,15 @@ void testDatasetScannerCombinedParams() throws Exception { } private void validScanResult(Dataset dataset, int fragmentId, int rowCount) throws Exception { - try (Scanner scanner = dataset.newScan(new ScanOptions.Builder().batchSize(1024).fragmentIds(Arrays.asList(fragmentId)).build())) { + try (Scanner scanner = + dataset.newScan( + new ScanOptions.Builder() + .batchSize(1024) + .fragmentIds(Arrays.asList(fragmentId)) + .build())) { try (ArrowReader reader = scanner.scanBatches()) { - assertEquals(dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + assertEquals( + dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); reader.loadNextBatch(); assertEquals(rowCount, reader.getVectorSchemaRoot().getRowCount()); assertFalse(reader.loadNextBatch()); diff --git a/java/core/src/test/java/com/lancedb/lance/TestUtils.java b/java/core/src/test/java/com/lancedb/lance/TestUtils.java index 461adc47674..dac8db2b628 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestUtils.java +++ b/java/core/src/test/java/com/lancedb/lance/TestUtils.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; import org.apache.arrow.c.ArrowArrayStream; @@ -48,10 +47,12 @@ public class TestUtils { public static class SimpleTestDataset { - private final Schema schema = new Schema(Arrays.asList( - Field.nullable("id", new ArrowType.Int(32, true)), - Field.nullable("name", new ArrowType.Utf8()) - ), null); + private final Schema schema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8())), + null); private final BufferAllocator allocator; private final String datasetPath; @@ -59,25 +60,33 @@ public SimpleTestDataset(BufferAllocator allocator, String datasetPath) { this.allocator = allocator; this.datasetPath = datasetPath; } - + public Schema getSchema() { return schema; } - + public Dataset createEmptyDataset() { - Dataset dataset = Dataset.create(allocator, datasetPath, - schema, new WriteParams.Builder().build()); + Dataset dataset = + Dataset.create(allocator, datasetPath, schema, new WriteParams.Builder().build()); assertEquals(0, dataset.countRows()); assertEquals(schema, dataset.getSchema()); - List fragments = dataset.getFragments(); + List fragments = dataset.getFragments(); assertEquals(0, fragments.size()); assertEquals(1, dataset.version()); assertEquals(1, dataset.latestVersion()); return dataset; } - public FragmentMetadata createNewFragment(int fragmentId, int rowCount) { - FragmentMetadata fragmentMeta; + public FragmentMetadata createNewFragment(int rowCount) { + List fragmentMetas = createNewFragment(rowCount, Integer.MAX_VALUE); + assertEquals(1, fragmentMetas.size()); + FragmentMetadata fragmentMeta = fragmentMetas.get(0); + assertEquals(rowCount, fragmentMeta.getPhysicalRows()); + return fragmentMeta; + } + + public List createNewFragment(int rowCount, int maxRowsPerFile) { + List fragmentMetas; try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { root.allocateNew(); IntVector idVector = (IntVector) root.getVector("id"); @@ -90,24 +99,81 @@ public FragmentMetadata createNewFragment(int fragmentId, int rowCount) { } root.setRowCount(rowCount); - fragmentMeta = Fragment.create(datasetPath, - allocator, root, Optional.of(fragmentId), new WriteParams.Builder().build()); - assertEquals(fragmentId, fragmentMeta.getId()); - assertEquals(rowCount, fragmentMeta.getPhysicalRows()); + fragmentMetas = + Fragment.create( + datasetPath, + allocator, + root, + new WriteParams.Builder().withMaxRowsPerFile(maxRowsPerFile).build()); } - return fragmentMeta; + return fragmentMetas; } public Dataset write(long version, int rowCount) { - FragmentMetadata metadata = createNewFragment(rowCount, rowCount); + FragmentMetadata metadata = createNewFragment(rowCount); FragmentOperation.Append appendOp = new FragmentOperation.Append(Arrays.asList(metadata)); return Dataset.commit(allocator, datasetPath, appendOp, Optional.of(version)); } - + + public Dataset writeSortByDataset(long version) { + List fragmentMetas; + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + /* dataset context + * i: | id | name | + * 0: | 0 | null | + * 1: | 1 | P0 | + * 2: | null | P1 | + * 3: | 2 | P2 | + * 4: | 2 | P3 | + * 5: | null | P3 | + * 6: | 3 | null | + * 7: | 4 | P4 | + * 8: | 4 | P5 | + * 9: | 5 | P5 | + */ + idVector.set(0, 0); + idVector.set(1, 1); + idVector.setNull(2); + idVector.set(3, 2); + idVector.set(4, 2); + idVector.setNull(5); + idVector.set(6, 3); + idVector.set(7, 4); + idVector.set(8, 4); + idVector.set(9, 5); + + nameVector.setNull(0); + nameVector.set(1, "P0".getBytes()); + nameVector.set(2, "P1".getBytes()); + nameVector.set(3, "P2".getBytes()); + nameVector.set(4, "P3".getBytes()); + nameVector.set(5, "P3".getBytes()); + nameVector.setNull(6); + nameVector.set(7, "P4".getBytes()); + nameVector.set(8, "P5".getBytes()); + nameVector.set(9, "P5".getBytes()); + + root.setRowCount(10); + + fragmentMetas = + Fragment.create( + datasetPath, + allocator, + root, + new WriteParams.Builder().withMaxRowsPerFile(Integer.MAX_VALUE).build()); + } + FragmentOperation.Append appendOp = new FragmentOperation.Append(fragmentMetas); + return Dataset.commit(allocator, datasetPath, appendOp, Optional.of(version)); + } + public void validateScanResults(Dataset dataset, Scanner scanner, int totalRows, int batchRows) throws IOException { try (ArrowReader reader = scanner.scanBatches()) { - assertEquals(dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + assertEquals( + dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowcount = 0; while (reader.loadNextBatch()) { int currentRowCount = reader.getVectorSchemaRoot().getRowCount(); @@ -118,10 +184,12 @@ public void validateScanResults(Dataset dataset, Scanner scanner, int totalRows, } } - public void validateScanResults(Dataset dataset, Scanner scanner, int expectedRows, int batchRows, int offset) + public void validateScanResults( + Dataset dataset, Scanner scanner, int expectedRows, int batchRows, int offset) throws IOException { try (ArrowReader reader = scanner.scanBatches()) { - assertEquals(dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + assertEquals( + dataset.getSchema().getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowcount = 0; while (reader.loadNextBatch()) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); @@ -132,7 +200,8 @@ public void validateScanResults(Dataset dataset, Scanner scanner, int expectedRo IntVector idVector = (IntVector) root.getVector("id"); for (int i = 0; i < currentRowCount; i++) { int expectedId = offset + rowcount - currentRowCount + i; - assertEquals(expectedId, idVector.get(i), "Mismatch at row " + (rowcount - currentRowCount + i)); + assertEquals( + expectedId, idVector.get(i), "Mismatch at row " + (rowcount - currentRowCount + i)); } } assertEquals(expectedRows, rowcount); @@ -146,32 +215,33 @@ public static class RandomAccessDataset { private final BufferAllocator allocator; private final String datasetPath; private Schema schema; - + public RandomAccessDataset(BufferAllocator allocator, String datasetPath) { this.allocator = allocator; this.datasetPath = datasetPath; } - + public void createDatasetAndValidate() throws IOException, URISyntaxException { Path path = Paths.get(DatasetTest.class.getResource(DATA_FILE).toURI()); try (BufferAllocator allocator = new RootAllocator(); - ArrowFileReader reader = - new ArrowFileReader( - new SeekableReadChannel( - new ByteArrayReadableSeekableByteChannel(Files.readAllBytes(path))), - allocator); - ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { + ArrowFileReader reader = + new ArrowFileReader( + new SeekableReadChannel( + new ByteArrayReadableSeekableByteChannel(Files.readAllBytes(path))), + allocator); + ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, reader, arrowStream); - try (Dataset dataset = Dataset.create( - allocator, - arrowStream, - datasetPath, - new WriteParams.Builder() - .withMaxRowsPerFile(10) - .withMaxRowsPerGroup(20) - .withMode(WriteParams.WriteMode.CREATE) - .withStorageOptions(new HashMap<>()) - .build())) { + try (Dataset dataset = + Dataset.create( + allocator, + arrowStream, + datasetPath, + new WriteParams.Builder() + .withMaxRowsPerFile(10) + .withMaxRowsPerGroup(20) + .withMode(WriteParams.WriteMode.CREATE) + .withStorageOptions(new HashMap<>()) + .build())) { assertEquals(ROW_COUNT, dataset.countRows()); schema = reader.getVectorSchemaRoot().getSchema(); validateFragments(dataset); @@ -198,7 +268,7 @@ public Schema getSchema() { private void validateFragments(Dataset dataset) { assertNotNull(schema); assertNotNull(dataset); - List fragments = dataset.getFragments(); + List fragments = dataset.getFragments(); assertEquals(1, fragments.size()); assertEquals(0, fragments.get(0).getId()); assertEquals(9, fragments.get(0).countRows()); diff --git a/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java b/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java index 564d47dd253..3b7055e8394 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java +++ b/java/core/src/test/java/com/lancedb/lance/TestVectorDataset.java @@ -11,9 +11,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; +import com.lancedb.lance.index.DistanceType; +import com.lancedb.lance.index.IndexParams; +import com.lancedb.lance.index.IndexType; +import com.lancedb.lance.index.vector.VectorIndexParams; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.*; @@ -25,11 +29,6 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.Text; -import com.lancedb.lance.index.DistanceType; -import com.lancedb.lance.index.IndexParams; -import com.lancedb.lance.index.IndexType; -import com.lancedb.lance.index.vector.VectorIndexParams; - import java.io.IOException; import java.nio.file.Path; import java.util.*; @@ -55,21 +54,26 @@ private Schema createSchema() { Map metadata = new HashMap<>(); metadata.put("dataset", "vector"); - List fields = Arrays.asList( - new Field("i", FieldType.nullable(new ArrowType.Int(32, true)), null), - new Field("s", FieldType.nullable(new ArrowType.Utf8()), null), - new Field(vectorColumnName, FieldType.nullable(new ArrowType.FixedSizeList(32)), - Collections.singletonList(new Field("item", - FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null)))); + List fields = + Arrays.asList( + new Field("i", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("s", FieldType.nullable(new ArrowType.Utf8()), null), + new Field( + vectorColumnName, + FieldType.nullable(new ArrowType.FixedSizeList(32)), + Collections.singletonList( + new Field( + "item", + FieldType.nullable( + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), + null)))); return new Schema(fields, metadata); } private Dataset createDataset() throws IOException { - WriteParams writeParams = new WriteParams.Builder() - .withMaxRowsPerGroup(10) - .withMaxRowsPerFile(200) - .build(); + WriteParams writeParams = + new WriteParams.Builder().withMaxRowsPerGroup(10).withMaxRowsPerFile(200).build(); Dataset.create(allocator, datasetPath.toString(), schema, writeParams).close(); @@ -102,7 +106,7 @@ private FragmentMetadata createFragment(int batchIndex) throws IOException { root.setRowCount(80); WriteParams fragmentWriteParams = new WriteParams.Builder().build(); - return Fragment.create(datasetPath.toString(), allocator, root, Optional.of(batchIndex), fragmentWriteParams); + return Fragment.create(datasetPath.toString(), allocator, root, fragmentWriteParams).get(0); } } @@ -127,18 +131,21 @@ public Dataset appendNewData() throws IOException { root.setRowCount(10); WriteParams writeParams = new WriteParams.Builder().build(); - fragmentMetadata = Fragment.create(datasetPath.toString(), allocator, root, Optional.empty(), - writeParams); + fragmentMetadata = + Fragment.create(datasetPath.toString(), allocator, root, writeParams).get(0); } - FragmentOperation.Append appendOp = new FragmentOperation.Append(Collections.singletonList(fragmentMetadata)); + FragmentOperation.Append appendOp = + new FragmentOperation.Append(Collections.singletonList(fragmentMetadata)); return Dataset.commit(allocator, datasetPath.toString(), appendOp, Optional.of(2L)); } public void createIndex(Dataset dataset) { - IndexParams params = new IndexParams.Builder() - .setVectorIndexParams(VectorIndexParams.ivfPq(2, 8, 2, DistanceType.L2, 2)) - .build(); - dataset.createIndex(Arrays.asList(vectorColumnName), IndexType.VECTOR, Optional.of(indexName), params, true); + IndexParams params = + new IndexParams.Builder() + .setVectorIndexParams(VectorIndexParams.ivfPq(2, 8, 2, DistanceType.L2, 2)) + .build(); + dataset.createIndex( + Arrays.asList(vectorColumnName), IndexType.VECTOR, Optional.of(indexName), params, true); } @Override @@ -147,4 +154,4 @@ public void close() { allocator.close(); } } -} \ No newline at end of file +} diff --git a/java/core/src/test/java/com/lancedb/lance/VectorSearchTest.java b/java/core/src/test/java/com/lancedb/lance/VectorSearchTest.java index e7492a2c536..c07f8efc7b3 100644 --- a/java/core/src/test/java/com/lancedb/lance/VectorSearchTest.java +++ b/java/core/src/test/java/com/lancedb/lance/VectorSearchTest.java @@ -11,30 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance; -import org.apache.arrow.dataset.scanner.Scanner; -import org.apache.arrow.vector.Float4Vector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.ipc.ArrowReader; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import com.lancedb.lance.ipc.Query; -import com.lancedb.lance.ipc.ScanOptions; - import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; import java.util.Optional; -import java.util.Set; - -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import static org.junit.jupiter.api.Assertions.*; @@ -48,14 +30,14 @@ // // An IVF-PQ index with 2 partitions is trained on this data public class VectorSearchTest { - @TempDir - Path tempDir; + @TempDir Path tempDir; // TODO: fix in https://github.com/lancedb/lance/issues/2956 // @Test // void test_create_index() throws Exception { - // try (TestVectorDataset testVectorDataset = new TestVectorDataset(tempDir.resolve("test_create_index"))) { + // try (TestVectorDataset testVectorDataset = new + // TestVectorDataset(tempDir.resolve("test_create_index"))) { // try (Dataset dataset = testVectorDataset.create()) { // testVectorDataset.createIndex(dataset); // List indexes = dataset.listIndexes(); @@ -70,7 +52,8 @@ public class VectorSearchTest { // Directly panic instead of throwing an exception // @Test // void search_invalid_vector() throws Exception { - // try (TestVectorDataset testVectorDataset = new TestVectorDataset(tempDir.resolve("test_create_index"))) { + // try (TestVectorDataset testVectorDataset = new + // TestVectorDataset(tempDir.resolve("test_create_index"))) { // try (Dataset dataset = testVectorDataset.create()) { // float[] key = new float[30]; // for (int i = 0; i < 30; i++) { @@ -97,7 +80,8 @@ public class VectorSearchTest { // @ParameterizedTest // @ValueSource(booleans = { false, true }) // void test_knn(boolean createVectorIndex) throws Exception { - // try (TestVectorDataset testVectorDataset = new TestVectorDataset(tempDir.resolve("test_knn"))) { + // try (TestVectorDataset testVectorDataset = new + // TestVectorDataset(tempDir.resolve("test_knn"))) { // try (Dataset dataset = testVectorDataset.create()) { // if (createVectorIndex) { @@ -126,7 +110,8 @@ public class VectorSearchTest { // assertEquals(4, root.getSchema().getFields().size(), "Expected 4 columns"); // assertEquals("i", root.getSchema().getFields().get(0).getName()); // assertEquals("s", root.getSchema().getFields().get(1).getName()); - // assertEquals(TestVectorDataset.vectorColumnName, root.getSchema().getFields().get(2).getName()); + // assertEquals(TestVectorDataset.vectorColumnName, + // root.getSchema().getFields().get(2).getName()); // assertEquals("_distance", root.getSchema().getFields().get(3).getName()); // IntVector iVector = (IntVector) root.getVector("i"); @@ -154,7 +139,8 @@ public class VectorSearchTest { // @Test // void test_knn_with_new_data() throws Exception { - // try (TestVectorDataset testVectorDataset = new TestVectorDataset(tempDir.resolve("test_knn_with_new_data"))) { + // try (TestVectorDataset testVectorDataset = new + // TestVectorDataset(tempDir.resolve("test_knn_with_new_data"))) { // try (Dataset dataset = testVectorDataset.create()) { // testVectorDataset.createIndex(dataset); // } @@ -201,7 +187,8 @@ public class VectorSearchTest { // int resultRows = root.getRowCount(); // int expectedRows = testCase.limit.orElse(k); // assertTrue(resultRows <= expectedRows, - // "Expected less than or equal to " + expectedRows + " rows, got " + resultRows); + // "Expected less than or equal to " + expectedRows + " rows, got " + + // resultRows); // } else { // assertEquals(testCase.limit.orElse(k), root.getRowCount(), // "Unexpected number of rows"); @@ -209,7 +196,8 @@ public class VectorSearchTest { // // Top one should be the first value of new data // IntVector iVector = (IntVector) root.getVector("i"); - // assertEquals(400, iVector.get(0), "First result should be the first value of new data"); + // assertEquals(400, iVector.get(0), "First result should be the first value of new + // data"); // // Check if distances are in ascending order // Float4Vector distanceVector = (Float4Vector) root.getVector("_distance"); diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index caf80f6cba2..3cdf2e0edc9 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -3508,7 +3508,7 @@ dependencies = [ [[package]] name = "strum_macros" -version = "0.25.3" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ diff --git a/java/mvnw b/java/mvnw new file mode 100755 index 00000000000..19529ddf8c6 --- /dev/null +++ b/java/mvnw @@ -0,0 +1,259 @@ +#!/bin/sh +# ---------------------------------------------------------------------------- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- +# Apache Maven Wrapper startup batch script, version 3.3.2 +# +# Optional ENV vars +# ----------------- +# JAVA_HOME - location of a JDK home dir, required when download maven via java source +# MVNW_REPOURL - repo url base for downloading maven distribution +# MVNW_USERNAME/MVNW_PASSWORD - user and password for downloading maven +# MVNW_VERBOSE - true: enable verbose log; debug: trace the mvnw script; others: silence the output +# ---------------------------------------------------------------------------- + +set -euf +[ "${MVNW_VERBOSE-}" != debug ] || set -x + +# OS specific support. +native_path() { printf %s\\n "$1"; } +case "$(uname)" in +CYGWIN* | MINGW*) + [ -z "${JAVA_HOME-}" ] || JAVA_HOME="$(cygpath --unix "$JAVA_HOME")" + native_path() { cygpath --path --windows "$1"; } + ;; +esac + +# set JAVACMD and JAVACCMD +set_java_home() { + # For Cygwin and MinGW, ensure paths are in Unix format before anything is touched + if [ -n "${JAVA_HOME-}" ]; then + if [ -x "$JAVA_HOME/jre/sh/java" ]; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACCMD="$JAVA_HOME/jre/sh/javac" + else + JAVACMD="$JAVA_HOME/bin/java" + JAVACCMD="$JAVA_HOME/bin/javac" + + if [ ! -x "$JAVACMD" ] || [ ! -x "$JAVACCMD" ]; then + echo "The JAVA_HOME environment variable is not defined correctly, so mvnw cannot run." >&2 + echo "JAVA_HOME is set to \"$JAVA_HOME\", but \"\$JAVA_HOME/bin/java\" or \"\$JAVA_HOME/bin/javac\" does not exist." >&2 + return 1 + fi + fi + else + JAVACMD="$( + 'set' +e + 'unset' -f command 2>/dev/null + 'command' -v java + )" || : + JAVACCMD="$( + 'set' +e + 'unset' -f command 2>/dev/null + 'command' -v javac + )" || : + + if [ ! -x "${JAVACMD-}" ] || [ ! -x "${JAVACCMD-}" ]; then + echo "The java/javac command does not exist in PATH nor is JAVA_HOME set, so mvnw cannot run." >&2 + return 1 + fi + fi +} + +# hash string like Java String::hashCode +hash_string() { + str="${1:-}" h=0 + while [ -n "$str" ]; do + char="${str%"${str#?}"}" + h=$(((h * 31 + $(LC_CTYPE=C printf %d "'$char")) % 4294967296)) + str="${str#?}" + done + printf %x\\n $h +} + +verbose() { :; } +[ "${MVNW_VERBOSE-}" != true ] || verbose() { printf %s\\n "${1-}"; } + +die() { + printf %s\\n "$1" >&2 + exit 1 +} + +trim() { + # MWRAPPER-139: + # Trims trailing and leading whitespace, carriage returns, tabs, and linefeeds. + # Needed for removing poorly interpreted newline sequences when running in more + # exotic environments such as mingw bash on Windows. + printf "%s" "${1}" | tr -d '[:space:]' +} + +# parse distributionUrl and optional distributionSha256Sum, requires .mvn/wrapper/maven-wrapper.properties +while IFS="=" read -r key value; do + case "${key-}" in + distributionUrl) distributionUrl=$(trim "${value-}") ;; + distributionSha256Sum) distributionSha256Sum=$(trim "${value-}") ;; + esac +done <"${0%/*}/.mvn/wrapper/maven-wrapper.properties" +[ -n "${distributionUrl-}" ] || die "cannot read distributionUrl property in ${0%/*}/.mvn/wrapper/maven-wrapper.properties" + +case "${distributionUrl##*/}" in +maven-mvnd-*bin.*) + MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ + case "${PROCESSOR_ARCHITECTURE-}${PROCESSOR_ARCHITEW6432-}:$(uname -a)" in + *AMD64:CYGWIN* | *AMD64:MINGW*) distributionPlatform=windows-amd64 ;; + :Darwin*x86_64) distributionPlatform=darwin-amd64 ;; + :Darwin*arm64) distributionPlatform=darwin-aarch64 ;; + :Linux*x86_64*) distributionPlatform=linux-amd64 ;; + *) + echo "Cannot detect native platform for mvnd on $(uname)-$(uname -m), use pure java version" >&2 + distributionPlatform=linux-amd64 + ;; + esac + distributionUrl="${distributionUrl%-bin.*}-$distributionPlatform.zip" + ;; +maven-mvnd-*) MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ ;; +*) MVN_CMD="mvn${0##*/mvnw}" _MVNW_REPO_PATTERN=/org/apache/maven/ ;; +esac + +# apply MVNW_REPOURL and calculate MAVEN_HOME +# maven home pattern: ~/.m2/wrapper/dists/{apache-maven-,maven-mvnd--}/ +[ -z "${MVNW_REPOURL-}" ] || distributionUrl="$MVNW_REPOURL$_MVNW_REPO_PATTERN${distributionUrl#*"$_MVNW_REPO_PATTERN"}" +distributionUrlName="${distributionUrl##*/}" +distributionUrlNameMain="${distributionUrlName%.*}" +distributionUrlNameMain="${distributionUrlNameMain%-bin}" +MAVEN_USER_HOME="${MAVEN_USER_HOME:-${HOME}/.m2}" +MAVEN_HOME="${MAVEN_USER_HOME}/wrapper/dists/${distributionUrlNameMain-}/$(hash_string "$distributionUrl")" + +exec_maven() { + unset MVNW_VERBOSE MVNW_USERNAME MVNW_PASSWORD MVNW_REPOURL || : + exec "$MAVEN_HOME/bin/$MVN_CMD" "$@" || die "cannot exec $MAVEN_HOME/bin/$MVN_CMD" +} + +if [ -d "$MAVEN_HOME" ]; then + verbose "found existing MAVEN_HOME at $MAVEN_HOME" + exec_maven "$@" +fi + +case "${distributionUrl-}" in +*?-bin.zip | *?maven-mvnd-?*-?*.zip) ;; +*) die "distributionUrl is not valid, must match *-bin.zip or maven-mvnd-*.zip, but found '${distributionUrl-}'" ;; +esac + +# prepare tmp dir +if TMP_DOWNLOAD_DIR="$(mktemp -d)" && [ -d "$TMP_DOWNLOAD_DIR" ]; then + clean() { rm -rf -- "$TMP_DOWNLOAD_DIR"; } + trap clean HUP INT TERM EXIT +else + die "cannot create temp dir" +fi + +mkdir -p -- "${MAVEN_HOME%/*}" + +# Download and Install Apache Maven +verbose "Couldn't find MAVEN_HOME, downloading and installing it ..." +verbose "Downloading from: $distributionUrl" +verbose "Downloading to: $TMP_DOWNLOAD_DIR/$distributionUrlName" + +# select .zip or .tar.gz +if ! command -v unzip >/dev/null; then + distributionUrl="${distributionUrl%.zip}.tar.gz" + distributionUrlName="${distributionUrl##*/}" +fi + +# verbose opt +__MVNW_QUIET_WGET=--quiet __MVNW_QUIET_CURL=--silent __MVNW_QUIET_UNZIP=-q __MVNW_QUIET_TAR='' +[ "${MVNW_VERBOSE-}" != true ] || __MVNW_QUIET_WGET='' __MVNW_QUIET_CURL='' __MVNW_QUIET_UNZIP='' __MVNW_QUIET_TAR=v + +# normalize http auth +case "${MVNW_PASSWORD:+has-password}" in +'') MVNW_USERNAME='' MVNW_PASSWORD='' ;; +has-password) [ -n "${MVNW_USERNAME-}" ] || MVNW_USERNAME='' MVNW_PASSWORD='' ;; +esac + +if [ -z "${MVNW_USERNAME-}" ] && command -v wget >/dev/null; then + verbose "Found wget ... using wget" + wget ${__MVNW_QUIET_WGET:+"$__MVNW_QUIET_WGET"} "$distributionUrl" -O "$TMP_DOWNLOAD_DIR/$distributionUrlName" || die "wget: Failed to fetch $distributionUrl" +elif [ -z "${MVNW_USERNAME-}" ] && command -v curl >/dev/null; then + verbose "Found curl ... using curl" + curl ${__MVNW_QUIET_CURL:+"$__MVNW_QUIET_CURL"} -f -L -o "$TMP_DOWNLOAD_DIR/$distributionUrlName" "$distributionUrl" || die "curl: Failed to fetch $distributionUrl" +elif set_java_home; then + verbose "Falling back to use Java to download" + javaSource="$TMP_DOWNLOAD_DIR/Downloader.java" + targetZip="$TMP_DOWNLOAD_DIR/$distributionUrlName" + cat >"$javaSource" <<-END + public class Downloader extends java.net.Authenticator + { + protected java.net.PasswordAuthentication getPasswordAuthentication() + { + return new java.net.PasswordAuthentication( System.getenv( "MVNW_USERNAME" ), System.getenv( "MVNW_PASSWORD" ).toCharArray() ); + } + public static void main( String[] args ) throws Exception + { + setDefault( new Downloader() ); + java.nio.file.Files.copy( java.net.URI.create( args[0] ).toURL().openStream(), java.nio.file.Paths.get( args[1] ).toAbsolutePath().normalize() ); + } + } + END + # For Cygwin/MinGW, switch paths to Windows format before running javac and java + verbose " - Compiling Downloader.java ..." + "$(native_path "$JAVACCMD")" "$(native_path "$javaSource")" || die "Failed to compile Downloader.java" + verbose " - Running Downloader.java ..." + "$(native_path "$JAVACMD")" -cp "$(native_path "$TMP_DOWNLOAD_DIR")" Downloader "$distributionUrl" "$(native_path "$targetZip")" +fi + +# If specified, validate the SHA-256 sum of the Maven distribution zip file +if [ -n "${distributionSha256Sum-}" ]; then + distributionSha256Result=false + if [ "$MVN_CMD" = mvnd.sh ]; then + echo "Checksum validation is not supported for maven-mvnd." >&2 + echo "Please disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 + exit 1 + elif command -v sha256sum >/dev/null; then + if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | sha256sum -c >/dev/null 2>&1; then + distributionSha256Result=true + fi + elif command -v shasum >/dev/null; then + if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | shasum -a 256 -c >/dev/null 2>&1; then + distributionSha256Result=true + fi + else + echo "Checksum validation was requested but neither 'sha256sum' or 'shasum' are available." >&2 + echo "Please install either command, or disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 + exit 1 + fi + if [ $distributionSha256Result = false ]; then + echo "Error: Failed to validate Maven distribution SHA-256, your Maven distribution might be compromised." >&2 + echo "If you updated your Maven version, you need to update the specified distributionSha256Sum property." >&2 + exit 1 + fi +fi + +# unzip and move +if command -v unzip >/dev/null; then + unzip ${__MVNW_QUIET_UNZIP:+"$__MVNW_QUIET_UNZIP"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -d "$TMP_DOWNLOAD_DIR" || die "failed to unzip" +else + tar xzf${__MVNW_QUIET_TAR:+"$__MVNW_QUIET_TAR"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -C "$TMP_DOWNLOAD_DIR" || die "failed to untar" +fi +printf %s\\n "$distributionUrl" >"$TMP_DOWNLOAD_DIR/$distributionUrlNameMain/mvnw.url" +mv -- "$TMP_DOWNLOAD_DIR/$distributionUrlNameMain" "$MAVEN_HOME" || [ -d "$MAVEN_HOME" ] || die "fail to move MAVEN_HOME" + +clean || : +exec_maven "$@" diff --git a/java/pom.xml b/java/pom.xml index 84fe148ce2c..1962978dfa1 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -6,7 +6,7 @@ com.lancedb lance-parent - 0.20.0 + 0.26.2 pom Lance Parent @@ -30,6 +30,29 @@ UTF-8 15.0.0 0.28.1 + false + 2.30.0 + 1.7 + 2.12.19 + 2.12 + + 3.7.5 + package + + /* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + @@ -75,11 +98,6 @@ junit-jupiter 5.10.1 - - org.json - json - 20231013 - org.apache.commons commons-lang3 @@ -156,6 +174,10 @@ true + + com.diffplug.spotless + spotless-maven-plugin + @@ -185,6 +207,59 @@ maven-install-plugin 2.5.2 + + com.diffplug.spotless + spotless-maven-plugin + ${spotless.version} + + ${spotless.skip} + + true + + + + src/main/java/**/*.java + src/test/java/**/*.java + + + ${spotless.java.googlejavaformat.version} + + + + + com.lancedb.lance,,javax,java,\# + + + + + + + src/main/scala/**/*.scala + src/main/scala-*/**/*.scala + src/test/scala/**/*.scala + src/test/scala-*/**/*.scala + + + ${spotless.scala.scalafmt.version} + ${scala.binary.version} + ${maven.multiModuleProjectDirectory}/.scalafmt.conf + + + + ${spotless.license.header} + ${spotless.delimiter} + + + + + spotless-check + validate + + apply + + + + @@ -267,4 +342,4 @@ - \ No newline at end of file + diff --git a/java/spark/pom.xml b/java/spark/pom.xml index 681e0a62027..68a4daa3e63 100644 --- a/java/spark/pom.xml +++ b/java/spark/pom.xml @@ -8,7 +8,7 @@ com.lancedb lance-parent - 0.20.0 + 0.26.2 ../pom.xml @@ -23,6 +23,75 @@ 2.12 + + + + net.alchim31.maven + scala-maven-plugin + 3.2.1 + + + scala-compile-first + process-resources + + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + + -feature + + + + + org.apache.maven.plugins + maven-dependency-plugin + 3.6.1 + + + copy-dependencies + package + + copy-dependencies + + + ${project.build.directory}/jars + false + false + true + lance-core,arrow-c-data,jar-jni,arrow-dataset + + + + copy-self + package + + copy + + + + + ${project.groupId} + ${project.artifactId} + ${project.version} + jar + ${project.build.directory}/jars + + + + + + + + scala-2.13 @@ -82,17 +151,24 @@ com.lancedb lance-core - 0.20.0 + 0.26.2 org.apache.spark spark-sql_${scala.compat.version} ${spark.version} + provided org.junit.jupiter junit-jupiter test + + org.scalatest + scalatest_2.12 + 3.2.10 + test + diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceCatalog.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceCatalog.java index 05a66fa9d0a..69136a1bf1a 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceCatalog.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceCatalog.java @@ -11,12 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark; import com.lancedb.lance.WriteParams; import com.lancedb.lance.spark.internal.LanceDatasetAdapter; import com.lancedb.lance.spark.utils.Optional; + import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; @@ -33,6 +33,7 @@ public class LanceCatalog implements TableCatalog { private CaseInsensitiveStringMap options; + @Override public Identifier[] listTables(String[] namespace) throws NoSuchNamespaceException { throw new UnsupportedOperationException("Please use lancedb catalog for dataset listing"); @@ -49,8 +50,9 @@ public Table loadTable(Identifier ident) throws NoSuchTableException { } @Override - public Table createTable(Identifier ident, StructType schema, Transform[] partitions, - Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException { + public Table createTable( + Identifier ident, StructType schema, Transform[] partitions, Map properties) + throws TableAlreadyExistsException, NoSuchNamespaceException { try { LanceConfig config = LanceConfig.from(options, ident.name()); WriteParams params = SparkOptions.genWriteParamsFromConfig(config); @@ -68,7 +70,9 @@ public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchT @Override public boolean dropTable(Identifier ident) { - throw new UnsupportedOperationException(); + LanceConfig config = LanceConfig.from(options, ident.name()); + LanceDatasetAdapter.dropDataset(config); + return true; } @Override diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceConfig.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConfig.java index 8758dabba9e..188b60c054d 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceConfig.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConfig.java @@ -11,16 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + import java.io.Serializable; import java.util.Map; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import java.util.Objects; -/** - * Lance Configuration. - */ +/** Lance Configuration. */ public class LanceConfig implements Serializable { private static final long serialVersionUID = 827364827364823764L; public static final String CONFIG_DATASET_URI = "path"; // Path is default spark option key @@ -35,8 +34,12 @@ public class LanceConfig implements Serializable { private final boolean pushDownFilters; private final Map options; - private LanceConfig(String dbPath, String datasetName, - String datasetUri, boolean pushDownFilters, CaseInsensitiveStringMap options) { + private LanceConfig( + String dbPath, + String datasetName, + String datasetUri, + boolean pushDownFilters, + CaseInsensitiveStringMap options) { this.dbPath = dbPath; this.datasetName = datasetName; this.datasetUri = datasetUri; @@ -64,8 +67,8 @@ public static LanceConfig from(String datasetUri) { } public static LanceConfig from(CaseInsensitiveStringMap options, String datasetUri) { - boolean pushDownFilters = options.getBoolean(CONFIG_PUSH_DOWN_FILTERS, - DEFAULT_PUSH_DOWN_FILTERS); + boolean pushDownFilters = + options.getBoolean(CONFIG_PUSH_DOWN_FILTERS, DEFAULT_PUSH_DOWN_FILTERS); String[] paths = extractDbPathAndDatasetName(datasetUri); return new LanceConfig(paths[0], paths[1], datasetUri, pushDownFilters, options); } @@ -89,9 +92,11 @@ private static String[] extractDbPathAndDatasetName(String datasetUri) { } String datasetNameWithSuffix = datasetUri.substring(lastSlashIndex + 1); - return new String[]{datasetUri.substring(0, lastSlashIndex + 1), - datasetNameWithSuffix.substring(0, - datasetNameWithSuffix.length() - LANCE_FILE_SUFFIX.length())}; + return new String[] { + datasetUri.substring(0, lastSlashIndex + 1), + datasetNameWithSuffix.substring( + 0, datasetNameWithSuffix.length() - LANCE_FILE_SUFFIX.length()) + }; } public String getDbPath() { @@ -113,4 +118,20 @@ public boolean isPushDownFilters() { public Map getOptions() { return options; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + LanceConfig config = (LanceConfig) o; + return pushDownFilters == config.pushDownFilters + && Objects.equals(dbPath, config.dbPath) + && Objects.equals(datasetName, config.datasetName) + && Objects.equals(datasetUri, config.datasetUri) + && Objects.equals(options, config.options); + } + + @Override + public int hashCode() { + return Objects.hash(dbPath, datasetName, datasetUri, pushDownFilters, options); + } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java new file mode 100644 index 00000000000..449c61dd62d --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.spark; + +public class LanceConstant { + public static final String ROW_ID = "_rowid"; + public static final String ROW_ADDRESS = "_rowaddr"; +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java index 0bc5fcbbdd1..a30e83e305b 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java @@ -11,11 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark; import com.lancedb.lance.spark.internal.LanceDatasetAdapter; import com.lancedb.lance.spark.utils.Optional; + import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.SupportsCatalogOptions; import org.apache.spark.sql.connector.catalog.Table; @@ -36,8 +36,8 @@ public StructType inferSchema(CaseInsensitiveStringMap options) { } @Override - public Table getTable(StructType schema, Transform[] partitioning, - Map properties) { + public Table getTable( + StructType schema, Transform[] partitioning, Map properties) { return new LanceDataset(LanceConfig.from(properties), schema); } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java index 702b3bdf42a..a6c5d0a3bf1 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java @@ -1,40 +1,70 @@ /* - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ - package com.lancedb.lance.spark; -import com.google.common.collect.ImmutableSet; - -import java.util.Set; - import com.lancedb.lance.spark.read.LanceScanBuilder; import com.lancedb.lance.spark.write.SparkWrite; + +import com.google.common.collect.ImmutableSet; +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.SupportsWrite; import org.apache.spark.sql.connector.catalog.TableCapability; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -/** - * Lance Spark Dataset. - */ -public class LanceDataset implements SupportsRead, SupportsWrite { +import java.util.Set; + +/** Lance Spark Dataset. */ +public class LanceDataset implements SupportsRead, SupportsWrite, SupportsMetadataColumns { private static final Set CAPABILITIES = - ImmutableSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE); + ImmutableSet.of( + TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE); + + public static final MetadataColumn[] METADATA_COLUMNS = + new MetadataColumn[] { + new MetadataColumn() { + @Override + public String name() { + return LanceConstant.ROW_ID; + } + + @Override + public DataType dataType() { + return DataTypes.LongType; + } + }, + new MetadataColumn() { + @Override + public String name() { + return LanceConstant.ROW_ADDRESS; + } - LanceConfig options; + @Override + public DataType dataType() { + return DataTypes.LongType; + } + }, + }; + + LanceConfig config; private final StructType sparkSchema; /** @@ -44,18 +74,18 @@ public class LanceDataset implements SupportsRead, SupportsWrite { * @param sparkSchema spark struct type */ public LanceDataset(LanceConfig config, StructType sparkSchema) { - this.options = config; + this.config = config; this.sparkSchema = sparkSchema; } @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap caseInsensitiveStringMap) { - return new LanceScanBuilder(sparkSchema, options); + return new LanceScanBuilder(sparkSchema, config); } @Override public String name() { - return this.options.getDatasetName(); + return this.config.getDatasetName(); } @Override @@ -70,6 +100,11 @@ public Set capabilities() { @Override public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { - return new SparkWrite.SparkWriteBuilder(sparkSchema, options); + return new SparkWrite.SparkWriteBuilder(sparkSchema, config); + } + + @Override + public MetadataColumn[] metadataColumns() { + return METADATA_COLUMNS; } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceIdentifier.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceIdentifier.java index 4c872721eec..1889f7fb75d 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceIdentifier.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceIdentifier.java @@ -11,13 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark; import org.apache.spark.sql.connector.catalog.Identifier; public class LanceIdentifier implements Identifier { - private final String[] namespace = new String[]{"default"}; + private final String[] namespace = new String[] {"default"}; private final String datasetUri; public LanceIdentifier(String datasetUri) { @@ -33,4 +32,4 @@ public String[] namespace() { public String name() { return datasetUri; } -} \ No newline at end of file +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java b/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java index 6ccee2c79ef..d91e2dd9dd0 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark; import com.lancedb.lance.ReadOptions; @@ -21,69 +20,95 @@ import java.util.Map; public class SparkOptions { - private static final String ak = "access_key_id"; - private static final String sk = "secret_access_key"; - private static final String endpoint = "aws_region"; - private static final String region = "aws_endpoint"; - private static final String virtual_hosted_style = "virtual_hosted_style_request"; - private static final String block_size = "block_size"; - private static final String version = "version"; - private static final String index_cache_size = "index_cache_size"; - private static final String metadata_cache_size = "metadata_cache_size"; - private static final String write_mode = "write_mode"; - private static final String max_row_per_file = "max_row_per_file"; - private static final String max_rows_per_group = "max_rows_per_group"; - private static final String max_bytes_per_file = "max_bytes_per_file"; + private static final String ak = "access_key_id"; + private static final String sk = "secret_access_key"; + private static final String endpoint = "aws_endpoint"; + private static final String region = "aws_region"; + private static final String virtual_hosted_style = "virtual_hosted_style_request"; + private static final String allow_http = "allow_http"; + + private static final String block_size = "block_size"; + private static final String version = "version"; + private static final String index_cache_size = "index_cache_size"; + private static final String metadata_cache_size = "metadata_cache_size"; + private static final String write_mode = "write_mode"; + private static final String max_row_per_file = "max_row_per_file"; + private static final String max_rows_per_group = "max_rows_per_group"; + private static final String max_bytes_per_file = "max_bytes_per_file"; + private static final String batch_size = "batch_size"; + private static final String topN_push_down = "topN_push_down"; + + public static ReadOptions genReadOptionFromConfig(LanceConfig config) { + ReadOptions.Builder builder = new ReadOptions.Builder(); + Map maps = config.getOptions(); + if (maps.containsKey(block_size)) { + builder.setBlockSize(Integer.parseInt(maps.get(block_size))); + } + if (maps.containsKey(version)) { + builder.setVersion(Integer.parseInt(maps.get(version))); + } + if (maps.containsKey(index_cache_size)) { + builder.setIndexCacheSize(Integer.parseInt(maps.get(index_cache_size))); + } + if (maps.containsKey(metadata_cache_size)) { + builder.setMetadataCacheSize(Integer.parseInt(maps.get(metadata_cache_size))); + } + builder.setStorageOptions(genStorageOptions(config)); + return builder.build(); + } - public static ReadOptions genReadOptionFromConfig(LanceConfig config) { - ReadOptions.Builder builder = new ReadOptions.Builder(); - Map maps = config.getOptions(); - if (maps.containsKey(block_size)) { - builder.setBlockSize(Integer.parseInt(maps.get(block_size))); - } - if (maps.containsKey(version)) { - builder.setVersion(Integer.parseInt(maps.get(version))); - } - if (maps.containsKey(index_cache_size)) { - builder.setIndexCacheSize(Integer.parseInt(maps.get(index_cache_size))); - } - if (maps.containsKey(metadata_cache_size)) { - builder.setMetadataCacheSize(Integer.parseInt(maps.get(metadata_cache_size))); - } - builder.setStorageOptions(genStorageOptions(config)); - return builder.build(); + public static WriteParams genWriteParamsFromConfig(LanceConfig config) { + WriteParams.Builder builder = new WriteParams.Builder(); + Map maps = config.getOptions(); + if (maps.containsKey(write_mode)) { + builder.withMode(WriteParams.WriteMode.valueOf(maps.get(write_mode))); } + if (maps.containsKey(max_row_per_file)) { + builder.withMaxRowsPerFile(Integer.parseInt(maps.get(max_row_per_file))); + } + if (maps.containsKey(max_rows_per_group)) { + builder.withMaxRowsPerGroup(Integer.parseInt(maps.get(max_rows_per_group))); + } + if (maps.containsKey(max_bytes_per_file)) { + builder.withMaxBytesPerFile(Long.parseLong(maps.get(max_bytes_per_file))); + } + builder.withStorageOptions(genStorageOptions(config)); + return builder.build(); + } - public static WriteParams genWriteParamsFromConfig(LanceConfig config) { - WriteParams.Builder builder = new WriteParams.Builder(); - Map maps = config.getOptions(); - if (maps.containsKey(write_mode)) { - builder.withMode(WriteParams.WriteMode.valueOf(maps.get(write_mode))); - } - if (maps.containsKey(max_row_per_file)) { - builder.withMaxRowsPerFile(Integer.parseInt(maps.get(max_row_per_file))); - } - if (maps.containsKey(max_rows_per_group)) { - builder.withMaxRowsPerGroup(Integer.parseInt(maps.get(max_rows_per_group))); - } - if (maps.containsKey(max_bytes_per_file)) { - builder.withMaxBytesPerFile(Long.parseLong(maps.get(max_bytes_per_file))); - } - builder.withStorageOptions(genStorageOptions(config)); - return builder.build(); + private static Map genStorageOptions(LanceConfig config) { + Map maps = config.getOptions(); + Map storageOptions = new HashMap<>(); + if (maps.containsKey(ak) && maps.containsKey(sk) && maps.containsKey(endpoint)) { + storageOptions.put(ak, maps.get(ak)); + storageOptions.put(sk, maps.get(sk)); + storageOptions.put(endpoint, maps.get(endpoint)); + } + if (maps.containsKey(region)) { + storageOptions.put(region, maps.get(region)); } + if (maps.containsKey(virtual_hosted_style)) { + storageOptions.put(virtual_hosted_style, maps.get(virtual_hosted_style)); + } + if (maps.containsKey(allow_http)) { + storageOptions.put(allow_http, maps.get(allow_http)); + } + return storageOptions; + } - private static Map genStorageOptions(LanceConfig config) { - Map maps = config.getOptions(); - Map storageOptions = new HashMap<>(); - if (maps.containsKey(ak) && maps.containsKey(sk) && maps.containsKey(endpoint)) { - storageOptions.put(ak, maps.get(ak)); - storageOptions.put(sk, maps.get(sk)); - storageOptions.put(endpoint, maps.get(endpoint)); - storageOptions.put(region, maps.get(region)); - storageOptions.put(virtual_hosted_style, maps.get(virtual_hosted_style)); - } - return storageOptions; + public static int getBatchSize(LanceConfig config) { + Map options = config.getOptions(); + if (options.containsKey(batch_size)) { + return Integer.parseInt(options.get(batch_size)); } + return 512; + } + + public static boolean enableTopNPushDown(LanceConfig config) { + return Boolean.parseBoolean(config.getOptions().getOrDefault(topN_push_down, "true")); + } + public static boolean overwrite(LanceConfig config) { + return config.getOptions().getOrDefault(write_mode, "append").equalsIgnoreCase("overwrite"); + } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java index ff87744b6c1..589b275cd09 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java @@ -11,37 +11,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.internal; import com.lancedb.lance.*; import com.lancedb.lance.spark.LanceConfig; -import com.lancedb.lance.spark.read.LanceInputPartition; import com.lancedb.lance.spark.SparkOptions; +import com.lancedb.lance.spark.read.LanceInputPartition; import com.lancedb.lance.spark.utils.Optional; import com.lancedb.lance.spark.write.LanceArrowWriter; + import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.ArrowUtils; +import org.apache.spark.sql.util.LanceArrowUtils; import java.time.ZoneId; import java.util.List; import java.util.stream.Collectors; public class LanceDatasetAdapter { - private static final BufferAllocator allocator = new RootAllocator( - RootAllocator.configBuilder().from(RootAllocator.defaultConfig()) - .maxAllocation(64 * 1024 * 1024).build()); + public static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); public static Optional getSchema(LanceConfig config) { String uri = config.getDatasetUri(); ReadOptions options = SparkOptions.genReadOptionFromConfig(config); try (Dataset dataset = Dataset.open(allocator, uri, options)) { - return Optional.of(ArrowUtils.fromArrowSchema(dataset.getSchema())); + return Optional.of(LanceArrowUtils.fromArrowSchema(dataset.getSchema())); } catch (IllegalArgumentException e) { // dataset not found return Optional.empty(); @@ -50,7 +49,29 @@ public static Optional getSchema(LanceConfig config) { public static Optional getSchema(String datasetUri) { try (Dataset dataset = Dataset.open(datasetUri, allocator)) { - return Optional.of(ArrowUtils.fromArrowSchema(dataset.getSchema())); + return Optional.of(LanceArrowUtils.fromArrowSchema(dataset.getSchema())); + } catch (IllegalArgumentException e) { + // dataset not found + return Optional.empty(); + } + } + + public static Optional getDatasetRowCount(LanceConfig config) { + String uri = config.getDatasetUri(); + ReadOptions options = SparkOptions.genReadOptionFromConfig(config); + try (Dataset dataset = Dataset.open(allocator, uri, options)) { + return Optional.of(dataset.countRows()); + } catch (IllegalArgumentException e) { + // dataset not found + return Optional.empty(); + } + } + + public static Optional getDatasetDataSize(LanceConfig config) { + String uri = config.getDatasetUri(); + ReadOptions options = SparkOptions.genReadOptionFromConfig(config); + try (Dataset dataset = Dataset.open(allocator, uri, options)) { + return Optional.of(dataset.calculateDataSize()); } catch (IllegalArgumentException e) { // dataset not found return Optional.empty(); @@ -61,14 +82,13 @@ public static List getFragmentIds(LanceConfig config) { String uri = config.getDatasetUri(); ReadOptions options = SparkOptions.genReadOptionFromConfig(config); try (Dataset dataset = Dataset.open(allocator, uri, options)) { - return dataset.getFragments().stream() - .map(DatasetFragment::getId).collect(Collectors.toList()); + return dataset.getFragments().stream().map(Fragment::getId).collect(Collectors.toList()); } } - public static LanceFragmentScanner getFragmentScanner(int fragmentId, - LanceInputPartition inputPartition) { - return LanceFragmentScanner.create(fragmentId, inputPartition, allocator); + public static LanceFragmentScanner getFragmentScanner( + int fragmentId, LanceInputPartition inputPartition) { + return LanceFragmentScanner.create(fragmentId, inputPartition); } public static void appendFragments(LanceConfig config, List fragments) { @@ -76,30 +96,58 @@ public static void appendFragments(LanceConfig config, List fr String uri = config.getDatasetUri(); ReadOptions options = SparkOptions.genReadOptionFromConfig(config); try (Dataset datasetRead = Dataset.open(allocator, uri, options)) { + Dataset.commit( + allocator, + config.getDatasetUri(), + appendOp, + java.util.Optional.of(datasetRead.version()), + options.getStorageOptions()) + .close(); + } + } - Dataset.commit(allocator, config.getDatasetUri(), - appendOp, java.util.Optional.of(datasetRead.version()), options.getStorageOptions()) - .close(); + public static void overwriteFragments( + LanceConfig config, List fragments, StructType sparkSchema) { + Schema schema = LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false); + FragmentOperation.Overwrite overwrite = new FragmentOperation.Overwrite(fragments, schema); + String uri = config.getDatasetUri(); + ReadOptions options = SparkOptions.genReadOptionFromConfig(config); + try (Dataset datasetRead = Dataset.open(allocator, uri, options)) { + Dataset.commit( + allocator, + config.getDatasetUri(), + overwrite, + java.util.Optional.of(datasetRead.version()), + options.getStorageOptions()) + .close(); } } public static LanceArrowWriter getArrowWriter(StructType sparkSchema, int batchSize) { - return new LanceArrowWriter(allocator, - ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize); + return new LanceArrowWriter( + allocator, LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize); } - public static FragmentMetadata createFragment(String datasetUri, ArrowReader reader, - WriteParams params) { + public static List createFragment( + String datasetUri, ArrowReader reader, WriteParams params) { try (ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, reader, arrowStream); - return Fragment.create(datasetUri, arrowStream, - java.util.Optional.empty(), params); + return Fragment.create(datasetUri, arrowStream, params); } } public static void createDataset(String datasetUri, StructType sparkSchema, WriteParams params) { - Dataset.create(allocator, datasetUri, - ArrowUtils.toArrowSchema(sparkSchema, ZoneId.systemDefault().getId(), true, false), - params).close(); + Dataset.create( + allocator, + datasetUri, + LanceArrowUtils.toArrowSchema(sparkSchema, ZoneId.systemDefault().getId(), true, false), + params) + .close(); + } + + public static void dropDataset(LanceConfig config) { + String uri = config.getDatasetUri(); + ReadOptions options = SparkOptions.genReadOptionFromConfig(config); + Dataset.drop(uri, options.getStorageOptions()); } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java index 660ec557706..e6b38682168 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java @@ -11,14 +11,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.internal; import com.lancedb.lance.spark.read.LanceInputPartition; + import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; -import org.apache.spark.sql.vectorized.ArrowColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.LanceArrowColumnVector; import java.io.IOException; @@ -27,16 +27,16 @@ public class LanceFragmentColumnarBatchScanner implements AutoCloseable { private final ArrowReader arrowReader; private ColumnarBatch currentColumnarBatch; - public LanceFragmentColumnarBatchScanner(LanceFragmentScanner fragmentScanner, - ArrowReader arrowReader) { + public LanceFragmentColumnarBatchScanner( + LanceFragmentScanner fragmentScanner, ArrowReader arrowReader) { this.fragmentScanner = fragmentScanner; this.arrowReader = arrowReader; } public static LanceFragmentColumnarBatchScanner create( int fragmentId, LanceInputPartition inputPartition) { - LanceFragmentScanner fragmentScanner = LanceDatasetAdapter - .getFragmentScanner(fragmentId, inputPartition); + LanceFragmentScanner fragmentScanner = + LanceDatasetAdapter.getFragmentScanner(fragmentId, inputPartition); return new LanceFragmentColumnarBatchScanner(fragmentScanner, fragmentScanner.getArrowReader()); } @@ -47,16 +47,18 @@ public boolean loadNextBatch() throws IOException { } if (arrowReader.loadNextBatch()) { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); - currentColumnarBatch = new ColumnarBatch(root.getFieldVectors().stream() - .map(ArrowColumnVector::new).toArray(ArrowColumnVector[]::new), root.getRowCount()); + currentColumnarBatch = + new ColumnarBatch( + root.getFieldVectors().stream() + .map(LanceArrowColumnVector::new) + .toArray(LanceArrowColumnVector[]::new), + root.getRowCount()); return true; } return false; } - /** - * @return the current batch, the caller responsible for closing the batch - */ + /** @return the current batch, the caller responsible for closing the batch */ public ColumnarBatch getCurrentBatch() { return currentColumnarBatch; } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java index e71cf33b7e3..41555896c40 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java @@ -11,17 +11,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.internal; import com.lancedb.lance.Dataset; -import com.lancedb.lance.DatasetFragment; +import com.lancedb.lance.Fragment; import com.lancedb.lance.ReadOptions; import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; import com.lancedb.lance.spark.LanceConfig; -import com.lancedb.lance.spark.read.LanceInputPartition; +import com.lancedb.lance.spark.LanceConstant; import com.lancedb.lance.spark.SparkOptions; +import com.lancedb.lance.spark.read.LanceInputPartition; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.spark.sql.types.StructField; @@ -30,34 +34,53 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; public class LanceFragmentScanner implements AutoCloseable { - private Dataset dataset; - private DatasetFragment fragment; + private static LoadingCache> LOADING_CACHE = + CacheBuilder.newBuilder() + .maximumSize(100) + .expireAfterAccess(1, TimeUnit.HOURS) + .build( + new CacheLoader>() { + @Override + public List load(LanceConfig config) throws Exception { + BufferAllocator allocator = LanceDatasetAdapter.allocator; + ReadOptions options = SparkOptions.genReadOptionFromConfig(config); + Dataset dataset = Dataset.open(allocator, config.getDatasetUri(), options); + return dataset.getFragments(); + } + }); private LanceScanner scanner; - private LanceFragmentScanner(Dataset dataset, DatasetFragment fragment, LanceScanner scanner) { - this.dataset = dataset; - this.fragment = fragment; + private LanceFragmentScanner(LanceScanner scanner) { this.scanner = scanner; } - public static LanceFragmentScanner create(int fragmentId, - LanceInputPartition inputPartition, BufferAllocator allocator) { - Dataset dataset = null; - DatasetFragment fragment = null; + public static LanceFragmentScanner create(int fragmentId, LanceInputPartition inputPartition) { LanceScanner scanner = null; try { LanceConfig config = inputPartition.getConfig(); - ReadOptions options = SparkOptions.genReadOptionFromConfig(config); - dataset = Dataset.open(allocator, config.getDatasetUri(), options); - fragment = dataset.getFragments().get(fragmentId); + List cachedFragments = LOADING_CACHE.get(config); + Fragment fragment = cachedFragments.get(fragmentId); ScanOptions.Builder scanOptions = new ScanOptions.Builder(); scanOptions.columns(getColumnNames(inputPartition.getSchema())); if (inputPartition.getWhereCondition().isPresent()) { scanOptions.filter(inputPartition.getWhereCondition().get()); } + scanOptions.batchSize(SparkOptions.getBatchSize(config)); + scanOptions.withRowId(getWithRowId(inputPartition.getSchema())); + scanOptions.withRowAddress(getWithRowAddress(inputPartition.getSchema())); + if (inputPartition.getLimit().isPresent()) { + scanOptions.limit(inputPartition.getLimit().get()); + } + if (inputPartition.getOffset().isPresent()) { + scanOptions.offset(inputPartition.getOffset().get()); + } + if (inputPartition.getTopNSortOrders().isPresent()) { + scanOptions.setColumnOrderings(inputPartition.getTopNSortOrders().get()); + } scanner = fragment.newScan(scanOptions.build()); } catch (Throwable t) { if (scanner != null) { @@ -67,21 +90,12 @@ public static LanceFragmentScanner create(int fragmentId, t.addSuppressed(it); } } - if (dataset != null) { - try { - dataset.close(); - } catch (Throwable it) { - t.addSuppressed(it); - } - } - throw t; + throw new RuntimeException(t); } - return new LanceFragmentScanner(dataset, fragment, scanner); + return new LanceFragmentScanner(scanner); } - /** - * @return the arrow reader. The caller is responsible for closing the reader - */ + /** @return the arrow reader. The caller is responsible for closing the reader */ public ArrowReader getArrowReader() { return scanner.scanBatches(); } @@ -95,14 +109,25 @@ public void close() throws IOException { throw new IOException(e); } } - if (dataset != null) { - dataset.close(); - } } private static List getColumnNames(StructType schema) { return Arrays.stream(schema.fields()) .map(StructField::name) + .filter( + name -> !name.equals(LanceConstant.ROW_ID) && !name.equals(LanceConstant.ROW_ADDRESS)) .collect(Collectors.toList()); } -} \ No newline at end of file + + private static boolean getWithRowId(StructType schema) { + return Arrays.stream(schema.fields()) + .map(StructField::name) + .anyMatch(name -> name.equals(LanceConstant.ROW_ID)); + } + + private static boolean getWithRowAddress(StructType schema) { + return Arrays.stream(schema.fields()) + .map(StructField::name) + .anyMatch(name -> name.equals(LanceConstant.ROW_ADDRESS)); + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java index 7cc30dc74a5..76f9e92cf85 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/FilterPushDown.java @@ -11,10 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.utils.Optional; + import org.apache.spark.sql.sources.And; import org.apache.spark.sql.sources.EqualNullSafe; import org.apache.spark.sql.sources.EqualTo; @@ -35,6 +35,7 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -53,9 +54,10 @@ public static Optional compileFiltersToSqlWhereClause(Filter[] filters) for (Filter filter : filters) { compileFilter(filter).ifPresent(compiledFilters::add); } - String whereClause = compiledFilters.stream() - .map(filter -> "(" + filter + ")") - .collect(Collectors.joining(" AND ")); + String whereClause = + compiledFilters.stream() + .map(filter -> "(" + filter + ")") + .collect(Collectors.joining(" AND ")); return Optional.of(whereClause); } @@ -78,7 +80,7 @@ public static Filter[][] processFilters(Filter[] filters) { Filter[] acceptedArray = acceptedFilters.toArray(new Filter[0]); Filter[] rejectedArray = rejectedFilters.toArray(new Filter[0]); - return new Filter[][]{acceptedArray, rejectedArray}; + return new Filter[][] {acceptedArray, rejectedArray}; } public static boolean isFilterSupported(Filter filter) { @@ -87,7 +89,7 @@ public static boolean isFilterSupported(Filter filter) { } else if (filter instanceof EqualNullSafe) { return false; } else if (filter instanceof In) { - return false; + return true; } else if (filter instanceof LessThan) { return true; } else if (filter instanceof LessThanOrEqual) { @@ -149,12 +151,11 @@ private static Optional compileFilter(Filter filter) { Optional right = compileFilter(f.right()); if (left.isEmpty()) return right; if (right.isEmpty()) return left; - return Optional.of(String.format("(%s) AND (%s)", - left.get(), right.get())); + return Optional.of(String.format("(%s) AND (%s)", left.get(), right.get())); } else if (filter instanceof IsNull) { IsNull f = (IsNull) filter; return Optional.of(String.format("%s IS NULL", f.attribute())); - } else if (filter instanceof IsNotNull) { + } else if (filter instanceof IsNotNull) { IsNotNull f = (IsNotNull) filter; return Optional.of(String.format("%s IS NOT NULL", f.attribute())); } else if (filter instanceof Not) { @@ -162,6 +163,13 @@ private static Optional compileFilter(Filter filter) { Optional child = compileFilter(f.child()); if (child.isEmpty()) return child; return Optional.of(String.format("NOT (%s)", child.get())); + } else if (filter instanceof In) { + In in = (In) filter; + String values = + Arrays.stream(in.values()) + .map(FilterPushDown::compileValue) + .collect(Collectors.joining(",")); + return Optional.of(String.format("%s IN (%s)", in.attribute(), values)); } return Optional.empty(); diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReader.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReader.java index 5745709823d..15f96c72094 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReader.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReader.java @@ -11,10 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.internal.LanceFragmentColumnarBatchScanner; + import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -40,9 +40,9 @@ public boolean next() throws IOException { if (fragmentReader != null) { fragmentReader.close(); } - fragmentReader = LanceFragmentColumnarBatchScanner.create( - inputPartition.getLanceSplit().getFragments().get(fragmentIndex), - inputPartition); + fragmentReader = + LanceFragmentColumnarBatchScanner.create( + inputPartition.getLanceSplit().getFragments().get(fragmentIndex), inputPartition); fragmentIndex++; if (loadNextBatchFromCurrentReader()) { return true; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceInputPartition.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceInputPartition.java index 3525502a63b..d0e72009cb1 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceInputPartition.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceInputPartition.java @@ -11,14 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; +import com.lancedb.lance.ipc.ColumnOrdering; import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.utils.Optional; + import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.types.StructType; +import java.util.List; + public class LanceInputPartition implements InputPartition { private static final long serialVersionUID = 4723894723984723984L; @@ -27,14 +30,61 @@ public class LanceInputPartition implements InputPartition { private final LanceSplit lanceSplit; private final LanceConfig config; private final Optional whereCondition; + private final Optional limit; + private final Optional offset; + private final Optional> topNSortOrders; + + public LanceInputPartition( + StructType schema, + int partitionId, + LanceSplit lanceSplit, + LanceConfig config, + Optional whereCondition) { + this.schema = schema; + this.partitionId = partitionId; + this.lanceSplit = lanceSplit; + this.config = config; + this.whereCondition = whereCondition; + this.limit = Optional.empty(); + this.offset = Optional.empty(); + this.topNSortOrders = Optional.empty(); + } - public LanceInputPartition(StructType schema, int partitionId, - LanceSplit lanceSplit, LanceConfig config, Optional whereCondition) { + public LanceInputPartition( + StructType schema, + int partitionId, + LanceSplit lanceSplit, + LanceConfig config, + Optional whereCondition, + Optional limit, + Optional offset) { this.schema = schema; this.partitionId = partitionId; this.lanceSplit = lanceSplit; this.config = config; this.whereCondition = whereCondition; + this.limit = limit; + this.offset = offset; + this.topNSortOrders = Optional.empty(); + } + + public LanceInputPartition( + StructType schema, + int partitionId, + LanceSplit lanceSplit, + LanceConfig config, + Optional whereCondition, + Optional limit, + Optional offset, + Optional> topNSortOrders) { + this.schema = schema; + this.partitionId = partitionId; + this.lanceSplit = lanceSplit; + this.config = config; + this.whereCondition = whereCondition; + this.limit = limit; + this.offset = offset; + this.topNSortOrders = topNSortOrders; } public StructType getSchema() { @@ -56,4 +106,16 @@ public LanceConfig getConfig() { public Optional getWhereCondition() { return whereCondition; } + + public Optional getLimit() { + return limit; + } + + public Optional getOffset() { + return offset; + } + + public Optional> getTopNSortOrders() { + return topNSortOrders; + } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceRowPartitionReader.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceRowPartitionReader.java index 88c105e7c77..f847365185d 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceRowPartitionReader.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceRowPartitionReader.java @@ -11,11 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; -import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.IOException; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java index 382cae20d30..3913b1bcf5a 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScan.java @@ -11,11 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; +import com.lancedb.lance.ipc.ColumnOrdering; import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.utils.Optional; + import org.apache.arrow.util.Preconditions; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.read.Batch; @@ -23,24 +24,42 @@ import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.Statistics; +import org.apache.spark.sql.connector.read.SupportsReportStatistics; +import org.apache.spark.sql.internal.connector.SupportsMetadata; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; +import scala.collection.immutable.Map; +import scala.collection.mutable.HashMap; import java.io.Serializable; import java.util.List; import java.util.stream.IntStream; -public class LanceScan implements Batch, Scan, Serializable { +public class LanceScan + implements Batch, Scan, SupportsMetadata, SupportsReportStatistics, Serializable { private static final long serialVersionUID = 947284762748623947L; private final StructType schema; - private final LanceConfig options; + private final LanceConfig config; private final Optional whereConditions; + private final Optional limit; + private final Optional offset; + private final Optional> topNSortOrders; - public LanceScan(StructType schema, LanceConfig options, Optional whereConditions) { + public LanceScan( + StructType schema, + LanceConfig config, + Optional whereConditions, + Optional limit, + Optional offset, + Optional> topNSortOrders) { this.schema = schema; - this.options = options; + this.config = config; this.whereConditions = whereConditions; + this.limit = limit; + this.offset = offset; + this.topNSortOrders = topNSortOrders; } @Override @@ -50,9 +69,19 @@ public Batch toBatch() { @Override public InputPartition[] planInputPartitions() { - List splits = LanceSplit.generateLanceSplits(options); + List splits = LanceSplit.generateLanceSplits(config); return IntStream.range(0, splits.size()) - .mapToObj(i -> new LanceInputPartition(schema, i, splits.get(i), options, whereConditions)) + .mapToObj( + i -> + new LanceInputPartition( + schema, + i, + splits.get(i), + config, + whereConditions, + limit, + offset, + topNSortOrders)) .toArray(InputPartition[]::new); } @@ -66,17 +95,34 @@ public StructType readSchema() { return schema; } + @Override + public Map getMetaData() { + HashMap hashMap = new HashMap<>(); + hashMap.put("whereConditions", whereConditions.toString()); + hashMap.put("limit", limit.toString()); + hashMap.put("offset", offset.toString()); + hashMap.put("topNSortOrders", topNSortOrders.toString()); + return hashMap.toMap(scala.Predef.conforms()); + } + + @Override + public Statistics estimateStatistics() { + return new LanceStatistics(config); + } + private class LanceReaderFactory implements PartitionReaderFactory { @Override public PartitionReader createReader(InputPartition partition) { - Preconditions.checkArgument(partition instanceof LanceInputPartition, + Preconditions.checkArgument( + partition instanceof LanceInputPartition, "Unknown InputPartition type. Expecting LanceInputPartition"); return LanceRowPartitionReader.create((LanceInputPartition) partition); } @Override public PartitionReader createColumnarReader(InputPartition partition) { - Preconditions.checkArgument(partition instanceof LanceInputPartition, + Preconditions.checkArgument( + partition instanceof LanceInputPartition, "Unknown InputPartition type. Expecting LanceInputPartition"); return new LanceColumnarPartitionReader((LanceInputPartition) partition); } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScanBuilder.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScanBuilder.java index 9fba4601c33..b1507fbfe86 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScanBuilder.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceScanBuilder.java @@ -11,33 +11,53 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; +import com.lancedb.lance.ipc.ColumnOrdering; import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.SparkOptions; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; import com.lancedb.lance.spark.utils.Optional; + +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NullOrdering; +import org.apache.spark.sql.connector.expressions.SortDirection; +import org.apache.spark.sql.connector.expressions.SortOrder; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.connector.read.SupportsPushDownLimit; +import org.apache.spark.sql.connector.read.SupportsPushDownOffset; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.connector.read.SupportsPushDownTopN; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.StructType; -public class LanceScanBuilder implements - SupportsPushDownRequiredColumns, SupportsPushDownFilters { - private final LanceConfig options; +import java.util.ArrayList; +import java.util.List; + +public class LanceScanBuilder + implements SupportsPushDownRequiredColumns, + SupportsPushDownFilters, + SupportsPushDownLimit, + SupportsPushDownOffset, + SupportsPushDownTopN { + private final LanceConfig config; private StructType schema; private Filter[] pushedFilters = new Filter[0]; + private Optional limit = Optional.empty(); + private Optional offset = Optional.empty(); + private Optional> topNSortOrders = Optional.empty(); - public LanceScanBuilder(StructType schema, LanceConfig options) { + public LanceScanBuilder(StructType schema, LanceConfig config) { this.schema = schema; - this.options = options; + this.config = config; } @Override public Scan build() { Optional whereCondition = FilterPushDown.compileFiltersToSqlWhereClause(pushedFilters); - return new LanceScan(schema, options, whereCondition); + return new LanceScan(schema, config, whereCondition, limit, offset, topNSortOrders); } @Override @@ -50,7 +70,7 @@ public void pruneColumns(StructType requiredSchema) { @Override public Filter[] pushFilters(Filter[] filters) { - if (!options.isPushDownFilters()) { + if (!config.isPushDownFilters()) { return filters; } Filter[][] processFilters = FilterPushDown.processFilters(filters); @@ -62,4 +82,50 @@ public Filter[] pushFilters(Filter[] filters) { public Filter[] pushedFilters() { return pushedFilters; } + + @Override + public boolean pushLimit(int limit) { + this.limit = Optional.of(limit); + return true; + } + + @Override + public boolean pushOffset(int offset) { + // Only one data file can be pushed down the offset. + if (LanceDatasetAdapter.getFragmentIds(config).size() == 1) { + this.offset = Optional.of(offset); + return true; + } else { + return false; + } + } + + @Override + public boolean isPartiallyPushed() { + return true; + } + + @Override + public boolean pushTopN(SortOrder[] orders, int limit) { + // The Order by operator will use compute thread in lance. + // So it's better to have an option to enable it. + if (!SparkOptions.enableTopNPushDown(this.config)) { + return false; + } + this.limit = Optional.of(limit); + List topNSortOrders = new ArrayList<>(); + for (SortOrder sortOrder : orders) { + ColumnOrdering.Builder builder = new ColumnOrdering.Builder(); + builder.setNullFirst(sortOrder.nullOrdering() == NullOrdering.NULLS_FIRST); + builder.setAscending(sortOrder.direction() == SortDirection.ASCENDING); + if (!(sortOrder.expression() instanceof FieldReference)) { + return false; + } + FieldReference reference = (FieldReference) sortOrder.expression(); + builder.setColumnName(reference.fieldNames()[0]); + topNSortOrders.add(builder.build()); + } + this.topNSortOrders = Optional.of(topNSortOrders); + return true; + } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceSplit.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceSplit.java index 4e46b464df8..d3e15e73c0f 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceSplit.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceSplit.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.LanceConfig; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java new file mode 100644 index 00000000000..cb098caf42e --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/read/LanceStatistics.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.spark.read; + +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.internal.LanceDatasetAdapter; +import com.lancedb.lance.spark.utils.Optional; + +import org.apache.spark.sql.connector.read.Statistics; + +import java.util.OptionalLong; + +public class LanceStatistics implements Statistics { + private final Optional rowNumber; + private final Optional dataBytesSize; + + public LanceStatistics(LanceConfig config) { + this.rowNumber = LanceDatasetAdapter.getDatasetRowCount(config); + this.dataBytesSize = LanceDatasetAdapter.getDatasetDataSize(config); + } + + @Override + public OptionalLong sizeInBytes() { + if (dataBytesSize.isPresent()) { + return OptionalLong.of(dataBytesSize.get()); + } else { + return OptionalLong.empty(); + } + } + + @Override + public OptionalLong numRows() { + if (rowNumber.isPresent()) { + return OptionalLong.of(rowNumber.get()); + } else { + return OptionalLong.empty(); + } + } +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/utils/Optional.java b/java/spark/src/main/java/com/lancedb/lance/spark/utils/Optional.java index 5c4df517317..c2ed19de23c 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/utils/Optional.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/utils/Optional.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.utils; import java.io.Serializable; diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java index f7fa2e6f450..ca4ea36ed0f 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import com.google.common.base.Preconditions; @@ -22,22 +21,17 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter; import javax.annotation.concurrent.GuardedBy; + import java.io.IOException; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -/** - * A custom arrow reader that supports writes Spark internal rows while reading data in batches. - */ +/** A custom arrow reader that supports writes Spark internal rows while reading data in batches. */ public class LanceArrowWriter extends ArrowReader { private final Schema schema; private final int batchSize; - private final Object monitor = new Object(); - @GuardedBy("monitor") - private final Queue rowQueue = new ConcurrentLinkedQueue<>(); + @GuardedBy("monitor") private volatile boolean finished; @@ -52,7 +46,6 @@ public LanceArrowWriter(BufferAllocator allocator, Schema schema, int batchSize) Preconditions.checkNotNull(schema); Preconditions.checkArgument(batchSize > 0); this.schema = schema; - // TODO(lu) batch size as config? this.batchSize = batchSize; this.writeToken = new Semaphore(0); this.loadToken = new Semaphore(0); @@ -69,7 +62,7 @@ void write(InternalRow row) { loadToken.release(); } } catch (InterruptedException e) { - throw new RuntimeException(e); + throw new RuntimeException(e); } } @@ -109,7 +102,7 @@ public boolean loadNextBatch() throws IOException { } } } catch (InterruptedException e) { - throw new RuntimeException(e); + throw new RuntimeException(e); } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/BatchAppend.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceBatchWrite.java similarity index 75% rename from java/spark/src/main/java/com/lancedb/lance/spark/write/BatchAppend.java rename to java/spark/src/main/java/com/lancedb/lance/spark/write/LanceBatchWrite.java index bf41dcefe8c..40e176fcb3d 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/BatchAppend.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceBatchWrite.java @@ -11,12 +11,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import com.lancedb.lance.FragmentMetadata; import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.SparkOptions; import com.lancedb.lance.spark.internal.LanceDatasetAdapter; + import org.apache.spark.sql.connector.write.BatchWrite; import org.apache.spark.sql.connector.write.DataWriterFactory; import org.apache.spark.sql.connector.write.PhysicalWriteInfo; @@ -27,13 +28,15 @@ import java.util.List; import java.util.stream.Collectors; -public class BatchAppend implements BatchWrite { +public class LanceBatchWrite implements BatchWrite { private final StructType schema; private final LanceConfig config; + private final boolean overwrite; - public BatchAppend(StructType schema, LanceConfig config) { + public LanceBatchWrite(StructType schema, LanceConfig config, boolean overwrite) { this.schema = schema; this.config = config; + this.overwrite = overwrite; } @Override @@ -48,12 +51,17 @@ public boolean useCommitCoordinator() { @Override public void commit(WriterCommitMessage[] messages) { - List fragments = Arrays.stream(messages) - .map(m -> (TaskCommit) m) - .map(TaskCommit::getFragments) - .flatMap(List::stream) - .collect(Collectors.toList()); - LanceDatasetAdapter.appendFragments(config, fragments); + List fragments = + Arrays.stream(messages) + .map(m -> (TaskCommit) m) + .map(TaskCommit::getFragments) + .flatMap(List::stream) + .collect(Collectors.toList()); + if (overwrite || SparkOptions.overwrite(this.config)) { + LanceDatasetAdapter.overwriteFragments(config, fragments, schema); + } else { + LanceDatasetAdapter.appendFragments(config, fragments); + } } @Override diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java index 706b6144d19..618837c98b4 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import com.lancedb.lance.FragmentMetadata; @@ -19,6 +18,7 @@ import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.SparkOptions; import com.lancedb.lance.spark.internal.LanceDatasetAdapter; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.DataWriterFactory; @@ -26,19 +26,20 @@ import org.apache.spark.sql.types.StructType; import java.io.IOException; -import java.util.Arrays; +import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; public class LanceDataWriter implements DataWriter { private LanceArrowWriter arrowWriter; - private FutureTask fragmentCreationTask; + private FutureTask> fragmentCreationTask; private Thread fragmentCreationThread; - private LanceDataWriter(LanceArrowWriter arrowWriter, - FutureTask fragmentCreationTask, Thread fragmentCreationThread) { - // TODO support write to multiple fragments + private LanceDataWriter( + LanceArrowWriter arrowWriter, + FutureTask> fragmentCreationTask, + Thread fragmentCreationThread) { this.arrowWriter = arrowWriter; this.fragmentCreationThread = fragmentCreationThread; this.fragmentCreationTask = fragmentCreationTask; @@ -53,8 +54,8 @@ public void write(InternalRow record) throws IOException { public WriterCommitMessage commit() throws IOException { arrowWriter.setFinished(); try { - FragmentMetadata fragmentMetadata = fragmentCreationTask.get(); - return new BatchAppend.TaskCommit(Arrays.asList(fragmentMetadata)); + List fragmentMetadata = fragmentCreationTask.get(); + return new LanceBatchWrite.TaskCommit(fragmentMetadata); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException("Interrupted while waiting for reader thread to finish", e); @@ -91,15 +92,16 @@ protected WriterFactory(StructType schema, LanceConfig config) { @Override public DataWriter createWriter(int partitionId, long taskId) { - LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, 1024); + int batch_size = SparkOptions.getBatchSize(config); + LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, batch_size); WriteParams params = SparkOptions.genWriteParamsFromConfig(config); - Callable fragmentCreator - = () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params); - FutureTask fragmentCreationTask = new FutureTask<>(fragmentCreator); + Callable> fragmentCreator = + () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params); + FutureTask> fragmentCreationTask = new FutureTask<>(fragmentCreator); Thread fragmentCreationThread = new Thread(fragmentCreationTask); fragmentCreationThread.start(); return new LanceDataWriter(arrowWriter, fragmentCreationTask, fragmentCreationThread); } } -} \ No newline at end of file +} diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/SparkWrite.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/SparkWrite.java index 857387d018d..3fefef2a022 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/SparkWrite.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/SparkWrite.java @@ -11,31 +11,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import com.lancedb.lance.spark.LanceConfig; + import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.SupportsTruncate; import org.apache.spark.sql.connector.write.Write; import org.apache.spark.sql.connector.write.WriteBuilder; import org.apache.spark.sql.connector.write.streaming.StreamingWrite; import org.apache.spark.sql.types.StructType; -/** - * Spark write builder. - */ +/** Spark write builder. */ public class SparkWrite implements Write { private final LanceConfig config; private final StructType schema; + private final boolean overwrite; - SparkWrite(StructType schema, LanceConfig config) { + SparkWrite(StructType schema, LanceConfig config, boolean overwrite) { this.schema = schema; this.config = config; + this.overwrite = overwrite; } @Override public BatchWrite toBatch() { - return new BatchAppend(schema, config); + return new LanceBatchWrite(schema, config, overwrite); } @Override @@ -44,19 +45,25 @@ public StreamingWrite toStreaming() { } /** Task commit. */ - - public static class SparkWriteBuilder implements WriteBuilder { - private final LanceConfig options; + public static class SparkWriteBuilder implements SupportsTruncate, WriteBuilder { + private final LanceConfig config; private final StructType schema; + private boolean overwrite = false; - public SparkWriteBuilder(StructType schema, LanceConfig options) { + public SparkWriteBuilder(StructType schema, LanceConfig config) { this.schema = schema; - this.options = options; + this.config = config; } @Override public Write build() { - return new SparkWrite(schema, options); + return new SparkWrite(schema, config, overwrite); + } + + @Override + public WriteBuilder truncate() { + this.overwrite = true; + return this; } } -} \ No newline at end of file +} diff --git a/java/spark/src/main/java/org/apache/spark/sql/vectorized/LanceArrowColumnVector.java b/java/spark/src/main/java/org/apache/spark/sql/vectorized/LanceArrowColumnVector.java new file mode 100644 index 00000000000..7b4eb9efd1d --- /dev/null +++ b/java/spark/src/main/java/org/apache/spark/sql/vectorized/LanceArrowColumnVector.java @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.util.LanceArrowUtils; +import org.apache.spark.unsafe.types.UTF8String; + +public class LanceArrowColumnVector extends ColumnVector { + private UInt8Accessor uInt8Accessor; + private ArrowColumnVector arrowColumnVector; + + public LanceArrowColumnVector(ValueVector vector) { + super(LanceArrowUtils.fromArrowField(vector.getField())); + if (vector instanceof UInt8Vector) { + uInt8Accessor = new UInt8Accessor((UInt8Vector) vector); + } else { + arrowColumnVector = new ArrowColumnVector(vector); + } + } + + @Override + public void close() { + if (uInt8Accessor != null) { + uInt8Accessor.close(); + } + if (arrowColumnVector != null) { + arrowColumnVector.close(); + } + } + + @Override + public boolean hasNull() { + if (uInt8Accessor != null) { + return uInt8Accessor.getNullCount() > 0; + } + if (arrowColumnVector != null) { + return arrowColumnVector.hasNull(); + } + return false; + } + + @Override + public int numNulls() { + if (uInt8Accessor != null) { + return uInt8Accessor.getNullCount(); + } + if (arrowColumnVector != null) { + return arrowColumnVector.numNulls(); + } + return 0; + } + + @Override + public boolean isNullAt(int rowId) { + if (uInt8Accessor != null) { + return uInt8Accessor.isNullAt(rowId); + } + if (arrowColumnVector != null) { + return arrowColumnVector.isNullAt(rowId); + } + return false; + } + + @Override + public boolean getBoolean(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getBoolean(rowId); + } + return false; + } + + @Override + public byte getByte(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getByte(rowId); + } + return 0; + } + + @Override + public short getShort(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getShort(rowId); + } + return 0; + } + + @Override + public int getInt(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getInt(rowId); + } + return 0; + } + + @Override + public long getLong(int rowId) { + if (uInt8Accessor != null) { + return uInt8Accessor.getLong(rowId); + } + if (arrowColumnVector != null) { + return arrowColumnVector.getLong(rowId); + } + return 0L; + } + + @Override + public float getFloat(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getFloat(rowId); + } + return 0; + } + + @Override + public double getDouble(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getDouble(rowId); + } + return 0; + } + + @Override + public ColumnarArray getArray(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getArray(rowId); + } + return null; + } + + @Override + public ColumnarMap getMap(int ordinal) { + if (arrowColumnVector != null) { + return arrowColumnVector.getMap(ordinal); + } + return null; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (arrowColumnVector != null) { + return arrowColumnVector.getDecimal(rowId, precision, scale); + } + return null; + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getUTF8String(rowId); + } + return null; + } + + @Override + public byte[] getBinary(int rowId) { + if (arrowColumnVector != null) { + return arrowColumnVector.getBinary(rowId); + } + return new byte[0]; + } + + @Override + public ColumnVector getChild(int ordinal) { + if (arrowColumnVector != null) { + return arrowColumnVector.getChild(ordinal); + } + return null; + } +} diff --git a/java/spark/src/main/java/org/apache/spark/sql/vectorized/UInt8Accessor.java b/java/spark/src/main/java/org/apache/spark/sql/vectorized/UInt8Accessor.java new file mode 100644 index 00000000000..f3809df93a6 --- /dev/null +++ b/java/spark/src/main/java/org/apache/spark/sql/vectorized/UInt8Accessor.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import org.apache.arrow.vector.UInt8Vector; + +// UInt8Accessor can't extend the ArrowVectorAccessor since it's package private. +public class UInt8Accessor { + private final UInt8Vector accessor; + + UInt8Accessor(UInt8Vector vector) { + this.accessor = vector; + } + + final long getLong(int rowId) { + return accessor.getObjectNoOverflow(rowId).longValueExact(); + } + + final boolean isNullAt(int rowId) { + return accessor.isNull(rowId); + } + + final int getNullCount() { + return accessor.getNullCount(); + } + + final void close() { + accessor.close(); + } +} diff --git a/java/spark/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala b/java/spark/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala new file mode 100644 index 00000000000..1a93f0d8221 --- /dev/null +++ b/java/spark/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.util + +/* + * The following code is originally from https://github.com/apache/spark/blob/master/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala + * and is licensed under the Apache license: + * + * License: Apache License 2.0, Copyright 2014 and onwards The Apache Software Foundation. + * https://github.com/apache/spark/blob/master/LICENSE + * + * It has been modified by the Lance developers to fit the needs of the Lance project. + */ + +import com.lancedb.lance.spark.LanceConstant + +import org.apache.arrow.vector.complex.MapVector +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.sql.types._ + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConverters._ + +object LanceArrowUtils { + def fromArrowField(field: Field): DataType = { + field.getType match { + case int: ArrowType.Int if !int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case _ => ArrowUtils.fromArrowField(field) + } + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }.toArray) + } + + def toArrowSchema( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean = false): Schema = { + new Schema(schema.map { field => + toArrowField( + field.name, + deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames), + field.nullable, + timeZoneId, + largeVarTypes) + }.asJava) + } + + def toArrowField( + name: String, + dt: DataType, + nullable: Boolean, + timeZoneId: String, + largeVarTypes: Boolean = false): Field = { + dt match { + case ArrayType(elementType, containsNull) => + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field( + name, + fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava) + case StructType(fields) => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field( + name, + fieldType, + fields.map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) + }.toSeq.asJava) + case MapType(keyType, valueType, valueContainsNull) => + val mapType = new FieldType(nullable, new ArrowType.Map(false), null) + // Note: Map Type struct can not be null, Struct Type key field can not be null + new Field( + name, + mapType, + Seq(toArrowField( + MapVector.DATA_VECTOR_NAME, + new StructType() + .add(MapVector.KEY_NAME, keyType, nullable = false) + .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), + nullable = false, + timeZoneId, + largeVarTypes)).asJava) + case udt: UserDefinedType[_] => + toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) + case dataType => + val fieldType = + new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes, name), null) + new Field(name, fieldType, Seq.empty[Field].asJava) + } + } + + private def toArrowType( + dt: DataType, + timeZoneId: String, + largeVarTypes: Boolean = false, + name: String): ArrowType = dt match { + case LongType if name.equals(LanceConstant.ROW_ID) => new ArrowType.Int(8 * 8, false) + case _ => ArrowUtils.toArrowType(dt, timeZoneId, largeVarTypes) + } + + private def deduplicateFieldNames( + dt: DataType, + errorOnDuplicatedFieldNames: Boolean): DataType = dt match { + case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) + case st @ StructType(fields) => + val newNames = if (st.names.toSet.size == st.names.length) { + st.names + } else { + if (errorOnDuplicatedFieldNames) { + throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) + } + val genNawName = st.names.groupBy(identity).map { + case (name, names) if names.length > 1 => + val i = new AtomicInteger() + name -> { () => s"${name}_${i.getAndIncrement()}" } + case (name, _) => name -> { () => name } + } + st.names.map(genNawName(_)()) + } + val newFields = + fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => + StructField( + name, + deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), + nullable, + metadata) + } + StructType(newFields) + case ArrayType(elementType, containsNull) => + ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType( + deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames), + deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames), + valueContainsNull) + case _ => dt + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java index 56713f5c39d..06aca8733fc 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/LanceConfigTest.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -28,9 +27,13 @@ public void testLanceConfigFromCaseInsensitiveStringMap() { String dbPath = "file://path/to/db/"; String datasetName = "testDatasetName"; String datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); - CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(new HashMap() {{ - put(LanceConfig.CONFIG_DATASET_URI, datasetUri); - }}); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap( + new HashMap() { + { + put(LanceConfig.CONFIG_DATASET_URI, datasetUri); + } + }); LanceConfig config = LanceConfig.from(options); @@ -44,9 +47,13 @@ public void testLanceConfigFromCaseInsensitiveStringMap2() { String dbPath = "s3://bucket/folder/"; String datasetName = "testDatasetName"; String datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); - CaseInsensitiveStringMap options = new CaseInsensitiveStringMap(new HashMap() {{ - put(LanceConfig.CONFIG_DATASET_URI, datasetUri); - }}); + CaseInsensitiveStringMap options = + new CaseInsensitiveStringMap( + new HashMap() { + { + put(LanceConfig.CONFIG_DATASET_URI, datasetUri); + } + }); LanceConfig config = LanceConfig.from(options); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java b/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java index fb89f166569..85e4661289a 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java @@ -11,12 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark; import com.lancedb.lance.spark.read.LanceInputPartition; import com.lancedb.lance.spark.read.LanceSplit; import com.lancedb.lance.spark.utils.Optional; + import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -30,21 +30,35 @@ public static class TestTable1Config { public static final String dbPath; public static final String datasetName = "test_dataset1"; public static final String datasetUri; - public static final List> expectedValues = Arrays.asList( - Arrays.asList(0L, 0L, 0L, 0L), - Arrays.asList(1L, 2L, 3L, -1L), - Arrays.asList(2L, 4L, 6L, -2L), - Arrays.asList(3L, 6L, 9L, -3L) - ); + public static final List> expectedValues = + Arrays.asList( + Arrays.asList(0L, 0L, 0L, 0L), + Arrays.asList(1L, 2L, 3L, -1L), + Arrays.asList(2L, 4L, 6L, -2L), + Arrays.asList(3L, 6L, 9L, -3L)); + public static final List> expectedValuesWithRowId = + Arrays.asList( + Arrays.asList(0L, 0L, 0L, 0L, 0L), + Arrays.asList(1L, 2L, 3L, -1L, 1L), + Arrays.asList(2L, 4L, 6L, -2L, (1L << 32) + 0L), + Arrays.asList(3L, 6L, 9L, -3L, (1L << 32) + 1L)); + public static final List> expectedValuesWithRowAddress = + Arrays.asList( + Arrays.asList(0L, 0L, 0L, 0L, 0L), + Arrays.asList(1L, 2L, 3L, -1L, 1L), + Arrays.asList(2L, 4L, 6L, -2L, (1L << 32) + 0L), + Arrays.asList(3L, 6L, 9L, -3L, (1L << 32) + 1L)); public static final LanceConfig lanceConfig; - public static final StructType schema = new StructType(new StructField[]{ - DataTypes.createStructField("x", DataTypes.LongType, true), - DataTypes.createStructField("y", DataTypes.LongType, true), - DataTypes.createStructField("b", DataTypes.LongType, true), - DataTypes.createStructField("c", DataTypes.LongType, true), - }); - + public static final StructType schema = + new StructType( + new StructField[] { + DataTypes.createStructField("x", DataTypes.LongType, true), + DataTypes.createStructField("y", DataTypes.LongType, true), + DataTypes.createStructField("b", DataTypes.LongType, true), + DataTypes.createStructField("c", DataTypes.LongType, true), + }); + public static final LanceInputPartition inputPartition; static { @@ -56,7 +70,9 @@ public static class TestTable1Config { } datasetUri = LanceConfig.getDatasetUri(dbPath, datasetName); lanceConfig = LanceConfig.from(datasetUri); - inputPartition = new LanceInputPartition(schema, 0, new LanceSplit(Arrays.asList(0, 1)), lanceConfig, Optional.empty()); + inputPartition = + new LanceInputPartition( + schema, 0, new LanceSplit(Arrays.asList(0, 1)), lanceConfig, Optional.empty()); } } } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java index 5376ba0b7ce..2b9c7855084 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/FilterPushDownTest.java @@ -11,10 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.utils.Optional; + import org.apache.spark.sql.sources.*; import org.junit.jupiter.api.Test; @@ -24,57 +24,83 @@ public class FilterPushDownTest { @Test public void testCompileFiltersToSqlWhereClause() { // Test case 1: GreaterThan, LessThanOrEqual, IsNotNull - Filter[] filters1 = new Filter[]{ - new GreaterThan("age", 30), - new LessThanOrEqual("salary", 100000), - new IsNotNull("name") - }; + Filter[] filters1 = + new Filter[] { + new GreaterThan("age", 30), new LessThanOrEqual("salary", 100000), new IsNotNull("name") + }; Optional whereClause1 = FilterPushDown.compileFiltersToSqlWhereClause(filters1); assertTrue(whereClause1.isPresent()); assertEquals("(age > 30) AND (salary <= 100000) AND (name IS NOT NULL)", whereClause1.get()); // Test case 2: GreaterThan, StringContains, LessThan - Filter[] filters2 = new Filter[]{ - new GreaterThan("age", 30), - new StringContains("name", "John"), - new LessThan("salary", 50000) - }; + Filter[] filters2 = + new Filter[] { + new GreaterThan("age", 30), + new StringContains("name", "John"), + new LessThan("salary", 50000) + }; Optional whereClause2 = FilterPushDown.compileFiltersToSqlWhereClause(filters2); assertTrue(whereClause2.isPresent()); assertEquals("(age > 30) AND (salary < 50000)", whereClause2.get()); // Test case 3: Empty filters array - Filter[] filters3 = new Filter[]{}; + Filter[] filters3 = new Filter[] {}; Optional whereClause3 = FilterPushDown.compileFiltersToSqlWhereClause(filters3); assertFalse(whereClause3.isPresent()); // Test case 4: Mixed supported and unsupported filters - Filter[] filters4 = new Filter[]{ - new GreaterThan("age", 30), - new StringContains("name", "John"), - new IsNull("address"), - new EqualTo("country", "USA") - }; + Filter[] filters4 = + new Filter[] { + new GreaterThan("age", 30), + new StringContains("name", "John"), + new IsNull("address"), + new EqualTo("country", "USA") + }; Optional whereClause4 = FilterPushDown.compileFiltersToSqlWhereClause(filters4); assertTrue(whereClause4.isPresent()); assertEquals("(age > 30) AND (address IS NULL) AND (country == 'USA')", whereClause4.get()); // Test case 5: Not, Or, And combinations - Filter[] filters5 = new Filter[]{ - new Not(new GreaterThan("age", 30)), - new Or(new IsNotNull("name"), new IsNull("address")), - new And(new LessThan("salary", 100000), new GreaterThanOrEqual("salary", 50000)) - }; + Filter[] filters5 = + new Filter[] { + new Not(new GreaterThan("age", 30)), + new Or(new IsNotNull("name"), new IsNull("address")), + new And(new LessThan("salary", 100000), new GreaterThanOrEqual("salary", 50000)) + }; Optional whereClause5 = FilterPushDown.compileFiltersToSqlWhereClause(filters5); assertTrue(whereClause5.isPresent()); - assertEquals("(NOT (age > 30)) AND ((name IS NOT NULL) OR (address IS NULL)) AND ((salary < 100000) AND (salary >= 50000))", whereClause5.get()); + assertEquals( + "(NOT (age > 30)) AND ((name IS NOT NULL) OR (address IS NULL)) AND ((salary < 100000) AND (salary >= 50000))", + whereClause5.get()); } @Test public void testCompileFiltersToSqlWhereClauseWithEmptyFilters() { - Filter[] filters = new Filter[]{}; + Filter[] filters = new Filter[] {}; Optional whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters); assertFalse(whereClause.isPresent()); } -} \ No newline at end of file + + @Test + public void testIntegerInFilterPushDown() { + Object[] values = new Object[2]; + values[0] = 500; + values[1] = 600; + Filter[] filters = new Filter[] {new GreaterThan("age", 30), new In("salary", values)}; + Optional whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters); + assertTrue(whereClause.isPresent()); + assertEquals("(age > 30) AND (salary IN (500,600))", whereClause.get()); + } + + @Test + public void testStringInFilterPushDown() { + Object[] values = new Object[2]; + values[0] = "500"; + values[1] = "600"; + Filter[] filters = new Filter[] {new GreaterThan("age", 30), new In("salary", values)}; + Optional whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters); + assertTrue(whereClause.isPresent()); + assertEquals("(age > 30) AND (salary IN ('500','600'))", whereClause.get()); + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReaderTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReaderTest.java index 23bfd233fce..55d49b94cbc 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReaderTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceColumnarPartitionReaderTest.java @@ -11,15 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; +import com.lancedb.lance.ipc.ColumnOrdering; import com.lancedb.lance.spark.TestUtils; import com.lancedb.lance.spark.utils.Optional; + import org.apache.spark.sql.vectorized.ColumnarBatch; import org.junit.jupiter.api.Test; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -29,8 +31,13 @@ public class LanceColumnarPartitionReaderTest { @Test public void test() throws Exception { LanceSplit split = new LanceSplit(Arrays.asList(0, 1)); - LanceInputPartition partition = new LanceInputPartition( - TestUtils.TestTable1Config.schema, 0, split, TestUtils.TestTable1Config.lanceConfig, Optional.empty()); + LanceInputPartition partition = + new LanceInputPartition( + TestUtils.TestTable1Config.schema, + 0, + split, + TestUtils.TestTable1Config.lanceConfig, + Optional.empty()); try (LanceColumnarPartitionReader reader = new LanceColumnarPartitionReader(partition)) { List> expectedValues = TestUtils.TestTable1Config.expectedValues; int rowIndex = 0; @@ -43,7 +50,8 @@ public void test() throws Exception { for (int j = 0; j < batch.numCols(); j++) { long actualValue = batch.column(j).getLong(i); long expectedValue = expectedValues.get(rowIndex).get(j); - assertEquals(expectedValue, actualValue, "Mismatch at row " + rowIndex + " column " + j); + assertEquals( + expectedValue, actualValue, "Mismatch at row " + rowIndex + " column " + j); } rowIndex++; } @@ -53,4 +61,77 @@ public void test() throws Exception { assertEquals(expectedValues.size(), rowIndex); } } + + @Test + public void testOffsetAndLimit() throws Exception { + LanceSplit split = new LanceSplit(Collections.singletonList(0)); + LanceInputPartition partition = + new LanceInputPartition( + TestUtils.TestTable1Config.schema, + 0, + split, + TestUtils.TestTable1Config.lanceConfig, + Optional.empty(), + Optional.of(1), + Optional.of(1)); + try (LanceColumnarPartitionReader reader = new LanceColumnarPartitionReader(partition)) { + List> expectedValues = TestUtils.TestTable1Config.expectedValues; + int rowIndex = 1; + + while (reader.next()) { + ColumnarBatch batch = reader.get(); + assertNotNull(batch); + assertEquals(1, batch.numRows()); + for (int i = 0; i < batch.numRows(); i++) { + for (int j = 0; j < batch.numCols(); j++) { + long actualValue = batch.column(j).getLong(i); + long expectedValue = expectedValues.get(rowIndex).get(j); + assertEquals( + expectedValue, actualValue, "Mismatch at row " + rowIndex + " column " + j); + } + rowIndex++; + } + batch.close(); + } + } + } + + @Test + public void testTopN() throws Exception { + LanceSplit split = new LanceSplit(Collections.singletonList(1)); + ColumnOrdering.Builder builder = new ColumnOrdering.Builder(); + builder.setNullFirst(true); + builder.setAscending(false); + builder.setColumnName("b"); + LanceInputPartition partition = + new LanceInputPartition( + TestUtils.TestTable1Config.schema, + 0, + split, + TestUtils.TestTable1Config.lanceConfig, + Optional.empty(), + Optional.of(1), + Optional.empty(), + Optional.of(Collections.singletonList(builder.build()))); + try (LanceColumnarPartitionReader reader = new LanceColumnarPartitionReader(partition)) { + List> expectedValues = TestUtils.TestTable1Config.expectedValues; + + // Only get the 4th row + int rowIndex = 3; + while (reader.next()) { + ColumnarBatch batch = reader.get(); + assertNotNull(batch); + assertEquals(1, batch.numRows()); + for (int i = 0; i < batch.numRows(); i++) { + for (int j = 0; j < batch.numCols(); j++) { + long actualValue = batch.column(j).getLong(i); + long expectedValue = expectedValues.get(rowIndex).get(j); + assertEquals( + expectedValue, actualValue, "Mismatch at row " + rowIndex + " column " + j); + } + } + batch.close(); + } + } + } } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceDatasetReadTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceDatasetReadTest.java index 6423a13ce03..a64689ed137 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceDatasetReadTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceDatasetReadTest.java @@ -11,13 +11,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.TestUtils; -import com.lancedb.lance.spark.internal.LanceFragmentScanner; import com.lancedb.lance.spark.internal.LanceDatasetAdapter; +import com.lancedb.lance.spark.internal.LanceFragmentScanner; import com.lancedb.lance.spark.utils.Optional; + import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.spark.sql.types.DataTypes; @@ -37,14 +37,16 @@ public class LanceDatasetReadTest { @Test public void testSchema() { StructType expectedSchema = TestUtils.TestTable1Config.schema; - Optional schema = LanceDatasetAdapter.getSchema(TestUtils.TestTable1Config.lanceConfig); + Optional schema = + LanceDatasetAdapter.getSchema(TestUtils.TestTable1Config.lanceConfig); assertTrue(schema.isPresent()); assertEquals(expectedSchema, schema.get()); } @Test public void testFragmentIds() { - List fragments = LanceDatasetAdapter.getFragmentIds(TestUtils.TestTable1Config.lanceConfig); + List fragments = + LanceDatasetAdapter.getFragmentIds(TestUtils.TestTable1Config.lanceConfig); assertEquals(2, fragments.size()); assertEquals(0, fragments.get(0)); assertEquals(1, fragments.get(1)); @@ -52,52 +54,60 @@ public void testFragmentIds() { @Test public void getFragmentScanner() throws IOException { - List> expectedValues = Arrays.asList( - Arrays.asList(0L, 0L, 0L, 0L), - Arrays.asList(1L, 2L, 3L, -1L) - ); + List> expectedValues = + Arrays.asList(Arrays.asList(0L, 0L, 0L, 0L), Arrays.asList(1L, 2L, 3L, -1L)); validateFragment(expectedValues, 0, TestUtils.TestTable1Config.schema); - List> expectedValues1 = Arrays.asList( - Arrays.asList(2L, 4L, 6L, -2L), - Arrays.asList(3L, 6L, 9L, -3L) - ); + List> expectedValues1 = + Arrays.asList(Arrays.asList(2L, 4L, 6L, -2L), Arrays.asList(3L, 6L, 9L, -3L)); validateFragment(expectedValues1, 1, TestUtils.TestTable1Config.schema); - List> expectedValuesColumnsyb = Arrays.asList( - Arrays.asList(4L, 6L), - Arrays.asList(6L, 9L) - ); - validateFragment(expectedValuesColumnsyb, 1, new StructType(new StructField[]{ - DataTypes.createStructField("y", DataTypes.LongType, true), - DataTypes.createStructField("b", DataTypes.LongType, true) - })); - List> expectedValuesColumnsbc = Arrays.asList( - Arrays.asList(0L, 0L), - Arrays.asList(3L, -1L) - ); - validateFragment(expectedValuesColumnsbc, 0, new StructType(new StructField[]{ - DataTypes.createStructField("b", DataTypes.LongType, true), - DataTypes.createStructField("c", DataTypes.LongType, true) - })); + List> expectedValuesColumnsyb = + Arrays.asList(Arrays.asList(4L, 6L), Arrays.asList(6L, 9L)); + validateFragment( + expectedValuesColumnsyb, + 1, + new StructType( + new StructField[] { + DataTypes.createStructField("y", DataTypes.LongType, true), + DataTypes.createStructField("b", DataTypes.LongType, true) + })); + List> expectedValuesColumnsbc = + Arrays.asList(Arrays.asList(0L, 0L), Arrays.asList(3L, -1L)); + validateFragment( + expectedValuesColumnsbc, + 0, + new StructType( + new StructField[] { + DataTypes.createStructField("b", DataTypes.LongType, true), + DataTypes.createStructField("c", DataTypes.LongType, true) + })); } - - public void validateFragment(List> expectedValues, int fragment, StructType schema) throws IOException { - try (LanceFragmentScanner scanner = LanceDatasetAdapter.getFragmentScanner(fragment, - new LanceInputPartition(schema, 0, new LanceSplit(Arrays.asList(fragment)), - TestUtils.TestTable1Config.lanceConfig, Optional.empty()))) { + + public void validateFragment(List> expectedValues, int fragment, StructType schema) + throws IOException { + try (LanceFragmentScanner scanner = + LanceDatasetAdapter.getFragmentScanner( + fragment, + new LanceInputPartition( + schema, + 0, + new LanceSplit(Arrays.asList(fragment)), + TestUtils.TestTable1Config.lanceConfig, + Optional.empty()))) { try (ArrowReader reader = scanner.getArrowReader()) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); assertNotNull(root); - + while (reader.loadNextBatch()) { for (int i = 0; i < root.getRowCount(); i++) { for (int j = 0; j < root.getFieldVectors().size(); j++) { - assertEquals(expectedValues.get(i).get(j), root.getFieldVectors().get(j).getObject(i)); + assertEquals( + expectedValues.get(i).get(j), root.getFieldVectors().get(j).getObject(i)); } } } } } } - + // TODO test_dataset4 [UNSUPPORTED_ARROWTYPE] Unsupported arrow type FixedSizeList(128). } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceFragmentColumnarBatchScannerTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceFragmentColumnarBatchScannerTest.java index cda163db712..c517ef3c720 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceFragmentColumnarBatchScannerTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/LanceFragmentColumnarBatchScannerTest.java @@ -11,11 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.TestUtils; import com.lancedb.lance.spark.internal.LanceFragmentColumnarBatchScanner; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; import org.junit.jupiter.api.Test; @@ -28,15 +28,16 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; public class LanceFragmentColumnarBatchScannerTest { - + @Test public void scanner() throws IOException { List> expectedValues = TestUtils.TestTable1Config.expectedValues; int rowIndex = 0; int fragmentId = 0; while (fragmentId <= 1) { - try (LanceFragmentColumnarBatchScanner scanner = LanceFragmentColumnarBatchScanner.create( - fragmentId, TestUtils.TestTable1Config.inputPartition)) { + try (LanceFragmentColumnarBatchScanner scanner = + LanceFragmentColumnarBatchScanner.create( + fragmentId, TestUtils.TestTable1Config.inputPartition)) { while (scanner.loadNextBatch()) { try (ColumnarBatch batch = scanner.getCurrentBatch()) { Iterator rows = batch.rowIterator(); @@ -46,10 +47,13 @@ public void scanner() throws IOException { for (int colIndex = 0; colIndex < row.numFields(); colIndex++) { long actualValue = row.getLong(colIndex); long expectedValue = expectedValues.get(rowIndex).get(colIndex); - assertEquals(expectedValue, actualValue, "Mismatch at row " + rowIndex + " column " + colIndex); + assertEquals( + expectedValue, + actualValue, + "Mismatch at row " + rowIndex + " column " + colIndex); } rowIndex++; - } + } } } } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorLineItemTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorLineItemTest.java index 2aa779ae753..523a12e194d 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorLineItemTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorLineItemTest.java @@ -11,11 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.LanceDataSource; + import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -41,16 +41,23 @@ static void setup() { dbPath = System.getenv("DB_PATH"); parquetPath = System.getenv("PARQUET_PATH"); assumeTrue(dbPath != null && !dbPath.isEmpty(), "DB_PATH environment variable is not set"); - assumeTrue(parquetPath != null && !parquetPath.isEmpty(), "PARQUET_PATH environment variable is not set"); - - spark = SparkSession.builder() - .appName("spark-lance-connector-test") - .master("local") - .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") - .getOrCreate(); - lanceData = spark.read().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath, "lineitem_10")) - .load(); + assumeTrue( + parquetPath != null && !parquetPath.isEmpty(), + "PARQUET_PATH environment variable is not set"); + + spark = + SparkSession.builder() + .appName("spark-lance-connector-test") + .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .getOrCreate(); + lanceData = + spark + .read() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath, "lineitem_10")) + .load(); lanceData.createOrReplaceTempView("lance_dataset"); parquetData = spark.read().parquet(parquetPath); parquetData.createOrReplaceTempView("parquet_dataset"); @@ -69,25 +76,35 @@ public void test() { validateResults(data -> data.filter("l_shipmode = 'TRUCK'").limit(10)); validateResults(data -> data.filter("l_shipmode IS NULL").selectExpr("count(*) as count")); validateResults(data -> data.select("l_shipmode").limit(100)); - validateResults(data -> data.select("l_orderkey", "l_partkey", "l_quantity", "l_extendedprice").limit(10)); + validateResults( + data -> data.select("l_orderkey", "l_partkey", "l_quantity", "l_extendedprice").limit(10)); validateResults(data -> data.groupBy("l_linestatus").avg("l_discount")); - validateResults(data -> data.groupBy("l_partkey").sum("l_quantity").orderBy(desc("sum(l_quantity)")).limit(5)); + validateResults( + data -> + data.groupBy("l_partkey").sum("l_quantity").orderBy(desc("sum(l_quantity)")).limit(5)); validateResults(data -> data.select("l_shipmode").distinct()); - validateResults(data -> data.select("l_orderkey", "l_comment").filter("l_comment LIKE '%express%'")); + validateResults( + data -> data.select("l_orderkey", "l_comment").filter("l_comment LIKE '%express%'")); // OOM in java test, pass in spark, need to enlarge java memory - validateResults(data -> data.select("l_orderkey", "l_partkey", "l_quantity")); - validateResults(data -> data.filter("l_quantity > 30").select("l_orderkey", "l_partkey", "l_quantity")); - validateResults(data -> data.groupBy("l_returnflag").count()); - validateResults(data -> data.filter("l_quantity BETWEEN 5 AND 30")); + validateResults(data -> data.select("l_orderkey", "l_partkey", "l_quantity")); + validateResults( + data -> data.filter("l_quantity > 30").select("l_orderkey", "l_partkey", "l_quantity")); + validateResults(data -> data.groupBy("l_returnflag").count()); + validateResults(data -> data.filter("l_quantity BETWEEN 5 AND 30")); // Not exact same result, but result is correct - Function, Dataset> function = data -> data.select("l_orderkey", "l_commitdate").orderBy("l_commitdate").limit(10); + Function, Dataset> function = + data -> data.select("l_orderkey", "l_commitdate").orderBy("l_commitdate").limit(10); function.apply(lanceData).show(); function.apply(parquetData).show(); // Lance much faster than parquet - validateResults(data -> data.groupBy("l_orderkey").sum("l_extendedprice").orderBy(desc("sum(l_extendedprice)"))); + validateResults( + data -> + data.groupBy("l_orderkey") + .sum("l_extendedprice") + .orderBy(desc("sum(l_extendedprice)"))); // Lance performance issue assertEquals(lanceData.count(), parquetData.count()); @@ -99,29 +116,46 @@ public void sql() { validateSQLResults("SELECT * FROM parquet_dataset LIMIT 10"); validateSQLResults("SELECT l_orderkey, l_partkey FROM parquet_dataset LIMIT 10"); validateSQLResults("SELECT l_extendedprice, l_discount, l_tax FROM parquet_dataset LIMIT 10"); - validateSQLResults("SELECT l_shipmode, COUNT(*) AS count FROM parquet_dataset GROUP BY l_shipmode"); - validateSQLResults("SELECT l_orderkey, SUM(l_extendedprice) AS total_extendedprice FROM parquet_dataset GROUP BY l_orderkey ORDER BY total_extendedprice DESC LIMIT 10"); - validateSQLResults("SELECT l_suppkey, SUM(l_tax) AS total_tax FROM parquet_dataset GROUP BY l_suppkey ORDER BY total_tax DESC LIMIT 5"); - validateSQLResults("SELECT l_orderkey, year(l_shipdate) AS ship_year FROM parquet_dataset GROUP BY l_orderkey, ship_year ORDER BY ship_year LIMIT 10"); - validateSQLResults("SELECT l_orderkey, l_partkey, l_quantity FROM parquet_dataset WHERE l_quantity IS NULL"); - - // LanceError(IO): Received literal Float64(100000) and could not convert to literal of type 'Decimal128(15, 2)', rust/lance/src/datafusion/logical_expr.rs:28:17 + validateSQLResults( + "SELECT l_shipmode, COUNT(*) AS count FROM parquet_dataset GROUP BY l_shipmode"); + validateSQLResults( + "SELECT l_orderkey, SUM(l_extendedprice) AS total_extendedprice FROM parquet_dataset GROUP BY l_orderkey ORDER BY total_extendedprice DESC LIMIT 10"); + validateSQLResults( + "SELECT l_suppkey, SUM(l_tax) AS total_tax FROM parquet_dataset GROUP BY l_suppkey ORDER BY total_tax DESC LIMIT 5"); + validateSQLResults( + "SELECT l_orderkey, year(l_shipdate) AS ship_year FROM parquet_dataset GROUP BY l_orderkey, ship_year ORDER BY ship_year LIMIT 10"); + validateSQLResults( + "SELECT l_orderkey, l_partkey, l_quantity FROM parquet_dataset WHERE l_quantity IS NULL"); + + // LanceError(IO): Received literal Float64(100000) and could not convert to literal of type + // 'Decimal128(15, 2)', rust/lance/src/datafusion/logical_expr.rs:28:17 // spark.sql("SELECT * FROM lineitem WHERE (l_extendedprice <= 100000)").show(); - // spark.sql("SELECT * FROM lineitem2 WHERE (l_quantity > 30) AND (l_extendedprice <= 100000) AND (l_comment IS NOT NULL)").show(); - // spark.sql("SELECT * FROM lineitem WHERE (l_quantity > 30) AND (l_extendedprice < 50000)").show(); - // spark.sql("SELECT * FROM lineitem WHERE NOT (l_quantity > 30) AND ((l_comment IS NOT NULL) OR (l_address IS NULL)) AND ((l_extendedprice < 100000) AND (l_extendedprice >= 50000))").show(); - validateSQLResults("SELECT * FROM parquet_dataset WHERE (l_quantity > 30) AND (l_comment IS NOT NULL)"); + // spark.sql("SELECT * FROM lineitem2 WHERE (l_quantity > 30) AND (l_extendedprice <= 100000) + // AND (l_comment IS NOT NULL)").show(); + // spark.sql("SELECT * FROM lineitem WHERE (l_quantity > 30) AND (l_extendedprice < + // 50000)").show(); + // spark.sql("SELECT * FROM lineitem WHERE NOT (l_quantity > 30) AND ((l_comment IS NOT NULL) OR + // (l_address IS NULL)) AND ((l_extendedprice < 100000) AND (l_extendedprice >= + // 50000))").show(); + validateSQLResults( + "SELECT * FROM parquet_dataset WHERE (l_quantity > 30) AND (l_comment IS NOT NULL)"); } private void validateResults(Function, Dataset> operation) { Dataset resultLance = operation.apply(lanceData); Dataset resultParquet = operation.apply(parquetData); - assertEquals(resultParquet.collectAsList(), resultLance.collectAsList(), "Results differ between Lance and Parquet datasets"); + assertEquals( + resultParquet.collectAsList(), + resultLance.collectAsList(), + "Results differ between Lance and Parquet datasets"); } private void validateSQLResults(String sqlQuery) { Dataset resultLance = spark.sql(sqlQuery.replace("parquet_dataset", "lance_dataset")); Dataset resultParquet = spark.sql(sqlQuery); - assertEquals(resultParquet.collectAsList(), resultLance.collectAsList(), "Results differ between Lance and Parquet datasets for query: " + sqlQuery); + assertEquals( + resultParquet.collectAsList(), + resultLance.collectAsList(), + "Results differ between Lance and Parquet datasets for query: " + sqlQuery); } } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java index fe5a82a6427..3bbdcd5e465 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java @@ -11,12 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.read; import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.LanceDataSource; import com.lancedb.lance.spark.TestUtils; + import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -29,6 +29,7 @@ import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class SparkConnectorReadTest { private static SparkSession spark; @@ -37,15 +38,22 @@ public class SparkConnectorReadTest { @BeforeAll static void setup() { - spark = SparkSession.builder() - .appName("spark-lance-connector-test") - .master("local") - .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") - .getOrCreate(); + spark = + SparkSession.builder() + .appName("spark-lance-connector-test") + .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .getOrCreate(); dbPath = TestUtils.TestTable1Config.dbPath; - data = spark.read().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)) - .load(); + data = + spark + .read() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)) + .load(); + data.createOrReplaceTempView("test_dataset1"); } @AfterAll @@ -79,56 +87,102 @@ public void readAll() { @Test public void filter() { - validateData(data.filter("x > 1"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(0) > 1) - .collect(Collectors.toList())); - validateData(data.filter("y == 4"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(1) == 4) - .collect(Collectors.toList())); - validateData(data.filter("b >= 6"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(2) >= 6) - .collect(Collectors.toList())); - validateData(data.filter("c < -1"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(3) < -1) - .collect(Collectors.toList())); - validateData(data.filter("c <= -1"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(3) <= -1) - .collect(Collectors.toList())); - validateData(data.filter("c == -2"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(3) == -2) - .collect(Collectors.toList())); - validateData(data.filter("x > 1").filter("y < 6"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(0) > 1) - .filter(row -> row.get(1) < 6) - .collect(Collectors.toList())); - validateData(data.filter("x > 1 and y < 6"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> row.get(0) > 1) - .filter(row -> row.get(1) < 6) - .collect(Collectors.toList())); - validateData(data.filter("x > 1 or y < 6"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> (row.get(0) > 1) || (row.get(1) < 6)) - .collect(Collectors.toList())); - validateData(data.filter("(x >= 1 and x <= 2) or (c >= -2 and c < 0)"), TestUtils.TestTable1Config.expectedValues.stream() - .filter(row -> (row.get(0) >= 1 && row.get(0) <= 2) || (row.get(3) >= -2 && row.get(3) < 0)) - .collect(Collectors.toList())); + validateData( + data.filter("x > 1"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(0) > 1) + .collect(Collectors.toList())); + validateData( + data.filter("y == 4"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(1) == 4) + .collect(Collectors.toList())); + validateData( + data.filter("b >= 6"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(2) >= 6) + .collect(Collectors.toList())); + validateData( + data.filter("c < -1"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(3) < -1) + .collect(Collectors.toList())); + validateData( + data.filter("c <= -1"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(3) <= -1) + .collect(Collectors.toList())); + validateData( + data.filter("c == -2"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(3) == -2) + .collect(Collectors.toList())); + validateData( + data.filter("x > 1").filter("y < 6"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(0) > 1) + .filter(row -> row.get(1) < 6) + .collect(Collectors.toList())); + validateData( + data.filter("x > 1 and y < 6"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> row.get(0) > 1) + .filter(row -> row.get(1) < 6) + .collect(Collectors.toList())); + validateData( + data.filter("x > 1 or y < 6"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter(row -> (row.get(0) > 1) || (row.get(1) < 6)) + .collect(Collectors.toList())); + validateData( + data.filter("(x >= 1 and x <= 2) or (c >= -2 and c < 0)"), + TestUtils.TestTable1Config.expectedValues.stream() + .filter( + row -> (row.get(0) >= 1 && row.get(0) <= 2) || (row.get(3) >= -2 && row.get(3) < 0)) + .collect(Collectors.toList())); } @Test public void select() { - validateData(data.select("y", "b"), TestUtils.TestTable1Config.expectedValues.stream() - .map(row -> Arrays.asList(row.get(1), row.get(2))) - .collect(Collectors.toList())); + validateData( + data.select("y", "b"), + TestUtils.TestTable1Config.expectedValues.stream() + .map(row -> Arrays.asList(row.get(1), row.get(2))) + .collect(Collectors.toList())); } @Test public void filterSelect() { - validateData(data.select("y", "b").filter("y > 3"), + validateData( + data.select("y", "b").filter("y > 3"), TestUtils.TestTable1Config.expectedValues.stream() - .map(row -> Arrays.asList(row.get(1), row.get(2))) // "y" is at index 1, "b" is at index 2 + .map( + row -> + Arrays.asList(row.get(1), row.get(2))) // "y" is at index 1, "b" is at index 2 .filter(row -> row.get(0) > 3) .collect(Collectors.toList())); } - - // TODO(lu) support spark.read().format("lance") - // .load(dbPath.resolve(datasetName).toString()); + + @Test + public void supportDataSourceLoadPath() { + Dataset df = + spark + .read() + .format("lance") + .load(LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)); + validateData(df, TestUtils.TestTable1Config.expectedValues); + } + + @Test + public void supportBroadcastJoin() { + Dataset df = + spark.read().format("lance").load(LanceConfig.getDatasetUri(dbPath, "test_dataset3")); + df.createOrReplaceTempView("test_dataset3"); + List desc = + spark + .sql("explain select t1.* from test_dataset1 t1 join test_dataset3 t3 on t1.x = t3.x") + .collectAsList(); + assertEquals(1, desc.size()); + assertTrue(desc.get(0).getString(0).contains("BroadcastHashJoin")); + } } diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowAddress.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowAddress.java new file mode 100644 index 00000000000..8a426b5a830 --- /dev/null +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowAddress.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.spark.read; + +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.LanceDataSource; +import com.lancedb.lance.spark.TestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SparkConnectorReadWithRowAddress { + private static SparkSession spark; + private static String dbPath; + private static Dataset data; + + @BeforeAll + static void setup() { + spark = + SparkSession.builder() + .appName("spark-lance-connector-test") + .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .getOrCreate(); + dbPath = TestUtils.TestTable1Config.dbPath; + data = + spark + .read() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)) + .load(); + } + + @AfterAll + static void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + private void validateData(Dataset data, List> expectedValues) { + List rows = data.collectAsList(); + assertEquals(expectedValues.size(), rows.size()); + + for (int i = 0; i < rows.size(); i++) { + Row row = rows.get(i); + List expectedRow = expectedValues.get(i); + assertEquals(expectedRow.size(), row.size()); + + for (int j = 0; j < expectedRow.size(); j++) { + long expectedValue = expectedRow.get(j); + long actualValue = row.getLong(j); + assertEquals(expectedValue, actualValue, "Mismatch at row " + i + " column " + j); + } + } + } + + @Test + public void readAllWithoutRowAddr() { + validateData(data, TestUtils.TestTable1Config.expectedValues); + } + + @Test + public void readAllWithRowAddr() { + validateData( + data.select("x", "y", "b", "c", "_rowaddr"), + TestUtils.TestTable1Config.expectedValuesWithRowAddress); + } + + @Test + public void select() { + validateData( + data.select("y", "b", "_rowaddr"), + TestUtils.TestTable1Config.expectedValuesWithRowAddress.stream() + .map(row -> Arrays.asList(row.get(1), row.get(2), row.get(4))) + .collect(Collectors.toList())); + } + + @Test + public void filterSelect() { + validateData( + data.select("y", "b", "_rowaddr").filter("y > 3"), + TestUtils.TestTable1Config.expectedValuesWithRowAddress.stream() + .map( + row -> + Arrays.asList( + row.get(1), + row.get(2), + row.get( + 4))) // "y" is at index 1, "b" is at index 2, "_rowaddr" is at index 4 + .filter(row -> row.get(0) > 3) + .collect(Collectors.toList())); + } + + @Test + public void filterSelectByRowAddr() { + validateData( + data.select("y", "b", "_rowaddr").filter("_rowaddr > 3"), + TestUtils.TestTable1Config.expectedValuesWithRowAddress.stream() + .map( + row -> + Arrays.asList( + row.get(1), + row.get(2), + row.get( + 4))) // "y" is at index 1, "b" is at index 2, "_rowaddr" is at index 4 + .filter(row -> row.get(2) > 3) + .collect(Collectors.toList())); + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowId.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowId.java new file mode 100644 index 00000000000..68135fad752 --- /dev/null +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowId.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.spark.read; + +import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.LanceDataSource; +import com.lancedb.lance.spark.TestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SparkConnectorReadWithRowId { + private static SparkSession spark; + private static String dbPath; + private static Dataset data; + + @BeforeAll + static void setup() { + spark = + SparkSession.builder() + .appName("spark-lance-connector-test") + .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .getOrCreate(); + dbPath = TestUtils.TestTable1Config.dbPath; + data = + spark + .read() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName)) + .load(); + } + + @AfterAll + static void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + private void validateData(Dataset data, List> expectedValues) { + List rows = data.collectAsList(); + assertEquals(expectedValues.size(), rows.size()); + + for (int i = 0; i < rows.size(); i++) { + Row row = rows.get(i); + List expectedRow = expectedValues.get(i); + assertEquals(expectedRow.size(), row.size()); + + for (int j = 0; j < expectedRow.size(); j++) { + long expectedValue = expectedRow.get(j); + long actualValue = row.getLong(j); + assertEquals(expectedValue, actualValue, "Mismatch at row " + i + " column " + j); + } + } + } + + @Test + public void readAllWithoutRowId() { + validateData(data, TestUtils.TestTable1Config.expectedValues); + } + + @Test + public void readAllWithRowId() { + validateData( + data.select("x", "y", "b", "c", "_rowid"), + TestUtils.TestTable1Config.expectedValuesWithRowId); + } + + @Test + public void select() { + validateData( + data.select("y", "b", "_rowid"), + TestUtils.TestTable1Config.expectedValuesWithRowId.stream() + .map(row -> Arrays.asList(row.get(1), row.get(2), row.get(4))) + .collect(Collectors.toList())); + } + + @Test + public void filterSelect() { + validateData( + data.select("y", "b", "_rowid").filter("y > 3"), + TestUtils.TestTable1Config.expectedValuesWithRowId.stream() + .map( + row -> + Arrays.asList( + row.get(1), + row.get(2), + row.get(4))) // "y" is at index 1, "b" is at index 2, "_rowid" is at index 4 + .filter(row -> row.get(0) > 3) + .collect(Collectors.toList())); + } + + @Test + public void filterSelectByRowId() { + validateData( + data.select("y", "b", "_rowid").filter("_rowid > 3"), + TestUtils.TestTable1Config.expectedValuesWithRowId.stream() + .map( + row -> + Arrays.asList( + row.get(1), + row.get(2), + row.get(4))) // "y" is at index 1, "b" is at index 2, "_rowid" is at index 4 + .filter(row -> row.get(2) > 3) + .collect(Collectors.toList())); + } +} diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java index 1dbc63ca60b..05734986dff 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java @@ -11,7 +11,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import org.apache.arrow.memory.BufferAllocator; @@ -36,7 +35,11 @@ public class LanceArrowWriterTest { @Test public void test() throws Exception { try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Field field = new Field("column1", FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.INT.getType()), null); + Field field = + new Field( + "column1", + FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.INT.getType()), + null); Schema schema = new Schema(Collections.singletonList(field)); final int totalRows = 125; @@ -47,37 +50,42 @@ public void test() throws Exception { AtomicInteger rowsRead = new AtomicInteger(0); AtomicLong expectedBytesRead = new AtomicLong(0); - Thread writerThread = new Thread(() -> { - try { - for (int i = 0; i < totalRows; i++) { - InternalRow row = new GenericInternalRow(new Object[]{rowsWritten.incrementAndGet()}); - arrowWriter.write(row); - } - arrowWriter.setFinished(); - } catch (Exception e) { - e.printStackTrace(); - throw e; - } - }); + Thread writerThread = + new Thread( + () -> { + try { + for (int i = 0; i < totalRows; i++) { + InternalRow row = + new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); + arrowWriter.write(row); + } + arrowWriter.setFinished(); + } catch (Exception e) { + e.printStackTrace(); + throw e; + } + }); - Thread readerThread = new Thread(() -> { - try { - while (arrowWriter.loadNextBatch()) { - VectorSchemaRoot root = arrowWriter.getVectorSchemaRoot(); - int rowCount = root.getRowCount(); - rowsRead.addAndGet(rowCount); - try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { - expectedBytesRead.addAndGet(recordBatch.computeBodyLength()); - } - for (int i = 0; i < rowCount; i++) { - int value = (int) root.getVector("column1").getObject(i); - assertEquals(value, rowsRead.get() - rowCount + i + 1); - } - } - } catch (Exception e) { - e.printStackTrace(); - } - }); + Thread readerThread = + new Thread( + () -> { + try { + while (arrowWriter.loadNextBatch()) { + VectorSchemaRoot root = arrowWriter.getVectorSchemaRoot(); + int rowCount = root.getRowCount(); + rowsRead.addAndGet(rowCount); + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + expectedBytesRead.addAndGet(recordBatch.computeBodyLength()); + } + for (int i = 0; i < rowCount; i++) { + int value = (int) root.getVector("column1").getObject(i); + assertEquals(value, rowsRead.get() - rowCount + i + 1); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + }); writerThread.start(); readerThread.start(); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceBatchWriteTest.java similarity index 86% rename from java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java rename to java/spark/src/test/java/com/lancedb/lance/spark/write/LanceBatchWriteTest.java index 45c04de4a33..b3bb276b71a 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceBatchWriteTest.java @@ -11,12 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import com.lancedb.lance.Dataset; import com.lancedb.lance.WriteParams; import com.lancedb.lance.spark.LanceConfig; + import org.apache.arrow.dataset.scanner.Scanner; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -30,23 +30,20 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.DataWriterFactory; -import org.apache.spark.sql.connector.write.PhysicalWriteInfo; import org.apache.spark.sql.connector.write.WriterCommitMessage; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.ArrowUtils; +import org.apache.spark.sql.util.LanceArrowUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.io.TempDir; -import java.io.IOException; import java.nio.file.Path; import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertEquals; -public class BatchAppendTest { - @TempDir - static Path tempDir; +public class LanceBatchWriteTest { + @TempDir static Path tempDir; @Test public void testLanceDataWriter(TestInfo testInfo) throws Exception { @@ -60,20 +57,20 @@ public void testLanceDataWriter(TestInfo testInfo) throws Exception { // Append data to lance dataset LanceConfig config = LanceConfig.from(datasetUri); - StructType sparkSchema = ArrowUtils.fromArrowSchema(schema); - BatchAppend batchAppend = new BatchAppend(sparkSchema, config); - DataWriterFactory factor = batchAppend.createBatchWriterFactory(() -> 1); + StructType sparkSchema = LanceArrowUtils.fromArrowSchema(schema); + LanceBatchWrite lanceBatchWrite = new LanceBatchWrite(sparkSchema, config, false); + DataWriterFactory factor = lanceBatchWrite.createBatchWriterFactory(() -> 1); int rows = 132; WriterCommitMessage message; try (DataWriter writer = factor.createWriter(0, 0)) { for (int i = 0; i < rows; i++) { - InternalRow row = new GenericInternalRow(new Object[]{i}); + InternalRow row = new GenericInternalRow(new Object[] {i}); writer.write(row); } message = writer.commit(); } - batchAppend.commit(new WriterCommitMessage[]{message}); + lanceBatchWrite.commit(new WriterCommitMessage[] {message}); // Validate lance dataset data try (Dataset dataset = Dataset.open(datasetUri, allocator)) { diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java index 8ea2c47cd69..211cd6ece8a 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java @@ -11,11 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import com.lancedb.lance.FragmentMetadata; import com.lancedb.lance.spark.LanceConfig; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.ArrowUtils; +import org.apache.spark.sql.util.LanceArrowUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.io.TempDir; @@ -38,8 +38,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class LanceDataWriterTest { - @TempDir - static Path tempDir; + @TempDir static Path tempDir; @Test public void testLanceDataWriter(TestInfo testInfo) throws IOException { @@ -47,18 +46,20 @@ public void testLanceDataWriter(TestInfo testInfo) throws IOException { try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Field field = new Field("column1", FieldType.nullable(new ArrowType.Int(32, true)), null); Schema schema = new Schema(Collections.singletonList(field)); - LanceConfig config = LanceConfig.from(tempDir.resolve(datasetName + LanceConfig.LANCE_FILE_SUFFIX).toString()); - StructType sparkSchema = ArrowUtils.fromArrowSchema(schema); - LanceDataWriter.WriterFactory writerFactory = new LanceDataWriter.WriterFactory(sparkSchema, config); + LanceConfig config = + LanceConfig.from(tempDir.resolve(datasetName + LanceConfig.LANCE_FILE_SUFFIX).toString()); + StructType sparkSchema = LanceArrowUtils.fromArrowSchema(schema); + LanceDataWriter.WriterFactory writerFactory = + new LanceDataWriter.WriterFactory(sparkSchema, config); LanceDataWriter dataWriter = (LanceDataWriter) writerFactory.createWriter(0, 0); int rows = 132; for (int i = 0; i < rows; i++) { - InternalRow row = new GenericInternalRow(new Object[]{i}); + InternalRow row = new GenericInternalRow(new Object[] {i}); dataWriter.write(row); } - BatchAppend.TaskCommit commitMessage = (BatchAppend.TaskCommit) dataWriter.commit(); + LanceBatchWrite.TaskCommit commitMessage = (LanceBatchWrite.TaskCommit) dataWriter.commit(); dataWriter.close(); List fragments = commitMessage.getFragments(); assertEquals(1, fragments.size()); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java index 78c5f9cb12f..e3d64859b1d 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/SparkWriteTest.java @@ -11,11 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package com.lancedb.lance.spark.write; import com.lancedb.lance.spark.LanceConfig; import com.lancedb.lance.spark.LanceDataSource; + import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -27,11 +27,11 @@ import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.api.io.TempDir; +import java.io.File; import java.nio.file.Path; import java.util.Arrays; import java.util.List; @@ -43,26 +43,30 @@ public class SparkWriteTest { private static SparkSession spark; private static Dataset testData; - @TempDir - static Path dbPath; + @TempDir static Path dbPath; @BeforeAll static void setup() { - spark = SparkSession.builder() - .appName("spark-lance-connector-test") - .master("local") - .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") - .getOrCreate(); - StructType schema = new StructType(new StructField[]{ - DataTypes.createStructField("id", DataTypes.IntegerType, false), - DataTypes.createStructField("name", DataTypes.StringType, false) - }); + spark = + SparkSession.builder() + .appName("spark-lance-connector-test") + .master("local") + .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog") + .config("spark.sql.catalog.lance.max_row_per_file", "1") + .getOrCreate(); + StructType schema = + new StructType( + new StructField[] { + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("name", DataTypes.StringType, false) + }); Row row1 = RowFactory.create(1, "Alice"); Row row2 = RowFactory.create(2, "Bob"); List data = Arrays.asList(row1, row2); testData = spark.createDataFrame(data, schema); + testData.createOrReplaceTempView("tmp_view"); } @AfterAll @@ -75,8 +79,12 @@ static void tearDown() { @Test public void defaultWrite(TestInfo testInfo) { String datasetName = testInfo.getTestMethod().get().getName(); - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) .save(); validateData(datasetName, 1); @@ -85,25 +93,43 @@ public void defaultWrite(TestInfo testInfo) { @Test public void errorIfExists(TestInfo testInfo) { String datasetName = testInfo.getTestMethod().get().getName(); - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) .save(); - assertThrows(TableAlreadyExistsException.class, () -> { - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) - .save(); - }); + assertThrows( + TableAlreadyExistsException.class, + () -> { + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .save(); + }); } @Test public void append(TestInfo testInfo) { String datasetName = testInfo.getTestMethod().get().getName(); - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) .save(); - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) .mode("append") .save(); validateData(datasetName, 2); @@ -112,42 +138,121 @@ public void append(TestInfo testInfo) { @Test public void appendErrorIfNotExist(TestInfo testInfo) { String datasetName = testInfo.getTestMethod().get().getName(); - assertThrows(NoSuchTableException.class, () -> { - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) - .mode("append") - .save(); - }); + assertThrows( + NoSuchTableException.class, + () -> { + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .mode("append") + .save(); + }); } @Test public void saveToPath(TestInfo testInfo) { String datasetName = testInfo.getTestMethod().get().getName(); - testData.write().format(LanceDataSource.name) + testData + .write() + .format(LanceDataSource.name) .save(LanceConfig.getDatasetUri(dbPath.toString(), datasetName)); validateData(datasetName, 1); } - @Disabled("Do not support overwrite") @Test public void overwrite(TestInfo testInfo) { String datasetName = testInfo.getTestMethod().get().getName(); - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) .save(); - testData.write().format(LanceDataSource.name) - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) .mode("overwrite") .save(); validateData(datasetName, 1); } + @Test + public void appendAfterOverwrite(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .save(); + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .mode("overwrite") + .save(); + testData + .write() + .format(LanceDataSource.name) + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .mode("append") + .save(); + validateData(datasetName, 2); + } + + @Test + public void writeMultiFiles(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + String filePath = LanceConfig.getDatasetUri(dbPath.toString(), datasetName); + testData + .write() + .format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, filePath) + .save(); + + validateData(datasetName, 1); + File directory = new File(filePath + "/data"); + assertEquals(2, directory.listFiles().length); + } + + @Test + public void writeEmptyTaskFiles(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + String filePath = LanceConfig.getDatasetUri(dbPath.toString(), datasetName); + testData + .repartition(4) + .write() + .format(LanceDataSource.name) + .option(LanceConfig.CONFIG_DATASET_URI, filePath) + .save(); + + File directory = new File(filePath + "/data"); + assertEquals(2, directory.listFiles().length); + } + private void validateData(String datasetName, int iteration) { - Dataset data = spark.read().format("lance") - .option(LanceConfig.CONFIG_DATASET_URI, LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) - .load(); + Dataset data = + spark + .read() + .format("lance") + .option( + LanceConfig.CONFIG_DATASET_URI, + LanceConfig.getDatasetUri(dbPath.toString(), datasetName)) + .load(); assertEquals(2 * iteration, data.count()); assertEquals(iteration, data.filter(col("id").equalTo(1)).count()); @@ -164,4 +269,13 @@ private void validateData(String datasetName, int iteration) { assertEquals("Bob", row.getString(0)); } } -} \ No newline at end of file + + @Test + public void dropAndReplaceTable(TestInfo testInfo) { + String datasetName = testInfo.getTestMethod().get().getName(); + String path = LanceConfig.getDatasetUri(dbPath.toString(), datasetName); + spark.sql("CREATE OR REPLACE TABLE lance.`" + path + "` AS SELECT * FROM tmp_view"); + spark.sql("CREATE OR REPLACE TABLE lance.`" + path + "` AS SELECT * FROM tmp_view"); + spark.sql("DROP TABLE lance.`" + path + "`"); + } +} diff --git a/java/spark/src/test/scala/org/apache/spark/sql/util/LanceArrowUtilsSuite.scala b/java/spark/src/test/scala/org/apache/spark/sql/util/LanceArrowUtilsSuite.scala new file mode 100644 index 00000000000..1dd337feca1 --- /dev/null +++ b/java/spark/src/test/scala/org/apache/spark/sql/util/LanceArrowUtilsSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.util + +/* + * The following code is originally from https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala + * and is licensed under the Apache license: + * + * License: Apache License 2.0, Copyright 2014 and onwards The Apache Software Foundation. + * https://github.com/apache/spark/blob/master/LICENSE + * + * It has been modified by the Lance developers to fit the needs of the Lance project. + */ + +import com.lancedb.lance.spark.LanceConstant + +import org.apache.arrow.vector.types.pojo.ArrowType +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite + +import java.time.ZoneId + +class LanceArrowUtilsSuite extends AnyFunSuite { + def roundtrip(dt: DataType, fieldName: String = "value"): Unit = { + dt match { + case schema: StructType => + assert(LanceArrowUtils.fromArrowSchema( + LanceArrowUtils.toArrowSchema(schema, null, true)) === schema) + case _ => + roundtrip(new StructType().add(fieldName, dt)) + } + } + + test("unsigned long") { + roundtrip(BooleanType, LanceConstant.ROW_ID) + val arrowType = LanceArrowUtils.toArrowField(LanceConstant.ROW_ID, LongType, true, "Beijing") + assert(arrowType.getType.asInstanceOf[ArrowType.Int].getBitWidth === 64) + assert(!arrowType.getType.asInstanceOf[ArrowType.Int].getIsSigned) + } + + test("simple") { + roundtrip(BooleanType) + roundtrip(ByteType) + roundtrip(ShortType) + roundtrip(IntegerType) + roundtrip(LongType) + roundtrip(FloatType) + roundtrip(DoubleType) + roundtrip(StringType) + roundtrip(BinaryType) + roundtrip(DecimalType.SYSTEM_DEFAULT) + roundtrip(DateType) + roundtrip(YearMonthIntervalType()) + roundtrip(DayTimeIntervalType()) + } + + test("timestamp") { + + def roundtripWithTz(timeZoneId: String): Unit = { + val schema = new StructType().add("value", TimestampType) + val arrowSchema = LanceArrowUtils.toArrowSchema(schema, timeZoneId, true) + val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] + assert(fieldType.getTimezone() === timeZoneId) + assert(LanceArrowUtils.fromArrowSchema(arrowSchema) === schema) + } + + roundtripWithTz(ZoneId.systemDefault().getId) + roundtripWithTz("Asia/Tokyo") + roundtripWithTz("UTC") + } + + test("array") { + roundtrip(ArrayType(IntegerType, containsNull = true)) + roundtrip(ArrayType(IntegerType, containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = false)) + } + + test("struct") { + roundtrip(new StructType()) + roundtrip(new StructType().add("i", IntegerType)) + roundtrip(new StructType().add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add( + "struct", + new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType)))) + } + + test("struct with duplicated field names") { + + def check(dt: DataType, expected: DataType): Unit = { + val schema = new StructType().add("value", dt) + intercept[SparkUnsupportedOperationException] { + LanceArrowUtils.toArrowSchema(schema, null, true) + } + assert(LanceArrowUtils.fromArrowSchema(LanceArrowUtils.toArrowSchema(schema, null, false)) + === new StructType().add("value", expected)) + } + + roundtrip(new StructType().add("i", IntegerType).add("i", StringType)) + + check( + new StructType().add("i", IntegerType).add("i", StringType), + new StructType().add("i_0", IntegerType).add("i_1", StringType)) + check( + ArrayType(new StructType().add("i", IntegerType).add("i", StringType)), + ArrayType(new StructType().add("i_0", IntegerType).add("i_1", StringType))) + check( + MapType(StringType, new StructType().add("i", IntegerType).add("i", StringType)), + MapType(StringType, new StructType().add("i_0", IntegerType).add("i_1", StringType))) + } + +} diff --git a/java/spark/src/test/scala/org/apache/spark/sql/vectorized/LanceArrowColumnVectorSuite.scala b/java/spark/src/test/scala/org/apache/spark/sql/vectorized/LanceArrowColumnVectorSuite.scala new file mode 100644 index 00000000000..5e555808822 --- /dev/null +++ b/java/spark/src/test/scala/org/apache/spark/sql/vectorized/LanceArrowColumnVectorSuite.scala @@ -0,0 +1,522 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized + +/* + * The following code is originally from https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala + * and is licensed under the Apache license: + * + * License: Apache License 2.0, Copyright 2014 and onwards The Apache Software Foundation. + * https://github.com/apache/spark/blob/master/LICENSE + * + * It has been modified by the Lance developers to fit the needs of the Lance project. + */ + +import com.lancedb.lance.spark.LanceConstant + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.{ArrowUtils, LanceArrowUtils} +import org.apache.spark.unsafe.types.UTF8String +import org.scalatest.funsuite.AnyFunSuite + +class LanceArrowColumnVectorSuite extends AnyFunSuite { + test("boolean") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null) + .createVector(allocator).asInstanceOf[BitVector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, if (i % 2 == 0) 1 else 0) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === BooleanType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBoolean(i) === (i % 2 == 0)) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getBooleans(0, 10) === (0 until 10).map(i => (i % 2 == 0))) + + columnVector.close() + allocator.close() + } + + test("byte") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("byte", ByteType, nullable = true, null) + .createVector(allocator).asInstanceOf[TinyIntVector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i.toByte) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === ByteType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getByte(i) === i.toByte) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getBytes(0, 10) === (0 until 10).map(i => i.toByte)) + + columnVector.close() + allocator.close() + } + + test("short") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("short", ShortType, nullable = true, null) + .createVector(allocator).asInstanceOf[SmallIntVector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i.toShort) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === ShortType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getShort(i) === i.toShort) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getShorts(0, 10) === (0 until 10).map(i => i.toShort)) + + columnVector.close() + allocator.close() + } + + test("int") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("int", IntegerType, nullable = true, null) + .createVector(allocator).asInstanceOf[IntVector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === IntegerType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getInt(i) === i) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getInts(0, 10) === (0 until 10)) + + columnVector.close() + allocator.close() + } + + test("long") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("long", LongType, nullable = true, null) + .createVector(allocator).asInstanceOf[BigIntVector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i.toLong) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === LongType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getLong(i) === i.toLong) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getLongs(0, 10) === (0 until 10).map(i => i.toLong)) + + columnVector.close() + allocator.close() + } + + test("unsigned long") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("unsigned long", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField(LanceConstant.ROW_ID, LongType, nullable = true, null) + .createVector(allocator).asInstanceOf[UInt8Vector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i.toLong) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === LongType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getLong(i) === i.toLong) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getLongs(0, 10) === (0 until 10).map(i => i.toLong)) + + columnVector.close() + allocator.close() + } + + test("float") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("float", FloatType, nullable = true, null) + .createVector(allocator).asInstanceOf[Float4Vector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i.toFloat) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === FloatType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getFloat(i) === i.toFloat) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getFloats(0, 10) === (0 until 10).map(i => i.toFloat)) + + columnVector.close() + allocator.close() + } + + test("double") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("double", DoubleType, nullable = true, null) + .createVector(allocator).asInstanceOf[Float8Vector] + vector.allocateNew() + + (0 until 10).foreach { i => + vector.setSafe(i, i.toDouble) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === DoubleType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getDouble(i) === i.toDouble) + } + assert(columnVector.isNullAt(10)) + + assert(columnVector.getDoubles(0, 10) === (0 until 10).map(i => i.toDouble)) + + columnVector.close() + allocator.close() + } + + test("string") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("string", StringType, nullable = true, null) + .createVector(allocator).asInstanceOf[VarCharVector] + vector.allocateNew() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + vector.setSafe(i, utf8, 0, utf8.length) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("large_string") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("string", StringType, nullable = true, null, true) + .createVector(allocator).asInstanceOf[LargeVarCharVector] + vector.allocateNew() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + vector.setSafe(i, utf8, 0, utf8.length) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("binary") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("binary", BinaryType, nullable = true, null, false) + .createVector(allocator).asInstanceOf[VarBinaryVector] + vector.allocateNew() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + vector.setSafe(i, utf8, 0, utf8.length) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("large_binary") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) + val vector = LanceArrowUtils.toArrowField("binary", BinaryType, nullable = true, null, true) + .createVector(allocator).asInstanceOf[LargeVarBinaryVector] + vector.allocateNew() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + vector.setSafe(i, utf8, 0, utf8.length) + } + vector.setNull(10) + vector.setValueCount(11) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("array") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) + val vector = + LanceArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) + .createVector(allocator).asInstanceOf[ListVector] + vector.allocateNew() + val elementVector = vector.getDataVector().asInstanceOf[IntVector] + + // [1, 2] + vector.startNewValue(0) + elementVector.setSafe(0, 1) + elementVector.setSafe(1, 2) + vector.endValue(0, 2) + + // [3, null, 5] + vector.startNewValue(1) + elementVector.setSafe(2, 3) + elementVector.setNull(3) + elementVector.setSafe(4, 5) + vector.endValue(1, 3) + + // null + + // [] + vector.startNewValue(3) + vector.endValue(3, 0) + + elementVector.setValueCount(5) + vector.setValueCount(4) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + val array0 = columnVector.getArray(0) + assert(array0.numElements() === 2) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + + val array1 = columnVector.getArray(1) + assert(array1.numElements() === 3) + assert(array1.getInt(0) === 3) + assert(array1.isNullAt(1)) + assert(array1.getInt(2) === 5) + + assert(columnVector.isNullAt(2)) + + val array3 = columnVector.getArray(3) + assert(array3.numElements() === 0) + + columnVector.close() + allocator.close() + } + + test("non nullable struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = LanceArrowUtils.toArrowField("struct", schema, nullable = false, null) + .createVector(allocator).asInstanceOf[StructVector] + + vector.allocateNew() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] + + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) + + vector.setValueCount(2) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(!columnVector.hasNull) + assert(columnVector.numNulls === 0) + + val row0 = columnVector.getStruct(0) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + columnVector.close() + allocator.close() + } + + test("struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = LanceArrowUtils.toArrowField("struct", schema, nullable = true, null) + .createVector(allocator).asInstanceOf[StructVector] + vector.allocateNew() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] + + // (1, 1L) + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + + // (2, null) + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) + + // (null, 3L) + vector.setIndexDefined(2) + intVector.setNull(2) + longVector.setSafe(2, 3L) + + // null + vector.setNull(3) + + // (5, 5L) + vector.setIndexDefined(4) + intVector.setSafe(4, 5) + longVector.setSafe(4, 5L) + + intVector.setValueCount(5) + longVector.setValueCount(5) + vector.setValueCount(5) + + val columnVector = new LanceArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(columnVector.hasNull) + assert(columnVector.numNulls === 1) + + val row0 = columnVector.getStruct(0) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + val row2 = columnVector.getStruct(2) + assert(row2.isNullAt(0)) + assert(row2.getLong(1) === 3L) + + assert(columnVector.isNullAt(3)) + + val row4 = columnVector.getStruct(4) + assert(row4.getInt(0) === 5) + assert(row4.getLong(1) === 5L) + + columnVector.close() + allocator.close() + } +} diff --git a/notebooks/quickstart.ipynb b/notebooks/quickstart.ipynb index 45ace154470..5abd79db2d7 100644 --- a/notebooks/quickstart.ipynb +++ b/notebooks/quickstart.ipynb @@ -1,5 +1,13 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "7980c1ca", + "metadata": {}, + "source": [ + "# Quickstart" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -981,7 +989,7 @@ } ], "source": [ - "%%time \n", + "%%time\n", "\n", "sift1m.create_index(\n", " \"vector\",\n", diff --git a/protos/encodings.proto b/protos/encodings.proto index cac0d0d5e5f..185a0963083 100644 --- a/protos/encodings.proto +++ b/protos/encodings.proto @@ -152,6 +152,8 @@ message List { message FixedSizeList { /// The number of items in each list uint32 dimension = 1; + /// True if the list is nullable + bool has_validity = 3; /// The items in the list ArrayEncoding items = 2; } @@ -180,8 +182,6 @@ message Flat { message Constant { // The value (TODO: define encoding for literals?) bytes value = 1; - // The number of values - uint64 num_values = 2; } // Items are bitpacked in a buffer @@ -211,11 +211,20 @@ message BitpackedForNonNeg { Buffer buffer = 3; } -message Bitpack2 { +// Opaque bitpacking variant where the bits per value are stored inline in the chunks themselves +message InlineBitpacking { // the number of bits of the uncompressed value. e.g. for a u32, this will be 32 uint64 uncompressed_bits_per_value = 2; } +// Transparent bitpacking variant where the number of bits per value is fixed through the whole buffer +message OutOfLineBitpacking { + // the number of bits of the uncompressed value. e.g. for a u32, this will be 32 + uint64 uncompressed_bits_per_value = 2; + // The number of compressed bits per value, fixed across the entire buffer + uint64 compressed_bits_per_value = 3; +} + // An array encoding for shredded structs that will never be null // // There is no actual data in this column. @@ -230,15 +239,8 @@ message Binary { uint64 null_adjustment = 3; } -message BinaryMiniBlock { -} - -message BinaryBlock { -} - -message FsstMiniBlock { - ArrayEncoding BinaryMiniBlock = 1; - bytes symbol_table = 2; +message Variable { + uint32 bits_per_offset = 1; } message Fsst { @@ -258,11 +260,20 @@ message PackedStruct { Buffer buffer = 2; } +message PackedStructFixedWidthMiniBlock { + ArrayEncoding Flat = 1; + repeated uint32 bits_per_values = 2; +} + message FixedSizeBinary { ArrayEncoding bytes = 1; uint32 byte_width = 2; } +message Block { + string scheme = 1; +} + // Encodings that decode into an Arrow array message ArrayEncoding { oneof array_encoding { @@ -279,10 +290,11 @@ message ArrayEncoding { FixedSizeBinary fixed_size_binary = 11; BitpackedForNonNeg bitpacked_for_non_neg = 12; Constant constant = 13; - Bitpack2 bitpack2 = 14; - BinaryMiniBlock binary_mini_block = 15; - FsstMiniBlock fsst_mini_block = 16; - BinaryBlock binary_block = 17; + InlineBitpacking inline_bitpacking = 14; + OutOfLineBitpacking out_of_line_bitpacking = 15; + Variable variable = 16; + PackedStructFixedWidthMiniBlock packed_struct_fixed_width_mini_block = 17; + Block block = 18; } } @@ -310,6 +322,51 @@ message ColumnEncoding { } } +// # Standardized Interpretation of Counting Terms +// +// When working with 2.1 encodings we have a number of different "counting terms" and it can be +// difficult to understand what we mean when we are talking about a "number of values". Here is +// a standard interpretation of these terms: +// +// TODO: This is a newly added standardization and hasn't yet been applied to all code. +// +// To understand these definitions consider a data type FIXED_SIZE_LIST>. +// +// A "value" is an abstract term when we aren't being specific. +// +// - num_rows: This is the highest level counting term. A single row includes everything in the +// fixed size list. This is what the user asks for when they asks for a range of rows. +// - num_elements: The number of elements is the number of rows multiplied by the dimension of any +// fixed size list wrappers. This is what you get when you flatten the FSL layer and +// is the starting point for structural encoding. Note that an element can be a list +// value or a single primitive value. +// - num_items: The number of items is the number of values in the repetition and definition vectors +// after everything has been flattened. +// - num_visible_items: The number of visible items is the number of items after invisible items +// have been removed. Invisible items are rep/def levels that don't correspond to an +// actual value. +// +// Note that we haven't exactly defined LIST> yet. Both FIXED_SIZE_LIST> +// and LIST> haven't been fully implemented and tested. + +/// Describes the meaning of each repdef layer in a mini-block layout +enum RepDefLayer { + // Should never be used, included for debugging purporses and general protobuf best practice + REPDEF_UNSPECIFIED = 0; + // All values are valid (can be primitive or struct) + REPDEF_ALL_VALID_ITEM = 1; + // All list values are valid + REPDEF_ALL_VALID_LIST = 2; + // There are one or more null items (can be primitive or struct) + REPDEF_NULLABLE_ITEM = 3; + // A list layer with null lists but no empty lists + REPDEF_NULLABLE_LIST = 4; + // A list layer with empty lists but no null lists + REPDEF_EMPTYABLE_LIST = 5; + // A list layer with both empty lists and null lists + REPDEF_NULL_AND_EMPTY_LIST = 6; +} + /// A layout used for pages where the data is small /// /// In this case we can fit many values into a single disk sector and transposing buffers is @@ -317,12 +374,40 @@ message ColumnEncoding { /// chunks (called mini blocks) which are roughly the size of a disk sector. message MiniBlockLayout { // Description of the compression of repetition levels (e.g. how many bits per rep) + // + // Optional, if there is no repetition then this field is not present ArrayEncoding rep_compression = 1; // Description of the compression of definition levels (e.g. how many bits per def) + // + // Optional, if there is no definition then this field is not present ArrayEncoding def_compression = 2; // Description of the compression of values ArrayEncoding value_compression = 3; + // Dictionary data ArrayEncoding dictionary = 4; + // Number of items in the dictionary + uint64 num_dictionary_items = 5; + // The meaning of each repdef layer, used to interpret repdef buffers correctly + repeated RepDefLayer layers = 6; + // The number of buffers in each mini-block, this is determined by the compression and does + // NOT include the repetition or definition buffers (the presence of these buffers can be determined + // by looking at the rep_compression and def_compression fields) + uint64 num_buffers = 7; + // The depth of the repetition index. + // + // If there is repetition then the depth must be at least 1. If there are many layers + // of repetition then deeper repetition indices will support deeper nested random access. For + // example, given 5 layers of repetition then the repetition index depth must be at least + // 3 to support access like rows[50][17][3]. + // + // We require `repetition_index_depth + 1` u64 values per mini-block to store the repetition + // index if the `repetition_index_depth` is greater than 0. The +1 is because we need to store + // the number of "leftover items" at the end of the chunk. Otherwise, we wouldn't have any way + // to know if the final item in a chunk is valid or not. + uint32 repetition_index_depth = 8; + // The page already records how many rows are in the page. For mini-block we also need to know how + // many "items" are in the page. A row and an item are the same thing unless the page has lists. + uint64 num_items = 9; } /// A layout used for pages where the data is large @@ -334,17 +419,34 @@ message FullZipLayout { uint32 bits_rep = 1; // The number of bits of definition info (0 if there is no definition) uint32 bits_def = 2; + // The number of bits of value info + // + // Note: we use bits here (and not bytes) for consistency with other encodings. However, in practice, + // there is never a reason to use a bits per value that is not a multiple of 8. The complexity is not + // worth the small savings in space since this encoding is typically used with large values already. + oneof details { + // If this is a fixed width block then we need to have a fixed number of bits per value + uint32 bits_per_value = 3; + // If this is a variable width block then we need to have a fixed number of bits per offset + uint32 bits_per_offset = 4; + } + // The number of items in the page + uint32 num_items = 5; + // The number of visible items in the page + uint32 num_visible_items = 6; // Description of the compression of values - ArrayEncoding value_compression = 3; + ArrayEncoding value_compression = 7; + // The meaning of each repdef layer, used to interpret repdef buffers correctly + repeated RepDefLayer layers = 8; } /// A layout used for pages where all values are null /// -/// In addition, there can be no repetition levels and only a single definition level -/// -/// If the data is all-null but we have non-trivial rep-def then MiniBlockLayout is used +/// There may be buffers of repetition and definition information +/// if required in order to interpret what kind of nulls are present message AllNullLayout { - + // The meaning of each repdef layer, used to interpret repdef buffers correctly + repeated RepDefLayer layers = 5; } message PageLayout { diff --git a/protos/index.proto b/protos/index.proto index 0db16566d2d..e7eb7f4818f 100644 --- a/protos/index.proto +++ b/protos/index.proto @@ -66,6 +66,9 @@ message IVF { // Tensor of centroids. `num_partitions * dimension` of float32s. Tensor centroids_tensor = 4; + + // KMeans loss. + optional double loss = 5; } // Product Quantization. diff --git a/protos/table.proto b/protos/table.proto index 84e49b98cab..f200dd33a7d 100644 --- a/protos/table.proto +++ b/protos/table.proto @@ -361,4 +361,5 @@ message BTreeIndexDetails {} message BitmapIndexDetails {} message LabelListIndexDetails {} message InvertedIndexDetails {} +message NGramIndexDetails {} message VectorIndexDetails {} \ No newline at end of file diff --git a/protos/transaction.proto b/protos/transaction.proto index 3aee36995eb..5cf7b52b2fa 100644 --- a/protos/transaction.proto +++ b/protos/transaction.proto @@ -67,7 +67,9 @@ message Transaction { // Add or replace a new secondary index. // - // - new_indices: the modified indices + // This is also used to remove an index (we are replacing it with nothing) + // + // - new_indices: the modified indices, empty if dropping indices only // - removed_indices: the indices that are being replaced message CreateIndex { repeated IndexMetadata new_indices = 1; @@ -163,6 +165,22 @@ message Transaction { message UpdateConfig { map upsert_values = 1; repeated string delete_keys = 2; + map schema_metadata = 3; + map field_metadata = 4; + + message FieldMetadataUpdate { + map metadata = 5; + } + } + + message DataReplacementGroup { + uint64 fragment_id = 1; + DataFile new_file = 2; + } + + // An operation that replaces the data in a region of the table with new data. + message DataReplacement { + repeated DataReplacementGroup replacements = 1; } // The operation of this transaction. @@ -178,6 +196,7 @@ message Transaction { Update update = 108; Project project = 109; UpdateConfig update_config = 110; + DataReplacement data_replacement = 111; } // An operation to apply to the blob dataset @@ -185,4 +204,4 @@ message Transaction { Append blob_append = 200; Overwrite blob_overwrite = 202; } -} \ No newline at end of file +} diff --git a/python/Cargo.lock b/python/Cargo.lock index 4bbf63f81be..dc372d4c4b0 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -17,6 +17,12 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + [[package]] name = "ahash" version = "0.8.11" @@ -25,7 +31,7 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -57,9 +63,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android-tzdata" @@ -76,11 +82,61 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +dependencies = [ + "anstyle", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "anyhow" -version = "1.0.93" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" [[package]] name = "arc-swap" @@ -102,9 +158,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" +checksum = "dc208515aa0151028e464cc94a692156e945ce5126abd3537bb7fd6ba2143ed1" dependencies = [ "arrow-arith", "arrow-array", @@ -124,24 +180,23 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" +checksum = "e07e726e2b3f7816a85c6a45b6ec118eeeabf0b2a8c208122ad949437181f49a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "half", "num", ] [[package]] name = "arrow-array" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" +checksum = "a2262eba4f16c78496adfd559a29fe4b24df6088efc9985a873d58e92be022d5" dependencies = [ "ahash", "arrow-buffer", @@ -150,15 +205,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "num", ] [[package]] name = "arrow-buffer" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" +checksum = "4e899dade2c3b7f5642eb8366cfd898958bcca099cde6dfea543c7e8d3ad88d4" dependencies = [ "bytes", "half", @@ -167,9 +222,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" +checksum = "4103d88c5b441525ed4ac23153be7458494c2b0c9a11115848fdb9b81f6f886a" dependencies = [ "arrow-array", "arrow-buffer", @@ -188,28 +243,25 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c13c36dc5ddf8c128df19bab27898eea64bf9da2b555ec1cd17a8ff57fba9ec2" +checksum = "43d3cb0914486a3cae19a5cad2598e44e225d53157926d0ada03c20521191a65" dependencies = [ "arrow-array", - "arrow-buffer", "arrow-cast", - "arrow-data", "arrow-schema", "chrono", "csv", "csv-core", "lazy_static", - "lexical-core", "regex", ] [[package]] name = "arrow-data" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" +checksum = "0a329fb064477c9ec5f0870d2f5130966f91055c7c5bce2b3a084f116bc28c3b" dependencies = [ "arrow-buffer", "arrow-schema", @@ -219,13 +271,12 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e786e1cdd952205d9a8afc69397b317cfbb6e0095e445c69cda7e8da5c1eeb0f" +checksum = "ddecdeab02491b1ce88885986e25002a3da34dd349f682c7cfe67bab7cc17b86" dependencies = [ "arrow-array", "arrow-buffer", - "arrow-cast", "arrow-data", "arrow-schema", "flatbuffers", @@ -235,9 +286,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb22284c5a2a01d73cebfd88a33511a3234ab45d66086b2ca2d1228c3498e445" +checksum = "d03b9340013413eb84868682ace00a1098c81a5ebc96d279f7ebf9a4cac3c0fd" dependencies = [ "arrow-array", "arrow-buffer", @@ -255,26 +306,23 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" +checksum = "f841bfcc1997ef6ac48ee0305c4dfceb1f7c786fe31e67c1186edf775e1f1160" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "half", - "num", ] [[package]] name = "arrow-row" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" +checksum = "1eeb55b0a0a83851aa01f2ca5ee5648f607e8506ba6802577afdda9d75cdedcd" dependencies = [ - "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -284,18 +332,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" +checksum = "85934a9d0261e0fa5d4e2a5295107d743b543a6e0484a835d4b8db2da15306f9" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", ] [[package]] name = "arrow-select" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" +checksum = "7e2932aece2d0c869dd2125feb9bd1709ef5c445daa3838ac4112dcfa0fda52c" dependencies = [ "ahash", "arrow-array", @@ -307,9 +355,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" +checksum = "912e38bd6a7a7714c1d9b61df80315685553b7455e8a6045c27531d8ecd5b458" dependencies = [ "arrow-array", "arrow-buffer", @@ -319,7 +367,7 @@ dependencies = [ "memchr", "num", "regex", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -345,24 +393,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "async-compression" -version = "0.4.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" -dependencies = [ - "bzip2", - "flate2", - "futures-core", - "futures-io", - "memchr", - "pin-project-lite", - "tokio", - "xz2", - "zstd", - "zstd-safe", -] - [[package]] name = "async-executor" version = "1.13.1" @@ -416,7 +446,7 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "event-listener 5.3.1", + "event-listener 5.4.0", "event-listener-strategy", "pin-project-lite", ] @@ -438,7 +468,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -475,13 +505,13 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "d556ec1359574147ec0c4fc5eb525f3f23263a592b1a9c07e0a75b427de55c97" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -513,9 +543,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.10" +version = "1.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b49afaa341e8dd8577e1a2200468f98956d6eda50bcf4a53246cc00174ba924" +checksum = "90aff65e86db5fe300752551c1b015ef72b708ac54bded8ef43d0d53cb7cb0b1" dependencies = [ "aws-credential-types", "aws-runtime", @@ -523,7 +553,7 @@ dependencies = [ "aws-sdk-ssooidc", "aws-sdk-sts", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.61.1", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -555,14 +585,14 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.4.3" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a10d5c055aa540164d9561a0e2e74ad30f0dcf7393c3a92f6733ddf9c5762468" +checksum = "76dd04d39cc12844c0994f2c9c5a6f5184c22e9188ec1ff723de41910a21dcad" dependencies = [ "aws-credential-types", "aws-sigv4", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -580,14 +610,14 @@ dependencies = [ [[package]] name = "aws-sdk-dynamodb" -version = "1.54.0" +version = "1.66.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8efdda6a491bb4640d35b99b0a4b93f75ce7d6e3a1937c3e902d3cb23d0a179c" +checksum = "5296daf754d333f51798bff599876c3849394ec3dabe8d1d61cbacb961fdde37" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -603,14 +633,14 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.49.0" +version = "1.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09677244a9da92172c8dc60109b4a9658597d4d298b188dd0018b6a66b410ca4" +checksum = "e65ff295979977039a25f5a0bf067a64bc5e6aa38f3cef4037cf42516265553c" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.61.1", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -625,14 +655,14 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.50.0" +version = "1.62.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fea2f3a8bb3bd10932ae7ad59cc59f65f270fc9183a7e91f501dc5efbef7ee" +checksum = "91430a60f754f235688387b75ee798ef00cfd09709a582be2b7525ebb5306d4f" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.61.1", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -647,14 +677,14 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.50.0" +version = "1.62.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ada54e5f26ac246dc79727def52f7f8ed38915cb47781e2a72213957dc3a7d5" +checksum = "9276e139d39fff5a0b0c984fc2d30f970f9a202da67234f948fda02e5bea1dbe" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.61.1", "aws-smithy-json", "aws-smithy-query", "aws-smithy-runtime", @@ -670,12 +700,12 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.5" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" +checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051" dependencies = [ "aws-credential-types", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", @@ -683,7 +713,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.1.0", + "http 1.2.0", "once_cell", "percent-encoding", "sha2", @@ -693,9 +723,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.1" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62220bc6e97f946ddd51b5f1361f78996e704677afc518a4ff66b7a72ea1378c" +checksum = "fa59d1327d8b5053c54bf2eaae63bf629ba9e904434d0835a28ed3c0ed0a614e" dependencies = [ "futures-util", "pin-project-lite", @@ -704,9 +734,29 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.11" +version = "0.60.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7809c27ad8da6a6a68c454e651d4962479e81472aa19ae99e59f9aba1f9713cc" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http" +version = "0.61.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8bc3e8fdc6b8d07d976e301c02fe553f72a39b7a9fea820e023268467d7ab6" +checksum = "e6f276f21c7921fe902826618d1423ae5bf74cf8c1b8472aee8434f3dfd31824" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", @@ -724,9 +774,9 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.7" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" +checksum = "623a51127f24c30776c8b374295f2df78d92517386f77ba30773f15a30ce1422" dependencies = [ "aws-smithy-types", ] @@ -743,12 +793,12 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.3" +version = "1.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" +checksum = "d526a12d9ed61fadefda24abe2e682892ba288c2018bcb38b1b4c111d13f6d92" dependencies = [ "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", @@ -758,7 +808,7 @@ dependencies = [ "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.31", + "hyper 0.14.32", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -778,7 +828,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.1.0", + "http 1.2.0", "pin-project-lite", "tokio", "tracing", @@ -787,16 +837,16 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.9" +version = "1.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fbd94a32b3a7d55d3806fe27d98d3ad393050439dd05eb53ece36ec5e3d3510" +checksum = "c7b8a53819e42f10d0821f56da995e1470b199686a1809168db6ca485665f042" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.1.0", + "http 1.2.0", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -822,9 +872,9 @@ dependencies = [ [[package]] name = "aws-types" -version = "1.3.3" +version = "1.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5221b91b3e441e6675310829fd8984801b772cb1546ef6c0e54dec9f1ac13fef" +checksum = "dfbd0a668309ec1f66c0f6bda4840dd6d4796ae26d699ebc266d7cc95c6d040f" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -871,6 +921,28 @@ dependencies = [ "vsimd", ] +[[package]] +name = "bigdecimal" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f31f3af01c5c65a07985c804d3366560e6fa7883d640a122819b14ec327482c" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -879,9 +951,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" [[package]] name = "bitpacking" @@ -915,9 +987,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.4" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +checksum = "675f87afced0413c9bb02843499dbbd3882a237645883f71a2b59644a6d2f753" dependencies = [ "arrayref", "arrayvec", @@ -950,9 +1022,9 @@ dependencies = [ [[package]] name = "brotli" -version = "6.0.0" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -961,9 +1033,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.1" +version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -971,15 +1043,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" [[package]] name = "byteorder" @@ -989,9 +1061,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" [[package]] name = "bytes-utils" @@ -1004,35 +1076,23 @@ dependencies = [ ] [[package]] -name = "bzip2" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" -dependencies = [ - "bzip2-sys", - "libc", -] - -[[package]] -name = "bzip2-sys" -version = "0.1.11+1.0.8" +name = "cc" +version = "1.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" dependencies = [ - "cc", + "jobserver", "libc", - "pkg-config", + "shlex", ] [[package]] -name = "cc" -version = "1.2.1" +name = "cedarwood" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +checksum = "6d910bedd62c24733263d0bed247460853c9d22e8956bd4cd964302095e04e90" dependencies = [ - "jobserver", - "libc", - "shlex", + "smallvec", ] [[package]] @@ -1055,9 +1115,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1070,9 +1130,9 @@ dependencies = [ [[package]] name = "chrono-tz" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" +checksum = "9c6ac4f2c0bf0f44e9161aec9675e1050aa4a530663c4a9e37e108fa948bca9f" dependencies = [ "chrono", "chrono-tz-build", @@ -1081,23 +1141,27 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "comfy-table" -version = "7.1.3" +version = "7.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" +checksum = "4a65ebfec4fb190b6f90e944a817d60499ee0744e582530e2c9900a22e591d9a" dependencies = [ - "strum", - "strum_macros", + "unicode-segmentation", "unicode-width", ] @@ -1125,7 +1189,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom", + "getrandom 0.2.15", "once_cell", "tiny-keccak", ] @@ -1162,11 +1226,20 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core2" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] + [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1197,18 +1270,18 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -1225,24 +1298,24 @@ dependencies = [ [[package]] name = "crossbeam-queue" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -1268,26 +1341,54 @@ dependencies = [ [[package]] name = "csv-core" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" dependencies = [ "memchr", ] [[package]] -name = "dashmap" -version = "5.5.3" +name = "darling" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ - "cfg-if", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.99", ] +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "dary_heap" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" + [[package]] name = "dashmap" version = "6.1.0" @@ -1304,118 +1405,171 @@ dependencies = [ [[package]] name = "datafusion" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4fd4a99fc70d40ef7e52b243b4a399c3f8d353a40d5ecb200deee05e49c61bb" +checksum = "914e6f9525599579abbd90b0f7a55afcaaaa40350b9e9ed52563f126dfe45fd3" dependencies = [ - "ahash", "arrow", - "arrow-array", "arrow-ipc", "arrow-schema", - "async-compression", "async-trait", "bytes", - "bzip2", "chrono", - "dashmap 6.1.0", "datafusion-catalog", + "datafusion-catalog-listing", "datafusion-common", "datafusion-common-runtime", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-nested", + "datafusion-functions-table", + "datafusion-functions-window", + "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", - "flate2", "futures", - "glob", - "half", - "hashbrown 0.14.5", - "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", - "num_cpus", "object_store", "parking_lot", "parquet", - "paste", - "pin-project-lite", "rand", + "regex", "sqlparser", "tempfile", "tokio", - "tokio-util", "url", "uuid", - "xz2", - "zstd", ] [[package]] name = "datafusion-catalog" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13b3cfbd84c6003594ae1972314e3df303a27ce8ce755fcea3240c90f4c0529" +checksum = "998a6549e6ee4ee3980e05590b2960446a56b343ea30199ef38acd0e0b9036e2" dependencies = [ - "arrow-schema", + "arrow", + "async-trait", + "dashmap", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", + "datafusion-sql", + "futures", + "itertools 0.14.0", + "log", + "parking_lot", +] + +[[package]] +name = "datafusion-catalog-listing" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5ac10096a5b3c0d8a227176c0e543606860842e943594ccddb45cf42a526e43" +dependencies = [ + "arrow", "async-trait", + "datafusion-catalog", "datafusion-common", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", + "futures", + "log", + "object_store", + "tokio", ] [[package]] name = "datafusion-common" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44fdbc877e3e40dcf88cc8f283d9f5c8851f0a3aa07fee657b1b75ac1ad49b9c" +checksum = "1f53d7ec508e1b3f68bd301cee3f649834fad51eff9240d898a4b2614cfd0a7a" dependencies = [ "ahash", "arrow", - "arrow-array", - "arrow-buffer", - "arrow-schema", - "chrono", + "arrow-ipc", + "base64 0.22.1", "half", "hashbrown 0.14.5", - "instant", + "indexmap", "libc", - "num_cpus", + "log", "object_store", "parquet", + "paste", "sqlparser", + "tokio", + "web-time", ] [[package]] name = "datafusion-common-runtime" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7496d1f664179f6ce3a5cbef6566056ccaf3ea4aa72cc455f80e62c1dd86b1" +checksum = "e0fcf41523b22e14cc349b01526e8b9f59206653037f2949a4adbfde5f8cb668" dependencies = [ + "log", "tokio", ] [[package]] -name = "datafusion-execution" -version = "41.0.0" +name = "datafusion-datasource" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799e70968c815b611116951e3dd876aef04bf217da31b72eec01ee6a959336a1" +checksum = "cf7f37ad8b6e88b46c7eeab3236147d32ea64b823544f498455a8d9042839c92" dependencies = [ "arrow", + "async-trait", + "bytes", "chrono", - "dashmap 6.1.0", + "datafusion-catalog", + "datafusion-common", + "datafusion-common-runtime", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "futures", + "glob", + "itertools 0.14.0", + "log", + "object_store", + "rand", + "tokio", + "url", +] + +[[package]] +name = "datafusion-doc" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7db7a0239fd060f359dc56c6e7db726abaa92babaed2fb2e91c3a8b2fff8b256" + +[[package]] +name = "datafusion-execution" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0938f9e5b6bc5782be4111cdfb70c02b7b5451bf34fd57e4de062a7f7c4e31f1" +dependencies = [ + "arrow", + "dashmap", "datafusion-common", "datafusion-expr", "futures", - "hashbrown 0.14.5", "log", "object_store", "parking_lot", @@ -1426,28 +1580,42 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c1841c409d9518c17971d15c9bae62e629eb937e6fb6c68cd32e9186f8b30d2" +checksum = "b36c28b00b00019a8695ad7f1a53ee1673487b90322ecbd604e2cf32894eb14f" dependencies = [ - "ahash", "arrow", - "arrow-array", - "arrow-buffer", "chrono", "datafusion-common", + "datafusion-doc", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", + "datafusion-functions-window-common", + "datafusion-physical-expr-common", + "indexmap", "paste", "serde_json", "sqlparser", - "strum", - "strum_macros", +] + +[[package]] +name = "datafusion-expr-common" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18f0a851a436c5a2139189eb4617a54e6a9ccb9edc96c4b3c83b3bb7c58b950e" +dependencies = [ + "arrow", + "datafusion-common", + "indexmap", + "itertools 0.14.0", + "paste", ] [[package]] name = "datafusion-functions" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8e481cf34d2a444bd8fa09b65945f0ce83dc92df8665b761505b3d9f351bebb" +checksum = "e3196e37d7b65469fb79fee4f05e5bb58a456831035f9a38aa5919aeb3298d40" dependencies = [ "arrow", "arrow-buffer", @@ -1456,11 +1624,13 @@ dependencies = [ "blake3", "chrono", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", - "hashbrown 0.14.5", + "datafusion-expr-common", + "datafusion-macros", "hex", - "itertools 0.12.1", + "itertools 0.14.0", "log", "md-5", "rand", @@ -1472,130 +1642,193 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b4ece19f73c02727e5e8654d79cd5652de371352c1df3c4ac3e419ecd6943fb" +checksum = "adfc2d074d5ee4d9354fdcc9283d5b2b9037849237ddecb8942a29144b77ca05" dependencies = [ "ahash", "arrow", - "arrow-schema", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate-common", + "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", + "half", "log", "paste", - "sqlparser", +] + +[[package]] +name = "datafusion-functions-aggregate-common" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cbceba0f98d921309a9121b702bcd49289d383684cccabf9a92cda1602f3bbb" +dependencies = [ + "ahash", + "arrow", + "datafusion-common", + "datafusion-expr-common", + "datafusion-physical-expr-common", ] [[package]] name = "datafusion-functions-nested" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1474552cc824e8c9c88177d454db5781d4b66757d4aca75719306b8343a5e8d" +checksum = "170e27ce4baa27113ddf5f77f1a7ec484b0dbeda0c7abbd4bad3fc609c8ab71a" dependencies = [ "arrow", - "arrow-array", - "arrow-buffer", "arrow-ord", - "arrow-schema", "datafusion-common", + "datafusion-doc", "datafusion-execution", "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "itertools 0.12.1", + "datafusion-macros", + "datafusion-physical-expr-common", + "itertools 0.14.0", "log", "paste", - "rand", ] [[package]] -name = "datafusion-optimizer" -version = "41.0.0" +name = "datafusion-functions-table" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "791ff56f55608bc542d1ea7a68a64bdc86a9413f5a381d06a39fd49c2a3ab906" +checksum = "7d3a06a7f0817ded87b026a437e7e51de7f59d48173b0a4e803aa896a7bd6bb5" dependencies = [ "arrow", "async-trait", + "datafusion-catalog", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-plan", + "parking_lot", + "paste", +] + +[[package]] +name = "datafusion-functions-window" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6c608b66496a1e05e3d196131eb9bebea579eed1f59e88d962baf3dda853bc6" +dependencies = [ + "datafusion-common", + "datafusion-doc", + "datafusion-expr", + "datafusion-functions-window-common", + "datafusion-macros", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "log", + "paste", +] + +[[package]] +name = "datafusion-functions-window-common" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da2f9d83348957b4ad0cd87b5cb9445f2651863a36592fe5484d43b49a5f8d82" +dependencies = [ + "datafusion-common", + "datafusion-physical-expr-common", +] + +[[package]] +name = "datafusion-macros" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4800e1ff7ecf8f310887e9b54c9c444b8e215ccbc7b21c2f244cfae373b1ece7" +dependencies = [ + "datafusion-expr", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "datafusion-optimizer" +version = "46.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "971c51c54cd309001376fae752fb15a6b41750b6d1552345c46afbfb6458801b" +dependencies = [ + "arrow", "chrono", "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.5", "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", - "paste", - "regex-syntax", + "regex", + "regex-syntax 0.8.5", ] [[package]] name = "datafusion-physical-expr" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a223962b3041304a3e20ed07a21d5de3d88d7e4e71ca192135db6d24e3365a4" +checksum = "e1447c2c6bc8674a16be4786b4abf528c302803fafa186aa6275692570e64d85" dependencies = [ "ahash", "arrow", - "arrow-array", - "arrow-buffer", - "arrow-ord", - "arrow-schema", - "arrow-string", - "base64 0.22.1", - "chrono", "datafusion-common", - "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", - "hex", "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", "paste", - "petgraph", - "regex", + "petgraph 0.7.1", ] [[package]] name = "datafusion-physical-expr-common" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db5e7d8532a1601cd916881db87a70b0a599900d23f3db2897d389032da53bc6" +checksum = "69f8c25dcd069073a75b3d2840a79d0f81e64bdd2c05f2d3d18939afb36a7dcb" dependencies = [ "ahash", "arrow", "datafusion-common", - "datafusion-expr", + "datafusion-expr-common", "hashbrown 0.14.5", - "rand", + "itertools 0.14.0", ] [[package]] name = "datafusion-physical-optimizer" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb9c78f308e050f5004671039786a925c3fee83b90004e9fcfd328d7febdcc0" +checksum = "68da5266b5b9847c11d1b3404ee96b1d423814e1973e1ad3789131e5ec912763" dependencies = [ + "arrow", "datafusion-common", "datafusion-execution", + "datafusion-expr", + "datafusion-expr-common", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", + "itertools 0.14.0", + "log", ] [[package]] name = "datafusion-physical-plan" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d1116949432eb2d30f6362707e2846d942e491052a206f2ddcb42d08aea1ffe" +checksum = "88cc160df00e413e370b3b259c8ea7bfbebc134d32de16325950e9e923846b7f" dependencies = [ "ahash", "arrow", - "arrow-array", - "arrow-buffer", "arrow-ord", "arrow-schema", "async-trait", @@ -1604,54 +1837,52 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", + "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-physical-expr-common", "futures", "half", "hashbrown 0.14.5", "indexmap", - "itertools 0.12.1", + "itertools 0.14.0", "log", - "once_cell", "parking_lot", "pin-project-lite", - "rand", "tokio", ] [[package]] name = "datafusion-sql" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45d0180711165fe94015d7c4123eb3e1cf5fb60b1506453200b8d1ce666bef0" +checksum = "325a212b67b677c0eb91447bf9a11b630f9fc4f62d8e5d145bf859f5a6b29e64" dependencies = [ "arrow", - "arrow-array", - "arrow-schema", + "bigdecimal", "datafusion-common", "datafusion-expr", + "indexmap", "log", "regex", "sqlparser", - "strum", ] [[package]] name = "datafusion-substrait" -version = "41.0.0" +version = "46.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0a0055aa98246c79f98f0d03df11f16cb7adc87818d02d4413e3f3cdadbbee" +checksum = "2c2be3226a683e02cff65181e66e62eba9f812ed0e9b7ec8fe11ac8dabf1a73f" dependencies = [ - "arrow-buffer", "async-recursion", + "async-trait", "chrono", "datafusion", - "itertools 0.12.1", + "itertools 0.14.0", "object_store", "pbjson-types", - "prost 0.12.6", + "prost 0.13.5", "substrait", + "tokio", "url", ] @@ -1685,6 +1916,37 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.99", +] + [[package]] name = "digest" version = "0.10.7" @@ -1725,15 +1987,9 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] -[[package]] -name = "doc-comment" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" - [[package]] name = "downcast-rs" version = "1.2.1" @@ -1742,43 +1998,135 @@ checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" [[package]] name = "either" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" [[package]] -name = "env_logger" -version = "0.10.2" +name = "encoding" +version = "0.2.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +checksum = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" +dependencies = [ + "encoding-index-japanese", + "encoding-index-korean", + "encoding-index-simpchinese", + "encoding-index-singlebyte", + "encoding-index-tradchinese", +] + +[[package]] +name = "encoding-index-japanese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8b2ff42e9a05335dbf8b5c6f7567e5591d0d916ccef4e0b1710d32a0d0c91" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-korean" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dc33fb8e6bcba213fe2f14275f0963fd16f0a02c878e3095ecfdf5bee529d81" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-simpchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87a7194909b9118fc707194baa434a4e3b0fb6a5a757c73c3adb07aa25031f7" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-singlebyte" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3351d5acffb224af9ca265f435b859c7c01537c0849754d3db3fdf2bfe2ae84a" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-tradchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd0e20d5688ce3cab59eb3ef3a2083a5c77bf496cb798dc6fcdb75f323890c18" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding_index_tests" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" + +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "encoding_rs_io" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cc3c5651fb62ab8aa3103998dade57efdd028544bd300516baa31840c252a83" +dependencies = [ + "encoding_rs", +] + +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ - "humantime", - "is-terminal", "log", "regex", - "termcolor", +] + +[[package]] +name = "env_logger" +version = "0.11.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3716d7a920fb4fac5d84e9d4bce8ceb321e9414b4409da61b07b75c1e3d0697" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", ] [[package]] name = "equivalent" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1800,9 +2148,9 @@ dependencies = [ [[package]] name = "event-listener" -version = "5.3.1" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -1811,11 +2159,11 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2" dependencies = [ - "event-listener 5.3.1", + "event-listener 5.4.0", "pin-project-lite", ] @@ -1827,9 +2175,9 @@ checksum = "9afc2bd4d5a73106dd53d10d73d3401c2f32730ba2c0b93ddb888a8983680471" [[package]] name = "fastrand" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "filetime" @@ -1849,11 +2197,17 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" -version = "24.3.25" +version = "24.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" +checksum = "4f1baf0dbf96932ec9a3038d57900329c015b0bfb7b63d904f3bc27e2b02a096" dependencies = [ "bitflags 1.3.2", "rustc_version", @@ -1861,9 +2215,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.35" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" dependencies = [ "crc32fast", "miniz_oxide", @@ -1877,9 +2231,24 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.3" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" @@ -1902,11 +2271,20 @@ dependencies = [ [[package]] name = "fsst" -version = "0.20.0" +version = "0.26.2" dependencies = [ "rand", ] +[[package]] +name = "fst" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" +dependencies = [ + "utf8-ranges", +] + [[package]] name = "funty" version = "2.0.0" @@ -1963,9 +2341,9 @@ checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cef40d21ae2c515b51041df9ed313ed21e572df340ea58a922a0aefe7e8891a1" +checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532" dependencies = [ "fastrand", "futures-core", @@ -1982,7 +2360,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -2015,6 +2393,28 @@ dependencies = [ "slab", ] +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "generator" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bd114ceda131d3b1d665eba35788690ad37f5916457286b32ab6fd3c438dd" +dependencies = [ + "cfg-if", + "libc", + "log", + "rustversion", + "windows", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2034,10 +2434,22 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + [[package]] name = "gimli" version = "0.31.1" @@ -2046,9 +2458,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "gloo-timers" @@ -2083,16 +2495,16 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" +checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.1.0", + "http 1.2.0", "indexmap", "slab", "tokio", @@ -2102,9 +2514,9 @@ dependencies = [ [[package]] name = "half" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" dependencies = [ "cfg-if", "crunchy", @@ -2123,9 +2535,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ "allocator-api2", "equivalent", @@ -2173,11 +2585,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2210,9 +2622,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" dependencies = [ "bytes", "fnv", @@ -2237,7 +2649,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.1.0", + "http 1.2.0", ] [[package]] @@ -2248,16 +2660,16 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -2267,15 +2679,15 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "humantime" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" [[package]] name = "hyper" -version = "0.14.31" +version = "0.14.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" dependencies = [ "bytes", "futures-channel", @@ -2297,15 +2709,15 @@ dependencies = [ [[package]] name = "hyper" -version = "1.5.1" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.7", - "http 1.1.0", + "h2 0.4.8", + "http 1.2.0", "http-body 1.0.1", "httparse", "itoa", @@ -2323,7 +2735,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.31", + "hyper 0.14.32", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2333,19 +2745,35 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.3" +version = "0.27.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.1.0", - "hyper 1.5.1", + "http 1.2.0", + "hyper 1.6.0", "hyper-util", - "rustls 0.23.17", + "rustls 0.23.23", "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", "tower-service", ] @@ -2358,9 +2786,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", - "hyper 1.5.1", + "hyper 1.6.0", "pin-project-lite", "socket2", "tokio", @@ -2388,7 +2816,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -2515,9 +2943,15 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.0.3" @@ -2539,21 +2973,44 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "include-flate" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df49c16750695486c1f34de05da5b7438096156466e7f76c38fcdf285cf0113e" +dependencies = [ + "include-flate-codegen", + "lazy_static", + "libflate", +] + +[[package]] +name = "include-flate-codegen" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c5b246c6261be723b85c61ecf87804e8ea4a35cb68be0ff282ed84b95ffe7d7" +dependencies = [ + "libflate", + "proc-macro2", + "quote", + "syn 2.0.99", +] + [[package]] name = "indexmap" -version = "2.6.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", - "hashbrown 0.15.1", + "hashbrown 0.15.2", ] [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "instant" @@ -2581,20 +3038,15 @@ checksum = "0d762194228a2f1c11063e46e32e5acb96e66e906382b9eb5441f2e0504bbd5a" [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] -name = "is-terminal" -version = "0.4.13" +name = "is_terminal_polyfill" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" -dependencies = [ - "hermit-abi 0.4.0", - "libc", - "windows-sys 0.52.0", -] +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "itertools" @@ -2632,11 +3084,68 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.13" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "jieba-macros" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c676b32a471d3cfae8dac2ad2f8334cd52e53377733cca8c1fb0a5062fec192" +dependencies = [ + "phf_codegen", +] + +[[package]] +name = "jieba-rs" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1bcad6332969e4d48ee568d430e14ee6dea70740c2549d005d87677ebefb0c" +dependencies = [ + "cedarwood", + "fxhash", + "include-flate", + "jieba-macros", + "lazy_static", + "phf", + "regex", +] + +[[package]] +name = "jiff" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d699bc6dfc879fb1bf9bdff0d4c56f0884fc6f0d0eb0fba397a6d00cd9a6b85e" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde", +] + +[[package]] +name = "jiff-static" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d16e75759ee0aa64c57a56acbf43916987b20c77373cb7e808979e02b93c9f9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] [[package]] name = "jobserver" @@ -2649,13 +3158,23 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.72" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ + "once_cell", "wasm-bindgen", ] +[[package]] +name = "kanaria" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0f9d9652540055ac4fded998a73aca97d965899077ab1212587437da44196ff" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "kv-log-macro" version = "1.0.7" @@ -2667,12 +3186,13 @@ dependencies = [ [[package]] name = "lance" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-arith", "arrow-array", "arrow-buffer", + "arrow-ipc", "arrow-ord", "arrow-row", "arrow-schema", @@ -2685,13 +3205,16 @@ dependencies = [ "byteorder", "bytes", "chrono", - "dashmap 5.5.3", + "dashmap", "datafusion", + "datafusion-expr", "datafusion-functions", "datafusion-physical-expr", "deepsize", + "either", "futures", "half", + "humantime", "itertools 0.13.0", "lance-arrow", "lance-core", @@ -2709,8 +3232,8 @@ dependencies = [ "permutation", "pin-project", "prost 0.12.6", - "prost-build 0.12.6", - "prost-types 0.12.6", + "prost 0.13.5", + "prost-types 0.13.5", "rand", "roaring", "serde", @@ -2727,7 +3250,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "arrow-buffer", @@ -2735,7 +3258,8 @@ dependencies = [ "arrow-data", "arrow-schema", "arrow-select", - "getrandom", + "bytes", + "getrandom 0.2.15", "half", "num-traits", "rand", @@ -2743,7 +3267,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "arrow-buffer", @@ -2765,7 +3289,7 @@ dependencies = [ "num_cpus", "object_store", "pin-project", - "prost 0.12.6", + "prost 0.13.5", "rand", "roaring", "serde_json", @@ -2779,7 +3303,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -2796,16 +3320,20 @@ dependencies = [ "futures", "lance-arrow", "lance-core", + "lance-datagen", "lazy_static", "log", - "prost 0.12.6", + "pin-project", + "prost 0.13.5", "snafu", + "tempfile", "tokio", + "tracing", ] [[package]] name = "lance-datagen" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -2820,7 +3348,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrayref", "arrow", @@ -2843,11 +3371,12 @@ dependencies = [ "lance-core", "lazy_static", "log", + "lz4", "num-traits", "paste", - "prost 0.12.6", - "prost-build 0.12.6", - "prost-types 0.12.6", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", "rand", "seq-macro", "snafu", @@ -2858,7 +3387,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-arith", "arrow-array", @@ -2880,9 +3409,9 @@ dependencies = [ "log", "num-traits", "object_store", - "prost 0.12.6", - "prost-build 0.12.6", - "prost-types 0.12.6", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", "roaring", "snafu", "tempfile", @@ -2892,7 +3421,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -2910,9 +3439,12 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "deepsize", + "dirs", + "fst", "futures", "half", "itertools 0.13.0", + "jieba-rs", "lance-arrow", "lance-core", "lance-datafusion", @@ -2922,12 +3454,14 @@ dependencies = [ "lance-linalg", "lance-table", "lazy_static", + "lindera", + "lindera-tantivy", "log", "moka", "num-traits", "object_store", - "prost 0.12.6", - "prost-build 0.12.6", + "prost 0.13.5", + "prost-build 0.13.5", "rand", "rayon", "roaring", @@ -2943,7 +3477,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-arith", @@ -2970,8 +3504,7 @@ dependencies = [ "object_store", "path_abs", "pin-project", - "prost 0.12.6", - "prost-build 0.12.6", + "prost 0.13.5", "rand", "shellexpand", "snafu", @@ -2982,7 +3515,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow-array", "arrow-ord", @@ -3005,7 +3538,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -3027,9 +3560,9 @@ dependencies = [ "lazy_static", "log", "object_store", - "prost 0.12.6", - "prost-build 0.12.6", - "prost-types 0.12.6", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", "rand", "rangemap", "roaring", @@ -3056,9 +3589,9 @@ checksum = "0c2cdeb66e45e9f36bfad5bbdb4d2384e70936afbee843c6f6543f0c551ebb25" [[package]] name = "lexical-core" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -3069,9 +3602,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -3080,9 +3613,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.6" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" dependencies = [ "lexical-util", "static_assertions", @@ -3090,18 +3623,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" dependencies = [ "lexical-util", "lexical-write-integer", @@ -3110,9 +3643,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" dependencies = [ "lexical-util", "static_assertions", @@ -3120,9 +3653,33 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.164" +version = "0.2.170" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" + +[[package]] +name = "libflate" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45d9dfdc14ea4ef0900c1cddbc8dcd553fbaacd8a4a282cf4018ae9dd04fb21e" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77", +] + +[[package]] +name = "libflate_lz77" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" +checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" +dependencies = [ + "core2", + "hashbrown 0.14.5", + "rle-decode-fast", +] [[package]] name = "libm" @@ -3136,22 +3693,83 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", "libc", "redox_syscall", ] +[[package]] +name = "lindera" +version = "0.38.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fff887f4b98539fb5f879ede50e17eb7eaafa5622c252cffe8280f42cafc6b7d" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "csv", + "kanaria", + "lindera-dictionary", + "once_cell", + "regex", + "serde", + "serde_json", + "serde_yaml", + "strum", + "strum_macros", + "unicode-blocks", + "unicode-normalization", + "unicode-segmentation", + "yada", +] + +[[package]] +name = "lindera-dictionary" +version = "0.38.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec716483ceb95aa84ac262cb766eef314b24257c343ca230daa71f856a278fe4" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "csv", + "derive_builder", + "encoding", + "encoding_rs", + "encoding_rs_io", + "flate2", + "glob", + "log", + "once_cell", + "reqwest", + "serde", + "tar", + "thiserror 2.0.12", + "yada", +] + +[[package]] +name = "lindera-tantivy" +version = "0.38.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261c87882a909fd17db4dd797e4dc2aac3992bdbbb4e2900d1362a1e0746266f" +dependencies = [ + "lindera", + "tantivy", + "tantivy-tokenizer-api", +] + [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "litemap" -version = "0.7.3" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" +checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "lock_api" @@ -3165,40 +3783,61 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" dependencies = [ "value-bag", ] +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru" version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.1", + "hashbrown 0.15.2", ] [[package]] -name = "lz4_flex" -version = "0.11.3" +name = "lz4" +version = "1.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" dependencies = [ - "twox-hash", + "lz4-sys", ] [[package]] -name = "lzma-sys" -version = "0.1.20" +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" dependencies = [ "cc", "libc", - "pkg-config", +] + +[[package]] +name = "lz4_flex" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +dependencies = [ + "twox-hash", ] [[package]] @@ -3207,6 +3846,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "md-5" version = "0.10.6" @@ -3265,22 +3913,21 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" dependencies = [ "adler2", ] [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi 0.3.9", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -3295,25 +3942,23 @@ dependencies = [ [[package]] name = "moka" -version = "0.12.8" +version = "0.12.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f" +checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" dependencies = [ "async-lock", - "async-trait", "crossbeam-channel", "crossbeam-epoch", "crossbeam-utils", - "event-listener 5.3.1", + "event-listener 5.4.0", "futures-util", - "once_cell", + "loom", "parking_lot", - "quanta", + "portable-atomic", "rustc_version", "smallvec", "tagptr", "thiserror 1.0.69", - "triomphe", "uuid", ] @@ -3323,18 +3968,29 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" -[[package]] -name = "multimap" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" - [[package]] name = "murmurhash32" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "noisy_float" version = "0.2.0" @@ -3456,26 +4112,27 @@ dependencies = [ [[package]] name = "object" -version = "0.36.5" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.10.2" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" +checksum = "3cfccb68961a56facde1163f9319e0d15743352344e7808a11795fb99698dcaf" dependencies = [ "async-trait", "base64 0.22.1", "bytes", "chrono", "futures", + "httparse", "humantime", - "hyper 1.5.1", + "hyper 1.6.0", "itertools 0.13.0", "md-5", "parking_lot", @@ -3496,21 +4153,59 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.2" +version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" [[package]] name = "oneshot" -version = "0.1.8" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" + +[[package]] +name = "openssl" +version = "0.10.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" +dependencies = [ + "bitflags 2.9.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] [[package]] name = "option-ext" @@ -3529,9 +4224,9 @@ dependencies = [ [[package]] name = "outref" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] name = "overload" @@ -3579,9 +4274,9 @@ dependencies = [ [[package]] name = "parquet" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e977b9066b4d3b03555c22bdc442f3fadebd96a39111249113087d0edb2691cd" +checksum = "f88838dca3b84d41444a0341b19f347e8098a3898b0f21536654b8b799e11abd" dependencies = [ "ahash", "arrow-array", @@ -3598,13 +4293,14 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "lz4_flex", "num", "num-bigint", "object_store", "paste", "seq-macro", + "simdutf8", "snap", "thrift", "tokio", @@ -3642,9 +4338,9 @@ dependencies = [ [[package]] name = "pbjson" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1030c719b0ec2a2d25a5df729d6cff1acf3cc230bf766f4f97833591f7577b90" +checksum = "c7e6349fa080353f4a597daffd05cb81572a9c031a6d4fff7e504947496fcc68" dependencies = [ "base64 0.21.7", "serde", @@ -3652,28 +4348,28 @@ dependencies = [ [[package]] name = "pbjson-build" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2580e33f2292d34be285c5bc3dba5259542b083cfad6037b6d70345f24dcb735" +checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" dependencies = [ - "heck 0.4.1", - "itertools 0.11.0", - "prost 0.12.6", - "prost-types 0.12.6", + "heck 0.5.0", + "itertools 0.13.0", + "prost 0.13.5", + "prost-types 0.13.5", ] [[package]] name = "pbjson-types" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18f596653ba4ac51bdecbb4ef6773bc7f56042dc13927910de1684ad3d32aa12" +checksum = "e54e5e7bfb1652f95bc361d76f3c780d8e526b134b85417e774166ee941f0887" dependencies = [ "bytes", "chrono", "pbjson", "pbjson-build", - "prost 0.12.6", - "prost-build 0.12.6", + "prost 0.13.5", + "prost-build 0.13.5", "serde", ] @@ -3695,24 +4391,34 @@ version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ - "fixedbitset", + "fixedbitset 0.4.2", + "indexmap", +] + +[[package]] +name = "petgraph" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +dependencies = [ + "fixedbitset 0.5.7", "indexmap", ] [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", "phf_shared", @@ -3720,9 +4426,9 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", "rand", @@ -3730,38 +4436,38 @@ dependencies = [ [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ "siphasher", ] [[package]] name = "pin-project" -version = "1.1.7" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.7" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "pin-project-lite" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -3782,9 +4488,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "polling" @@ -3803,9 +4509,18 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.9.0" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] [[package]] name = "powerfmt" @@ -3834,19 +4549,19 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.25" +version = "0.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" +checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" dependencies = [ "proc-macro2", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" dependencies = [ "unicode-ident", ] @@ -3871,6 +4586,16 @@ dependencies = [ "prost-derive 0.12.6", ] +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive 0.13.5", +] + [[package]] name = "prost-build" version = "0.11.9" @@ -3882,8 +4607,8 @@ dependencies = [ "itertools 0.10.5", "lazy_static", "log", - "multimap 0.8.3", - "petgraph", + "multimap", + "petgraph 0.6.5", "prettyplease 0.1.25", "prost 0.11.9", "prost-types 0.11.9", @@ -3903,14 +4628,34 @@ dependencies = [ "heck 0.5.0", "itertools 0.12.1", "log", - "multimap 0.10.0", + "multimap", "once_cell", - "petgraph", - "prettyplease 0.2.25", + "petgraph 0.6.5", + "prettyplease 0.2.30", "prost 0.12.6", "prost-types 0.12.6", "regex", - "syn 2.0.89", + "syn 2.0.99", + "tempfile", +] + +[[package]] +name = "prost-build" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +dependencies = [ + "heck 0.5.0", + "itertools 0.14.0", + "log", + "multimap", + "once_cell", + "petgraph 0.7.1", + "prettyplease 0.2.30", + "prost 0.13.5", + "prost-types 0.13.5", + "regex", + "syn 2.0.99", "tempfile", ] @@ -3937,7 +4682,20 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.99", ] [[package]] @@ -3958,9 +4716,27 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "prost-types" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +dependencies = [ + "prost 0.13.5", +] + +[[package]] +name = "psm" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f58e5423e24c18cc840e1c98370b3993c6649cd1678b4d24318bcf0a083cbe88" +dependencies = [ + "cc", +] + [[package]] name = "pylance" -version = "0.20.0" +version = "0.26.2" dependencies = [ "arrow", "arrow-array", @@ -3986,7 +4762,7 @@ dependencies = [ "lazy_static", "log", "object_store", - "prost 0.12.6", + "prost 0.13.5", "prost-build 0.11.9", "pyo3", "serde", @@ -4003,15 +4779,15 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", - "parking_lot", + "once_cell", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -4021,9 +4797,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" dependencies = [ "once_cell", "target-lexicon", @@ -4031,9 +4807,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" dependencies = [ "libc", "pyo3-build-config", @@ -4041,49 +4817,34 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "pyo3-macros-backend" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.89", -] - -[[package]] -name = "quanta" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" -dependencies = [ - "crossbeam-utils", - "libc", - "once_cell", - "raw-cpuid", - "wasi", - "web-sys", - "winapi", + "syn 2.0.99", ] [[package]] name = "quick-xml" -version = "0.36.2" +version = "0.37.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" +checksum = "165859e9e55f79d67b96c5d96f4e88b6f2695a1972849c15a6a3f5c59fc2c003" dependencies = [ "memchr", "serde", @@ -4099,10 +4860,10 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.0.0", - "rustls 0.23.17", + "rustc-hash 2.1.1", + "rustls 0.23.23", "socket2", - "thiserror 2.0.3", + "thiserror 2.0.12", "tokio", "tracing", ] @@ -4114,14 +4875,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", - "getrandom", + "getrandom 0.2.15", "rand", "ring", - "rustc-hash 2.0.0", - "rustls 0.23.17", + "rustc-hash 2.1.1", + "rustls 0.23.23", "rustls-pki-types", "slab", - "thiserror 2.0.3", + "thiserror 2.0.12", "tinyvec", "tracing", "web-time", @@ -4129,9 +4890,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.7" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" +checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" dependencies = [ "cfg_aliases", "libc", @@ -4143,9 +4904,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.37" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" dependencies = [ "proc-macro2", ] @@ -4183,7 +4944,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -4211,15 +4972,6 @@ version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f60fcc7d6849342eff22c4350c8b9a989ee8ceabc4b481253e8946b9fe83d684" -[[package]] -name = "raw-cpuid" -version = "11.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" -dependencies = [ - "bitflags 2.6.0", -] - [[package]] name = "rayon" version = "1.10.0" @@ -4240,13 +4992,33 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn 2.0.99", +] + [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", ] [[package]] @@ -4255,7 +5027,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror 1.0.69", ] @@ -4268,8 +5040,17 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -4280,7 +5061,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -4289,6 +5070,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -4297,40 +5084,43 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "regress" -version = "0.9.1" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eae2a1ebfecc58aff952ef8ccd364329abe627762f5bf09ff42eb9d98522479" +checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366" dependencies = [ - "hashbrown 0.14.5", + "hashbrown 0.15.2", "memchr", ] [[package]] name = "reqwest" -version = "0.12.9" +version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", - "h2 0.4.7", - "http 1.1.0", + "h2 0.4.8", + "http 1.2.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.1", - "hyper-rustls 0.27.3", + "hyper 1.6.0", + "hyper-rustls 0.27.5", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.17", + "rustls 0.23.23", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -4338,9 +5128,12 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper", + "system-configuration", "tokio", - "tokio-rustls 0.26.0", + "tokio-native-tls", + "tokio-rustls 0.26.2", "tokio-util", + "tower", "tower-service", "url", "wasm-bindgen", @@ -4352,24 +5145,29 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.8" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", - "spin", "untrusted", "windows-sys 0.52.0", ] +[[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + [[package]] name = "roaring" -version = "0.10.6" +version = "0.10.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +checksum = "a652edd001c53df0b3f96a36a8dc93fce6866988efc16808235653c6bcac8bf2" dependencies = [ "bytemuck", "byteorder", @@ -4399,9 +5197,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc-hash" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustc_version" @@ -4414,15 +5212,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.41" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4439,9 +5237,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.17" +version = "0.23.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" +checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ "log", "once_cell", @@ -4473,7 +5271,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework 3.0.1", + "security-framework 3.2.0", ] [[package]] @@ -4496,9 +5294,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" dependencies = [ "web-time", ] @@ -4526,15 +5324,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -4556,9 +5354,9 @@ dependencies = [ [[package]] name = "schemars" -version = "0.8.21" +version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09c024468a378b7e36765cd36702b7a90cc3cba11654f6685c8f233408e89e92" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ "dyn-clone", "schemars_derive", @@ -4568,16 +5366,22 @@ dependencies = [ [[package]] name = "schemars_derive" -version = "0.8.21" +version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1eee588578aff73f856ab961cd2f79e36bc45d7ded33a7562adba4667aecc0e" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.89", + "syn 2.0.99", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -4600,7 +5404,7 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", "core-foundation 0.9.4", "core-foundation-sys", "libc", @@ -4609,11 +5413,11 @@ dependencies = [ [[package]] name = "security-framework" -version = "3.0.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.9.0", "core-foundation 0.10.0", "core-foundation-sys", "libc", @@ -4622,9 +5426,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.1" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -4632,37 +5436,37 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" dependencies = [ "serde", ] [[package]] name = "seq-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.215" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.215" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -4673,14 +5477,14 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -4697,7 +5501,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -4769,11 +5573,17 @@ dependencies = [ "libc", ] +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "siphasher" -version = "0.3.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "sketches-ddsketch" @@ -4795,30 +5605,29 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" [[package]] name = "snafu" -version = "0.7.5" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" dependencies = [ - "doc-comment", "snafu-derive", ] [[package]] name = "snafu-derive" -version = "0.7.5" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.99", ] [[package]] @@ -4829,39 +5638,34 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", ] -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "sqlparser" -version = "0.49.0" +version = "0.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +checksum = "c66e3b7374ad4a6af849b08b3e7a6eda0edbd82f0fd59b57e22671bf16979899" dependencies = [ "log", + "recursive", "sqlparser_derive", ] [[package]] name = "sqlparser_derive" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" +checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -4870,6 +5674,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601f9201feb9b09c00266478bf459952b9ef9a6b94edb2f21eba14ab681a60a9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -4888,6 +5705,12 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "strum" version = "0.26.3" @@ -4907,29 +5730,30 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "substrait" -version = "0.36.0" +version = "0.53.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1ee6e584c8bf37104b7eb51c25eae07a9321b0e01379bec3b7c462d2f42afbf" +checksum = "6fac3d70185423235f37b889764e184b81a5af4bb7c95833396ee9bd92577e1b" dependencies = [ "heck 0.5.0", "pbjson", "pbjson-build", "pbjson-types", - "prettyplease 0.2.25", - "prost 0.12.6", - "prost-build 0.12.6", - "prost-types 0.12.6", + "prettyplease 0.2.30", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", + "regress", "schemars", "semver", "serde", "serde_json", "serde_yaml", - "syn 2.0.89", + "syn 2.0.99", "typify", "walkdir", ] @@ -4953,9 +5777,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.89" +version = "2.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" +checksum = "e02e925281e18ffd9d640e234264753c43edc62d64b2d4cf898f1bc5e75f3fc2" dependencies = [ "proc-macro2", "quote", @@ -4979,7 +5803,28 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.9.0", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", ] [[package]] @@ -5084,7 +5929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d60769b80ad7953d8a7b2c70cdfe722bbcdcac6bccc8ac934c40c034d866fc18" dependencies = [ "byteorder", - "regex-syntax", + "regex-syntax 0.8.5", "utf8-ranges", ] @@ -5137,9 +5982,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tar" -version = "0.4.43" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c65998313f8e17d0d553d28f91a0df93e4dbbbf770279c7bc21ca0f09ea1a1f6" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" dependencies = [ "filetime", "libc", @@ -5154,26 +5999,18 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.14.0" +version = "3.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" dependencies = [ "cfg-if", "fastrand", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", ] -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - [[package]] name = "tfrecord" version = "0.15.0" @@ -5214,11 +6051,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.3" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.3", + "thiserror-impl 2.0.12", ] [[package]] @@ -5229,18 +6066,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "thiserror-impl" -version = "2.0.3" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -5266,9 +6103,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", "itoa", @@ -5287,9 +6124,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" dependencies = [ "num-conv", "time-core", @@ -5316,9 +6153,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" dependencies = [ "tinyvec_macros", ] @@ -5331,9 +6168,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.41.1" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -5348,13 +6185,23 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", ] [[package]] @@ -5369,20 +6216,19 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ - "rustls 0.23.17", - "rustls-pki-types", + "rustls 0.23.23", "tokio", ] [[package]] name = "tokio-stream" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", @@ -5391,9 +6237,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", @@ -5402,6 +6248,27 @@ dependencies = [ "tokio", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" @@ -5410,9 +6277,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -5421,13 +6288,13 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] @@ -5443,9 +6310,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -5464,24 +6331,22 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", ] -[[package]] -name = "triomphe" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" - [[package]] name = "try-lock" version = "0.2.5" @@ -5500,15 +6365,15 @@ dependencies = [ [[package]] name = "typenum" -version = "1.17.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "typify" -version = "0.1.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb6beec125971dda80a086f90b4a70f60f222990ce4d63ad0fc140492f53444" +checksum = "e03ba3643450cfd95a1aca2e1938fef63c1c1994489337998aff4ad771f21ef8" dependencies = [ "typify-impl", "typify-macro", @@ -5516,9 +6381,9 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.1.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93bbb24e990654aff858d80fee8114f4322f7d7a1b1ecb45129e2fcb0d0ad5ae" +checksum = "bce48219a2f3154aaa2c56cbf027728b24a3c8fe0a47ed6399781de2b3f3eeaf" dependencies = [ "heck 0.5.0", "log", @@ -5529,16 +6394,16 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.89", - "thiserror 1.0.69", + "syn 2.0.99", + "thiserror 2.0.12", "unicode-ident", ] [[package]] name = "typify-macro" -version = "0.1.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8e6491896e955692d68361c68db2b263e3bec317ec0b684e0e2fa882fb6e31e" +checksum = "68b5780d745920ed73c5b7447496a9b5c42ed2681a9b70859377aec423ecf02b" dependencies = [ "proc-macro2", "quote", @@ -5547,15 +6412,30 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.89", + "syn 2.0.99", "typify-impl", ] +[[package]] +name = "unicode-blocks" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b12e05d9e06373163a9bb6bb8c263c261b396643a99445fe6b9811fd376581b" + [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unicode-normalization" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] [[package]] name = "unicode-segmentation" @@ -5571,9 +6451,9 @@ checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" [[package]] name = "unsafe-libyaml" @@ -5589,15 +6469,15 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.10.1" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" dependencies = [ "base64 0.22.1", "flate2", "log", "once_cell", - "rustls 0.23.17", + "rustls 0.23.23", "rustls-pki-types", "url", "webpki-roots", @@ -5605,9 +6485,9 @@ dependencies = [ [[package]] name = "url" -version = "2.5.3" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", @@ -5638,21 +6518,29 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" -version = "1.11.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587" dependencies = [ - "getrandom", + "getrandom 0.3.1", + "js-sys", "serde", + "wasm-bindgen", ] [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "value-bag" @@ -5660,6 +6548,12 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ef4c4aa54d5d05a279399bfa921ec387b7aba77caf7a682ae8d86785b8fdad2" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -5697,49 +6591,59 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.45" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5747,22 +6651,25 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-streams" @@ -5779,9 +6686,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.72" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -5799,9 +6706,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.7" +version = "0.26.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" dependencies = [ "rustls-pki-types", ] @@ -5849,6 +6756,16 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core 0.58.0", + "windows-targets 0.52.6", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -5858,6 +6775,41 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + [[package]] name = "windows-registry" version = "0.2.0" @@ -6036,6 +6988,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.9.0", +] + [[package]] name = "write16" version = "1.0.0" @@ -6059,9 +7020,9 @@ dependencies = [ [[package]] name = "xattr" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" dependencies = [ "libc", "linux-raw-sys", @@ -6075,19 +7036,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] -name = "xz2" -version = "0.1.7" +name = "yada" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" -dependencies = [ - "lzma-sys", -] +checksum = "aed111bd9e48a802518765906cbdadf0b45afb72b9c81ab049a3b86252adffdd" [[package]] name = "yoke" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" dependencies = [ "serde", "stable_deref_trait", @@ -6097,13 +7055,13 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", "synstructure", ] @@ -6125,27 +7083,27 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "zerofrom" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", "synstructure", ] @@ -6174,14 +7132,14 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.89", + "syn 2.0.99", ] [[package]] name = "zstd" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" dependencies = [ "zstd-safe", ] @@ -6197,9 +7155,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.12+zstd.1.5.6" +version = "2.0.13+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" dependencies = [ "cc", "pkg-config", diff --git a/python/Cargo.toml b/python/Cargo.toml index f19fafab571..604b3a7035c 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.20.0" +version = "0.26.2" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" @@ -12,17 +12,17 @@ name = "lance" crate-type = ["cdylib"] [dependencies] -arrow = { version = "52.2", features = ["pyarrow"] } -arrow-array = "52.2" -arrow-data = "52.2" -arrow-schema = "52.2" -arrow-select = "52.2" -object_store = "0.10.1" +arrow = { version = "54.1", features = ["pyarrow"] } +arrow-array = "54.1" +arrow-data = "54.1" +arrow-schema = "54.1" +arrow-select = "54.1" +object_store = "0.11.2" async-trait = "0.1" chrono = "0.4.31" -env_logger = "0.10" +env_logger = "0.11.7" futures = "0.3" -half = { version = "2.3", default-features = false, features = [ +half = { version = "2.5", default-features = false, features = [ "num-traits", "std", ] } @@ -36,27 +36,30 @@ lance-core = { path = "../rust/lance-core" } lance-datagen = { path = "../rust/lance-datagen", optional = true } lance-encoding = { path = "../rust/lance-encoding" } lance-file = { path = "../rust/lance-file" } -lance-index = { path = "../rust/lance-index" } +lance-index = { path = "../rust/lance-index", features = [ + "tokenizer-lindera", + "tokenizer-jieba", +] } lance-io = { path = "../rust/lance-io" } lance-linalg = { path = "../rust/lance-linalg" } lance-table = { path = "../rust/lance-table" } lazy_static = "1" log = "0.4" -prost = "0.12.2" -pyo3 = { version = "0.21", features = [ +prost = "0.13.2" +pyo3 = { version = "0.23", features = [ "extension-module", "abi3-py39", - "gil-refs", + "py-clone", ] } tokio = { version = "1.23", features = ["rt-multi-thread"] } uuid = "1.3.0" serde_json = "1" serde = "1.0.197" serde_yaml = "0.9.34" -snafu = "0.7.4" +snafu = "0.8" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.17" -tracing = "0.1.37" +tracing = { version = "0.1" } url = "2.5.0" bytes = "1.4" diff --git a/python/DEVELOPMENT.md b/python/DEVELOPMENT.md index 5202701d6ed..04f84c06867 100644 --- a/python/DEVELOPMENT.md +++ b/python/DEVELOPMENT.md @@ -16,6 +16,14 @@ re-building. ## Running tests +To run the tests, first install the test packages: + +```shell +pip install '.[tests]' +``` + +then: + ```shell make test ``` diff --git a/python/Makefile b/python/Makefile index f51fe8c65cf..e566b9da3b6 100644 --- a/python/Makefile +++ b/python/Makefile @@ -31,6 +31,7 @@ lint: lint-python lint-rust lint-python: ruff format --check python ruff check python + pyright .PHONY: lint-python lint-rust: diff --git a/python/pyproject.toml b/python/pyproject.toml index c0d2f900dc6..6c9581b473a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,5 +1,6 @@ [project] name = "pylance" +dynamic = ["version"] dependencies = ["pyarrow>=14", "numpy>=1.22"] description = "python wrapper for Lance columnar format" authors = [{ name = "Lance Devs", email = "dev@lancedb.com" }] @@ -57,11 +58,9 @@ tests = [ "tensorflow", "tqdm", ] -dev = ["ruff==0.4.1"] +dev = ["ruff==0.4.1", "pyright"] benchmarks = ["pytest-benchmark"] torch = ["torch"] -cuvs-cu11 = ["cuvs-cu11", "pylibraft-cu11"] -cuvs-cu12 = ["cuvs-cu12", "pylibraft-cu12"] ray = ["ray[data]<2.38; python_version<'3.12'"] [tool.ruff] @@ -70,11 +69,25 @@ lint.select = ["F", "E", "W", "I", "G", "TCH", "PERF", "B019"] [tool.ruff.lint.per-file-ignores] "*.pyi" = ["E301", "E302"] -[tool.mypy] -python_version = "3.12" -check_untyped_defs = true -warn_redundant_casts = true -warn_unused_ignores = true +[tool.pyright] +pythonVersion = "3.12" +# TODO: expand this list as we fix more files. +include = [ + "python/lance/util.py", + "python/lance/debug.py", + "python/lance/tracing.py", + "python/lance/dependencies.py", + "python/lance/schema.py", + "python/lance/file.py", + "python/lance/util.py", +] +# Dependencies like pyarrow make this difficult to enforce strictly. +reportMissingTypeStubs = "warning" +reportImportCycles = "error" +reportUnusedImport = "error" +reportPropertyTypeMismatch = "error" +reportUnnecessaryCast = "error" + [tool.pytest.ini_options] markers = [ diff --git a/python/python/benchmarks/test_packed_struct.py b/python/python/benchmarks/test_packed_struct.py index 037470f01ce..96c887174a1 100644 --- a/python/python/benchmarks/test_packed_struct.py +++ b/python/python/benchmarks/test_packed_struct.py @@ -14,8 +14,9 @@ NUM_ROWS = 10_000_000 RANDOM_ACCESS = "indices" -NUM_INDICES = 100 +NUM_INDICES = 1000 NUM_ROUNDS = 10 +BATCH_SIZE = 16 * 1024 # This file compares benchmarks for reading and writing a StructArray column using # (i) parquet @@ -31,15 +32,12 @@ def test_data(tmp_path_factory): { "struct_col": pa.StructArray.from_arrays( [ - pc.random(NUM_ROWS).cast(pa.float32()), - pa.array(range(NUM_ROWS), type=pa.int32()), - pa.FixedSizeListArray.from_arrays( - pc.random(NUM_ROWS * 5).cast(pa.float32()), 5 - ), - pa.array(range(NUM_ROWS), type=pa.int32()), - pa.array(range(NUM_ROWS), type=pa.int32()), + pc.random(NUM_ROWS).cast(pa.float32()), # f1 + pc.random(NUM_ROWS).cast(pa.float32()), # f2 + pc.random(NUM_ROWS).cast(pa.float32()), # f3 + pc.random(NUM_ROWS).cast(pa.float32()), # f4 ], - ["f", "i", "fsl", "i2", "i3"], + ["f1", "f2", "f3", "f4"], ) } ) @@ -51,6 +49,7 @@ def test_data(tmp_path_factory): @pytest.fixture(scope="module") def random_indices(): random_indices = [random.randint(0, NUM_ROWS) for _ in range(NUM_INDICES)] + random_indices.sort() return random_indices @@ -59,12 +58,18 @@ def test_parquet_read(tmp_path: Path, benchmark, test_data, random_indices): parquet_path = tmp_path / "data.parquet" pq.write_table(test_data, parquet_path) + def read_parquet(): + parquet_file = pq.ParquetFile(parquet_path) + batches = parquet_file.iter_batches(batch_size=BATCH_SIZE) + tab_parquet = pa.Table.from_batches(batches) + return tab_parquet + if RANDOM_ACCESS == "indices": benchmark.pedantic( lambda: pq.read_table(parquet_path).take(random_indices), rounds=5 ) elif RANDOM_ACCESS == "full": - benchmark.pedantic(lambda: pq.read_table(parquet_path), rounds=5) + benchmark.pedantic(lambda: read_parquet(), rounds=5) def read_lance_file_random(lance_path, random_indices): @@ -75,7 +80,9 @@ def read_lance_file_random(lance_path, random_indices): def read_lance_file_full(lance_path): - for batch in LanceFileReader(lance_path).read_all(batch_size=1000).to_batches(): + for batch in ( + LanceFileReader(lance_path).read_all(batch_size=BATCH_SIZE).to_batches() + ): pass @@ -127,7 +134,7 @@ def test_parquet_write(tmp_path: Path, benchmark, test_data): def write_lance_file(lance_path, test_data): - with LanceFileWriter(lance_path, test_data.schema) as writer: + with LanceFileWriter(lance_path, test_data.schema, version="2.1") as writer: for batch in test_data.to_batches(): writer.write_batch(batch) diff --git a/python/python/ci_benchmarks/benchmarks/test_search.py b/python/python/ci_benchmarks/benchmarks/test_search.py index b2229d89b0b..2cf31dc32a9 100644 --- a/python/python/ci_benchmarks/benchmarks/test_search.py +++ b/python/python/ci_benchmarks/benchmarks/test_search.py @@ -34,3 +34,22 @@ def bench(): ) benchmark.pedantic(bench, rounds=1, iterations=1) + + +BTREE_FILTERS = ["image_widths = 3997", "image_widths >= 3990 AND image_widths <= 3997"] + + +@pytest.mark.parametrize("filt", BTREE_FILTERS) +def test_eda_btree_search(benchmark, filt): + dataset_uri = get_dataset_uri("image_eda") + ds = lance.dataset(dataset_uri) + + def bench(): + ds.to_table( + columns=[], + filter=filt, + with_row_id=True, + ) + + # We warmup so we can test hot index performance + benchmark.pedantic(bench, warmup_rounds=1, rounds=1, iterations=100) diff --git a/python/python/ci_benchmarks/datagen/gen_all.py b/python/python/ci_benchmarks/datagen/gen_all.py index 58281291940..31c9d71de6d 100644 --- a/python/python/ci_benchmarks/datagen/gen_all.py +++ b/python/python/ci_benchmarks/datagen/gen_all.py @@ -1,10 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -import logging - from ci_benchmarks.datagen.lineitems import gen_tcph if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) gen_tcph() diff --git a/python/python/ci_benchmarks/datagen/lineitems.py b/python/python/ci_benchmarks/datagen/lineitems.py index 8becec6b4d8..4e6d60c67b9 100644 --- a/python/python/ci_benchmarks/datagen/lineitems.py +++ b/python/python/ci_benchmarks/datagen/lineitems.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright The Lance Authors # Creates a dataset containing the TPC-H lineitems table using a prebuilt Parquet file -import logging import duckdb import lance +from lance.log import LOGGER from ci_benchmarks.datasets import get_dataset_uri @@ -13,7 +13,7 @@ def _gen_data(): - logging.info("Using DuckDB to generate TPC-H dataset") + LOGGER.info("Using DuckDB to generate TPC-H dataset") con = duckdb.connect(database=":memory:") con.execute("INSTALL tpch; LOAD tpch") con.execute("CALL dbgen(sf=10)") diff --git a/python/python/ci_benchmarks/datasets.py b/python/python/ci_benchmarks/datasets.py index bc25bdb90c1..f71da448df5 100644 --- a/python/python/ci_benchmarks/datasets.py +++ b/python/python/ci_benchmarks/datasets.py @@ -1,36 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -import logging from functools import cache from pathlib import Path import requests +from lance.log import LOGGER def _is_on_google() -> bool: - logging.info("Testing if running on Google Cloud") + LOGGER.info("Testing if running on Google Cloud") try: rsp = requests.get("http://metadata.google.internal", timeout=5) - logging.info("Metadata-Flavor: %s", rsp.headers.get("Metadata-Flavor")) + LOGGER.info("Metadata-Flavor: %s", rsp.headers.get("Metadata-Flavor")) return rsp.headers["Metadata-Flavor"] == "Google" except requests.exceptions.RequestException as ex: - logging.info("Failed to connect to metadata server: %s", ex) + LOGGER.info("Failed to connect to metadata server: %s", ex) return False @cache def _get_base_uri() -> str: if _is_on_google(): - logging.info( - "Running on Google Cloud, using gs://lance-benchmarks-ci-datasets/" - ) + LOGGER.info("Running on Google Cloud, using gs://lance-benchmarks-ci-datasets/") return "gs://lance-benchmarks-ci-datasets/" else: data_path = Path.home() / "lance-benchmarks-ci-datasets" if not data_path.exists(): data_path.mkdir(parents=True, exist_ok=True) - logging.info("Running locally, using %s", data_path) + LOGGER.info("Running locally, using %s", data_path) return f"{data_path}/" diff --git a/python/python/lance/__init__.py b/python/python/lance/__init__.py index f900a26f6c3..3487563cdf3 100644 --- a/python/python/lance/__init__.py +++ b/python/python/lance/__init__.py @@ -3,10 +3,16 @@ from __future__ import annotations +import logging +import os +import warnings from typing import TYPE_CHECKING, Dict, Optional, Union +from . import log from .blob import BlobColumn, BlobFile from .dataset import ( + DataStatistics, + FieldStatistics, LanceDataset, LanceOperation, LanceScanner, @@ -17,6 +23,7 @@ write_dataset, ) from .fragment import FragmentMetadata, LanceFragment +from .lance import ScanStatistics, bytes_read_counter, iops_counter from .schema import json_to_schema, schema_to_json from .util import sanitize_ts @@ -33,19 +40,25 @@ __all__ = [ "BlobColumn", "BlobFile", + "DataStatistics", + "FieldStatistics", "FragmentMetadata", "LanceDataset", "LanceFragment", "LanceOperation", "LanceScanner", "MergeInsertBuilder", + "ScanStatistics", "Transaction", "__version__", + "bytes_read_counter", + "iops_counter", "write_dataset", "schema_to_json", "json_to_schema", "dataset", "batch_udf", + "set_logger", ] @@ -65,7 +78,8 @@ def dataset( Parameters ---------- uri : str - Address to the Lance dataset. + Address to the Lance dataset. It can be a local file path `/tmp/data.lance`, + or a cloud object store URI, i.e., `s3://bucket/data.lance`. version : optional, int | str If specified, load a specific version of the Lance dataset. Else, loads the latest version. A version number (`int`) or a tag (`str`) can be provided. @@ -74,10 +88,9 @@ def dataset( argument value. If a version is already specified, this arg is ignored. block_size : optional, int Block size in bytes. Provide a hint for the size of the minimal I/O request. - commit_handler : optional, CommitLock - If specified, use the provided commit handler to lock the table while - committing a new version. Not necessary on object stores other than S3 - or when there are no concurrent writers. + commit_lock : optional, lance.commit.CommitLock + A custom commit lock. Only needed if your object store does not support + atomic commits. See the user guide for more details. index_cache_size : optional, int Index cache size. Index cache is a LRU cache with TTL. This number specifies the number of index pages, for example, IVF partitions, to be cached in @@ -133,3 +146,23 @@ def dataset( ) else: return ds + + +def set_logger( + file_path="pylance.log", + name="pylance", + level=logging.INFO, + format_string=None, + log_handler=None, +): + log.set_logger(file_path, name, level, format_string, log_handler) + + +def __warn_on_fork(): + warnings.warn( + "lance is not fork-safe. If you are using multiprocessing, use spawn instead." + ) + + +if hasattr(os, "register_at_fork"): + os.register_at_fork(before=__warn_on_fork) diff --git a/python/python/lance/_arrow/bf16.py b/python/python/lance/_arrow/bf16.py index 9ecd361183e..870ec370cc5 100644 --- a/python/python/lance/_arrow/bf16.py +++ b/python/python/lance/_arrow/bf16.py @@ -81,11 +81,11 @@ def from_numpy(cls, array: np.ndarray): class BFloat16Scalar(pa.ExtensionScalar): - def as_py(self) -> Optional[BFloat16]: + def as_py(self, **kwargs) -> Optional[BFloat16]: if self.value is None: return None else: - return BFloat16.from_bytes(self.value.as_py()) + return BFloat16.from_bytes(self.value.as_py(**kwargs)) def __eq__(self, other: Any): from ml_dtypes import bfloat16 diff --git a/python/python/lance/_datagen.py b/python/python/lance/_datagen.py index c592f6f5843..9c0e203cb77 100644 --- a/python/python/lance/_datagen.py +++ b/python/python/lance/_datagen.py @@ -5,6 +5,8 @@ An internal module for generating Arrow data for use in testing and benchmarking. """ +from typing import Optional + import pyarrow as pa from .lance import datagen @@ -15,7 +17,10 @@ def is_datagen_supported(): def rand_batches( - schema: pa.Schema, *, num_batches: int = None, batch_size_bytes: int = None + schema: pa.Schema, + *, + num_batches: Optional[int] = None, + batch_size_bytes: Optional[int] = None, ): if not datagen.is_datagen_supported(): raise NotImplementedError( diff --git a/python/python/lance/arrow.py b/python/python/lance/arrow.py index 69ea309f9a2..54da15705f8 100644 --- a/python/python/lance/arrow.py +++ b/python/python/lance/arrow.py @@ -517,8 +517,8 @@ def tensorflow_encoder(x): class ImageScalar(pa.ExtensionScalar): - def as_py(self): - return self.value.as_py() + def as_py(self, **kwargs): + return self.value.as_py(**kwargs) class ImageURIScalar(ImageScalar): diff --git a/python/python/lance/blob.py b/python/python/lance/blob.py index 05071224886..cf2c9ef3118 100644 --- a/python/python/lance/blob.py +++ b/python/python/lance/blob.py @@ -6,7 +6,7 @@ import pyarrow as pa -from lance.lance import LanceBlobFile +from .lance import LanceBlobFile class BlobIterator: @@ -28,7 +28,7 @@ class BlobColumn: This can be useful for working with medium-to-small binary objects that need to interface with APIs that expect file-like objects. For very large binary objects (4-8MB or more per value) you might be better off creating a blob column - and using :ref:`lance.Dataset.take_blobs` to access the blob data. + and using :py:meth:`lance.Dataset.take_blobs` to access the blob data. """ def __init__(self, blob_column: Union[pa.Array, pa.ChunkedArray]): @@ -50,13 +50,12 @@ def __iter__(self) -> Iterator[IO[bytes]]: class BlobFile(io.RawIOBase): - """ - Represents a blob in a Lance dataset as a file-like object. - """ + """Represents a blob in a Lance dataset as a file-like object.""" def __init__(self, inner: LanceBlobFile): """ - Internal only: To obtain a BlobFile use :ref:`lance.Dataset.take_blobs`. + Internal only: To obtain a BlobFile use + :py:meth:`lance.dataset.Dataset.take_blobs`. """ self.inner = inner diff --git a/python/python/lance/cuvs/__init__.py b/python/python/lance/cuvs/__init__.py deleted file mode 100644 index c41ad8c80c3..00000000000 --- a/python/python/lance/cuvs/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright The Lance Authors diff --git a/python/python/lance/cuvs/kmeans.py b/python/python/lance/cuvs/kmeans.py deleted file mode 100644 index be835c2c0a2..00000000000 --- a/python/python/lance/cuvs/kmeans.py +++ /dev/null @@ -1,143 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright The Lance Authors - - -import logging -import time -from typing import Literal, Optional, Tuple, Union - -import pyarrow as pa - -from lance.dependencies import cagra, raft_common, torch -from lance.dependencies import numpy as np -from lance.torch.kmeans import KMeans as KMeansTorch - -__all__ = ["KMeans"] - - -class KMeans(KMeansTorch): - """K-Means trains over vectors and divide into K clusters, - using cuVS as accelerator. - - This implement is built on PyTorch+cuVS, supporting Nvidia GPU only. - - Parameters - ---------- - k: int - The number of clusters - metric : str - Metric type, support "l2", "cosine" or "dot" - init: str - Initialization method. Only support "random" now. - max_iters: int - Max number of iterations to train the kmean model. - tolerance: float - Relative tolerance in regard to Frobenius norm of the difference in - the cluster centers of two consecutive iterations to declare convergence. - centroids : torch.Tensor, optional. - Provide existing centroids. - seed: int, optional - Random seed - device: str, optional - The device to run the PyTorch algorithms. Default we will pick - the most performant device on the host. See `lance.torch.preferred_device()` - For the cuVS implementation, it will be verified this is a cuda device. - """ - - def __init__( - self, - k: int, - *, - metric: Literal["l2", "euclidean", "cosine", "dot"] = "l2", - init: Literal["random"] = "random", - max_iters: int = 50, - tolerance: float = 1e-4, - centroids: Optional[torch.Tensor] = None, - seed: Optional[int] = None, - device: Optional[str] = None, - itopk_size: int = 10, - ): - if metric == "dot": - raise ValueError( - 'Kmeans::__init__: metric == "dot" is incompatible' " with cuVS" - ) - super().__init__( - k, - metric=metric, - init=init, - max_iters=max_iters, - tolerance=tolerance, - centroids=centroids, - seed=seed, - device=device, - ) - - if self.device.type != "cuda" or not torch.cuda.is_available(): - raise ValueError("KMeans::__init__: cuda is not enabled/available") - - self.itopk_size = itopk_size - self.time_rebuild = 0.0 - self.time_search = 0.0 - - def fit( - self, - data: Union[ - torch.utils.data.IterableDataset, - np.ndarray, - torch.Tensor, - pa.FixedSizeListArray, - ], - ) -> None: - self.time_rebuild = 0.0 - self.time_search = 0.0 - super().fit(data) - logging.info("Total search time: %s", self.time_search) - logging.info("Total rebuild time: %s", self.time_rebuild) - - def rebuild_index(self): - rebuild_time_start = time.time() - cagra_metric = "sqeuclidean" - dim = self.centroids.shape[1] - graph_degree = max(dim // 4, 32) - nn_descent_degree = graph_degree * 2 - index_params = cagra.IndexParams( - metric=cagra_metric, - intermediate_graph_degree=nn_descent_degree, - graph_degree=graph_degree, - build_algo="nn_descent", - compression=None, - ) - self.index = cagra.build(index_params, self.centroids) - rebuild_time_end = time.time() - self.time_rebuild += rebuild_time_end - rebuild_time_start - - self.y2 = None - - def _transform( - self, - data: torch.Tensor, - y2: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.metric == "cosine": - data = torch.nn.functional.normalize(data) - - search_time_start = time.time() - device = torch.device("cuda") - out_idx = raft_common.device_ndarray.empty((data.shape[0], 1), dtype="uint32") - out_dist = raft_common.device_ndarray.empty((data.shape[0], 1), dtype="float32") - search_params = cagra.SearchParams(itopk_size=self.itopk_size) - cagra.search( - search_params, - self.index, - data, - 1, - neighbors=out_idx, - distances=out_dist, - ) - ret = ( - torch.as_tensor(out_idx, device=device).squeeze(dim=1).view(torch.int32), - torch.as_tensor(out_dist, device=device), - ) - search_time_end = time.time() - self.time_search += search_time_end - search_time_start - return ret diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 836477c2ce3..284567022b5 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -6,19 +6,19 @@ import copy import dataclasses import json -import logging import os import random import time import uuid import warnings -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Iterable, Iterator, @@ -27,6 +27,7 @@ Optional, Sequence, Set, + Tuple, TypedDict, Union, ) @@ -35,6 +36,8 @@ import pyarrow.dataset from pyarrow import RecordBatch, Schema +from lance.log import LOGGER + from .blob import BlobFile from .dependencies import ( _check_for_hugging_face, @@ -43,22 +46,21 @@ ) from .dependencies import numpy as np from .dependencies import pandas as pd -from .fragment import FragmentMetadata, LanceFragment +from .fragment import DataFile, FragmentMetadata, LanceFragment from .lance import ( CleanupStats, + Compaction, + CompactionMetrics, + LanceSchema, + ScanStatistics, _Dataset, _MergeInsertBuilder, - _Operation, - _RewriteGroup, - _RewrittenIndex, _Scanner, _write_dataset, ) -from .lance import CompactionMetrics as CompactionMetrics from .lance import __version__ as __version__ from .lance import _Session as Session -from .optimize import Compaction -from .schema import LanceSchema +from .query import FullTextQuery from .types import _coerce_reader from .udf import BatchUDF, normalize_transform from .udf import BatchUDFCheckpoint as BatchUDFCheckpoint @@ -104,9 +106,35 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None): return super(MergeInsertBuilder, self).execute(reader) + def execute_uncommitted( + self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None + ) -> Tuple[Transaction, Dict[str, Any]]: + """Executes the merge insert operation without committing + + This function updates the original dataset and returns a dictionary with + information about merge statistics - i.e. the number of inserted, updated, + and deleted rows. + + Parameters + ---------- + + data_obj: ReaderLike + The new data to use as the source table for the operation. This parameter + can be any source of data (e.g. table / dataset) that + :func:`~lance.write_dataset` accepts. + schema: Optional[pa.Schema] + The schema of the data. This only needs to be supplied whenever the data + source is some kind of generator. + """ + reader = _coerce_reader(data_obj, schema) + + return super(MergeInsertBuilder, self).execute_uncommitted(reader) + # These next three overrides exist only to document the methods - def when_matched_update_all(self, condition: Optional[str] = None): + def when_matched_update_all( + self, condition: Optional[str] = None + ) -> "MergeInsertBuilder": """ Configure the operation to update matched rows @@ -127,7 +155,7 @@ def when_matched_update_all(self, condition: Optional[str] = None): """ return super(MergeInsertBuilder, self).when_matched_update_all(condition) - def when_not_matched_insert_all(self): + def when_not_matched_insert_all(self) -> "MergeInsertBuilder": """ Configure the operation to insert not matched rows @@ -137,7 +165,9 @@ def when_not_matched_insert_all(self): """ return super(MergeInsertBuilder, self).when_not_matched_insert_all() - def when_not_matched_by_source_delete(self, expr: Optional[str] = None): + def when_not_matched_by_source_delete( + self, expr: Optional[str] = None + ) -> "MergeInsertBuilder": """ Configure the operation to delete source rows that do not match @@ -151,7 +181,7 @@ def when_not_matched_by_source_delete(self, expr: Optional[str] = None): class LanceDataset(pa.dataset.Dataset): - """A dataset in Lance format where the data is stored at the given uri.""" + """A Lance Dataset in Lance format where the data is stored at the given uri.""" def __init__( self, @@ -167,6 +197,7 @@ def __init__( ): uri = os.fspath(uri) if isinstance(uri, Path) else uri self._uri = uri + self._storage_options = storage_options self._ds = _Dataset( uri, version, @@ -183,6 +214,7 @@ def __init__( def __deserialize__( cls, uri: str, + storage_options: Optional[Dict[str, str]], version: int, manifest: bytes, default_scan_options: Optional[Dict[str, Any]], @@ -190,6 +222,7 @@ def __deserialize__( return cls( uri, version, + storage_options=storage_options, serialized_manifest=manifest, default_scan_options=default_scan_options, ) @@ -197,6 +230,7 @@ def __deserialize__( def __reduce__(self): return type(self).__deserialize__, ( self.uri, + self._storage_options, self._ds.version(), self._ds.serialized_manifest(), self._default_scan_options, @@ -205,16 +239,24 @@ def __reduce__(self): def __getstate__(self): return ( self.uri, + self._storage_options, self._ds.version(), self._ds.serialized_manifest(), self._default_scan_options, ) def __setstate__(self, state): - self._uri, version, manifest, default_scan_options = state + ( + self._uri, + self._storage_options, + version, + manifest, + default_scan_options, + ) = state self._ds = _Dataset( self._uri, version, + storage_options=self._storage_options, manifest=manifest, default_scan_options=default_scan_options, ) @@ -222,6 +264,7 @@ def __setstate__(self, state): def __copy__(self): ds = LanceDataset.__new__(LanceDataset) ds._uri = self._uri + ds._storage_options = self._storage_options ds._ds = copy.copy(self._ds) ds._default_scan_options = self._default_scan_options return ds @@ -238,9 +281,33 @@ def uri(self) -> str: @property def tags(self) -> Tags: + """Tag management for the dataset. + + Similar to Git, tags are a way to add metadata to a specific version of the + dataset. + + .. warning:: + + Tagged versions are exempted from the :py:meth:`cleanup_old_versions()` + process. + + To remove a version that has been tagged, you must first + :py:meth:`~Tags.delete` the associated tag. + + Examples + -------- + + .. code-block:: python + + ds = lance.open("dataset.lance") + ds.tags.create("v2-prod-20250203", 10) + + tags = ds.tags.list() + + """ return Tags(self._ds) - def list_indices(self) -> List[Dict[str, Any]]: + def list_indices(self) -> List[Index]: return self._ds.load_indices() def index_statistics(self, index_name: str) -> Dict[str, Any]: @@ -270,18 +337,20 @@ def scanner( batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, - scan_in_order: bool = None, + scan_in_order: Optional[bool] = None, fragments: Optional[Iterable[LanceFragment]] = None, - full_text_query: Optional[Union[str, dict]] = None, + full_text_query: Optional[Union[str, dict, FullTextQuery]] = None, *, - prefilter: bool = None, - with_row_id: bool = None, - with_row_address: bool = None, - use_stats: bool = None, - fast_search: bool = None, + prefilter: Optional[bool] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, + fast_search: Optional[bool] = None, io_buffer_size: Optional[int] = None, late_materialization: Optional[bool | List[str]] = None, use_scalar_index: Optional[bool] = None, + include_deleted_rows: Optional[bool] = None, + scan_stats_callback: Optional[Callable[[ScanStatistics], None]] = None, ) -> LanceScanner: """Return a Scanner that can support various pushdowns. @@ -293,7 +362,7 @@ def scanner( All columns are fetched if None or unspecified. filter: pa.compute.Expression or str Expression or str that is a valid SQL where clause. See - `Lance filter pushdown `_ + `Lance filter pushdown `_ for valid SQL expressions. limit: int, default None Fetch up to this many rows. All rows if None or unspecified. @@ -311,8 +380,11 @@ def scanner( "nprobes": 1, "refine_factor": 1 } + batch_size: int, default None - The max size of batches returned. + The target size of batches returned. In some cases batches can be up to + twice this size (but never larger than this). In some cases batches can + be smaller than this size. io_buffer_size: int, default None The size of the IO buffer. See ``ScannerBuilder.io_buffer_size`` for more information. @@ -352,7 +424,7 @@ def scanner( If True, then all columns are late materialized. If False, then all columns are early materialized. If a list of strings, then only the columns in the list are - late materialized. + late materialized. The default uses a heuristic that assumes filters will select about 0.1% of the rows. If your filter is more selective (e.g. find by id) you may @@ -362,6 +434,7 @@ def scanner( query string to search for, the results will be ranked by BM25. e.g. "hello world", would match documents containing "hello" or "world". or a dictionary with the following keys: + - columns: list[str] The columns to search, currently only supports a single column in the columns list. @@ -370,18 +443,33 @@ def scanner( fast_search: bool, default False If True, then the search will only be performed on the indexed data, which yields faster search time. + scan_stats_callback: Callable[[ScanStatistics], None], default None + A callback function that will be called with the scan statistics after the + scan is complete. Errors raised by the callback will be logged but not + re-raised. + include_deleted_rows: bool, default False + If True, then rows that have been deleted, but are still present in the + fragment, will be returned. These rows will have the _rowid column set + to null. All other columns will reflect the value stored on disk and may + not be null. - Notes - ----- + Note: if this is a search operation, or a take operation (including scalar + indexed scans) then deleted rows cannot be returned. + + + .. note:: + + For now, if BOTH filter and nearest is specified, then: + + 1. nearest is executed first. + 2. The results are filtered afterwards. - For now, if BOTH filter and nearest is specified, then: - 1. nearest is executed first. - 2. The results are filtered afterwards. For debugging ANN results, you can choose to not use the index even if present by specifying ``use_index=False``. For example, the following will always return exact KNN results: + .. code-block:: python dataset.to_table(nearest={ @@ -418,7 +506,8 @@ def setopt(opt, val): setopt(builder.use_stats, use_stats) setopt(builder.use_scalar_index, use_scalar_index) setopt(builder.fast_search, fast_search) - + setopt(builder.include_deleted_rows, include_deleted_rows) + setopt(builder.scan_stats_callback, scan_stats_callback) # columns=None has a special meaning. we can't treat it as "user didn't specify" if self._default_scan_options is None: # No defaults, use user-provided, if any @@ -438,9 +527,9 @@ def setopt(opt, val): builder = builder.columns(default_columns) if full_text_query is not None: - if isinstance(full_text_query, str): + if isinstance(full_text_query, (str, FullTextQuery)): builder = builder.full_text_search(full_text_query) - else: + elif isinstance(full_text_query, dict): builder = builder.full_text_search(**full_text_query) if nearest is not None: builder = builder.nearest(**nearest) @@ -470,6 +559,13 @@ def data_storage_version(self) -> str: """ return self._ds.data_storage_version + @property + def max_field_id(self) -> int: + """ + The max_field_id in manifest + """ + return self._ds.max_field_id + def to_table( self, columns: Optional[Union[List[str], Dict[str, str]]] = None, @@ -480,19 +576,20 @@ def to_table( batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, - scan_in_order: bool = True, + scan_in_order: Optional[bool] = None, *, - prefilter: bool = False, - with_row_id: bool = False, - with_row_address: bool = False, - use_stats: bool = True, - fast_search: bool = False, - full_text_query: Optional[Union[str, dict]] = None, + prefilter: Optional[bool] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, + fast_search: Optional[bool] = None, + full_text_query: Optional[Union[str, dict, FullTextQuery]] = None, io_buffer_size: Optional[int] = None, late_materialization: Optional[bool | List[str]] = None, use_scalar_index: Optional[bool] = None, + include_deleted_rows: Optional[bool] = None, ) -> pa.Table: - """Read the data into memory as a pyarrow Table. + """Read the data into memory as a :py:class:`pyarrow.Table` Parameters ---------- @@ -502,7 +599,7 @@ def to_table( All columns are fetched if None or unspecified. filter : pa.compute.Expression or str Expression or str that is a valid SQL where clause. See - `Lance filter pushdown `_ + `Lance filter pushdown `_ for valid SQL expressions. limit: int, default None Fetch up to this many rows. All rows if None or unspecified. @@ -531,11 +628,11 @@ def to_table( The number of batches to read ahead. fragment_readahead: int, optional The number of fragments to read ahead. - scan_in_order: bool, default True + scan_in_order: bool, optional, default True Whether to read the fragments and batches in order. If false, throughput may be higher, but batches will be returned out of order and memory use might increase. - prefilter: bool, default False + prefilter: bool, optional, default False Run filter before the vector search. late_materialization: bool or List[str], default None Allows custom control over late materialization. See @@ -543,25 +640,36 @@ def to_table( use_scalar_index: bool, default True Allows custom control over scalar index usage. See ``ScannerBuilder.use_scalar_index`` for more information. - with_row_id: bool, default False + with_row_id: bool, optional, default False Return row ID. - with_row_address: bool, default False + with_row_address: bool, optional, default False Return row address - use_stats: bool, default True + use_stats: bool, optional, default True Use stats pushdown during filters. + fast_search: bool, optional, default False full_text_query: str or dict, optional query string to search for, the results will be ranked by BM25. e.g. "hello world", would match documents contains "hello" or "world". or a dictionary with the following keys: + - columns: list[str] The columns to search, currently only supports a single column in the columns list. - query: str The query string to search for. + include_deleted_rows: bool, optional, default False + If True, then rows that have been deleted, but are still present in the + fragment, will be returned. These rows will have the _rowid column set + to null. All other columns will reflect the value stored on disk and may + not be null. + + Note: if this is a search operation, or a take operation (including scalar + indexed scans) then deleted rows cannot be returned. Notes ----- If BOTH filter and nearest is specified, then: + 1. nearest is executed first. 2. The results are filtered afterward, unless pre-filter sets to True. """ @@ -584,6 +692,7 @@ def to_table( use_stats=use_stats, fast_search=fast_search, full_text_query=full_text_query, + include_deleted_rows=include_deleted_rows, ).to_table() @property @@ -596,8 +705,38 @@ def partition_expression(self): def replace_schema(self, schema: Schema): """ Not implemented (just override pyarrow dataset to prevent segfault) + + See :py:method:`replace_schema_metadata` or :py:method:`replace_field_metadata` + """ + raise NotImplementedError( + "Cannot replace the schema of a dataset. This method exists for backwards" + " compatibility with pyarrow. Use replace_schema_metadata or " + "replace_field_metadata to change the metadata" + ) + + def replace_schema_metadata(self, new_metadata: Dict[str, str]): + """ + Replace the schema metadata of the dataset + + Parameters + ---------- + new_metadata: dict + The new metadata to set + """ + self._ds.replace_schema_metadata(new_metadata) + + def replace_field_metadata(self, field_name: str, new_metadata: Dict[str, str]): """ - raise NotImplementedError("not changing schemas yet") + Replace the metadata of a field in the schema + + Parameters + ---------- + field_name: str + The name of the field to replace the metadata for + new_metadata: dict + The new metadata to set + """ + self._ds.replace_field_metadata(field_name, new_metadata) def get_fragments(self, filter: Optional[Expression] = None) -> List[LanceFragment]: """Get all fragments from the dataset. @@ -628,12 +767,12 @@ def to_batches( batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, - scan_in_order: bool = True, + scan_in_order: Optional[bool] = None, *, - prefilter: bool = False, - with_row_id: bool = False, - with_row_address: bool = False, - use_stats: bool = True, + prefilter: Optional[bool] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, full_text_query: Optional[Union[str, dict]] = None, io_buffer_size: Optional[int] = None, late_materialization: Optional[bool | List[str]] = None, @@ -645,11 +784,11 @@ def to_batches( Parameters ---------- **kwargs : dict, optional - Arguments for ``Scanner.from_dataset``. + Arguments for :py:meth:`~LanceDataset.scanner`. Returns ------- - record_batches : Iterator of RecordBatch + record_batches : Iterator of :py:class:`~pyarrow.RecordBatch` """ return self.scanner( columns=columns, @@ -707,7 +846,6 @@ def take( self, indices: Union[List[int], pa.Array], columns: Optional[Union[List[str], Dict[str, str]]] = None, - **kwargs, ) -> pa.Table: """Select rows of data by index. @@ -719,12 +857,10 @@ def take( List of column names to be fetched. Or a dictionary of column names to SQL expressions. All columns are fetched if None or unspecified. - **kwargs : dict, optional - See scanner() method for full parameter description. Returns ------- - table : Table + table : pyarrow.Table """ columns_with_transform = None if isinstance(columns, dict): @@ -773,7 +909,11 @@ def take_blobs( blob_column: str, ) -> List[BlobFile]: """ - Select blobs by row_ids. + Select blobs by row IDs. + + Instead of loading large binary blob data into memory before processing it, + this API allows you to open binary blob data as a regular Python file-like + object. For more details, see :py:class:`lance.BlobFile`. Parameters ---------- @@ -825,7 +965,9 @@ def count_rows( """ if isinstance(filter, pa.compute.Expression): # TODO: consolidate all to use scanner - return self.scanner(filter=filter).count_rows() + return self.scanner( + columns=[], with_row_id=True, filter=filter + ).count_rows() return self._ds.count_rows(filter) @@ -845,7 +987,7 @@ def join( """ raise NotImplementedError("Versioning not yet supported in Rust") - def alter_columns(self, *alterations: Iterable[Dict[str, Any]]): + def alter_columns(self, *alterations: Iterable[AlterColumn]): """Alter column name, data type, and nullability. Columns that are renamed can keep any indices that are on them. If a @@ -967,7 +1109,12 @@ def merge( def add_columns( self, - transforms: Dict[str, str] | BatchUDF | ReaderLike, + transforms: Dict[str, str] + | BatchUDF + | ReaderLike + | pyarrow.Field + | List[pyarrow.Field] + | pyarrow.Schema, read_columns: List[str] | None = None, reader_schema: Optional[pa.Schema] = None, batch_size: Optional[int] = None, @@ -995,6 +1142,8 @@ def add_columns( reference existing columns in the dataset. If this is a AddColumnsUDF, then it is a UDF that takes a batch of existing data and returns a new batch with the new columns. + If this is :class:`pyarrow.Field` or :class:`pyarrow.Schema`, it adds + all NULL columns with the given schema, in a metadata-only operation. read_columns : list of str, optional The names of the columns that the UDF will read. If None, then the UDF will read all columns. This is only used when transforms is a @@ -1034,6 +1183,18 @@ def add_columns( LanceDataset.merge : Merge a pre-computed set of columns into the dataset. """ + if isinstance(transforms, pa.Field): + transforms = [transforms] + if ( + isinstance(transforms, list) + and len(transforms) > 0 + and isinstance(transforms[0], pa.Field) + ): + transforms = pa.schema(transforms) + if isinstance(transforms, pa.Schema): + self._ds.add_columns_with_schema(transforms) + return + transforms = normalize_transform(transforms, self, read_columns, reader_schema) if isinstance(transforms, pa.RecordBatchReader): self._ds.add_columns_from_reader(transforms, batch_size) @@ -1177,6 +1338,11 @@ def merge_insert( Examples -------- + + Use `when_matched_update_all()` and `when_not_matched_insert_all()` to + perform an "upsert" operation. This will update rows that already exist + in the dataset and insert rows that do not exist. + >>> import lance >>> import pyarrow as pa >>> table = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]}) @@ -1194,6 +1360,51 @@ def merge_insert( 1 2 x 2 3 y 3 4 z + + Use `when_not_matched_insert_all()` to perform an "insert if not exists" + operation. This will only insert rows that do not already exist in the + dataset. + + >>> import lance + >>> import pyarrow as pa + >>> table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + >>> dataset = lance.write_dataset(table, "example2") + >>> new_table = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) + >>> # Perform an "insert if not exists" operation + >>> dataset.merge_insert("a") \\ + ... .when_not_matched_insert_all() \\ + ... .execute(new_table) + {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0} + >>> dataset.to_table().sort_by("a").to_pandas() + a b + 0 1 a + 1 2 b + 2 3 c + 3 4 z + + You are not required to provide all the columns. If you only want to + update a subset of columns, you can omit columns you don't want to + update. Omitted columns will keep their existing values if they are + updated, or will be null if they are inserted. + + >>> import lance + >>> import pyarrow as pa + >>> table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"], \\ + ... "c": ["x", "y", "z"]}) + >>> dataset = lance.write_dataset(table, "example3") + >>> new_table = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) + >>> # Perform an "upsert" operation, only updating column "a" + >>> dataset.merge_insert("a") \\ + ... .when_matched_update_all() \\ + ... .when_not_matched_insert_all() \\ + ... .execute(new_table) + {'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0} + >>> dataset.to_table().sort_by("a").to_pandas() + a b c + 0 1 a x + 1 2 x y + 2 3 y z + 3 4 z None """ return MergeInsertBuilder(self._ds, on) @@ -1201,7 +1412,7 @@ def update( self, updates: Dict[str, str], where: Optional[str] = None, - ) -> Dict[str, Any]: + ) -> UpdateResult: """ Update column values for rows matching where. @@ -1353,6 +1564,7 @@ def create_scalar_index( Literal["LABEL_LIST"], Literal["INVERTED"], Literal["FTS"], + Literal["NGRAM"], ], name: Optional[str] = None, *, @@ -1391,7 +1603,7 @@ def create_scalar_index( ) - There are 4 types of scalar indices available today. + There are 5 types of scalar indices available today. * ``BTREE``. The most common type is ``BTREE``. This index is inspired by the btree data structure although only the first few layers of the btree @@ -1406,6 +1618,10 @@ def create_scalar_index( contains lists of tags (e.g. ``["tag1", "tag2", "tag3"]``) can be indexed with a ``LABEL_LIST`` index. This index can only speedup queries with ``array_has_any`` or ``array_has_all`` filters. + * ``NGRAM``. A special index that is used to index string columns. This index + creates a bitmap for each ngram in the string. By default we use trigrams. + This index can currently speed up queries using the ``contains`` function + in filters. * ``FTS/INVERTED``. It is used to index document columns. This index can conduct full-text searches. For example, a column that contains any word of query string "hello world". The results will be ranked by BM25. @@ -1423,15 +1639,13 @@ def create_scalar_index( or string column. index_type : str The type of the index. One of ``"BTREE"``, ``"BITMAP"``, - ``"LABEL_LIST"``, "FTS" or ``"INVERTED"``. + ``"LABEL_LIST"``, ``"NGRAM"``, ``"FTS"`` or ``"INVERTED"``. name : str, optional The index name. If not provided, it will be generated from the column name. replace : bool, default True Replace the existing index if it exists. - Optional Parameters - ------------------- with_position: bool, default True This is for the ``INVERTED`` index. If True, the index will store the positions of the words in the document, so that you can conduct phrase @@ -1510,40 +1724,53 @@ def create_scalar_index( raise KeyError(f"{column} not found in schema") index_type = index_type.upper() - if index_type not in ["BTREE", "BITMAP", "LABEL_LIST", "INVERTED"]: + if index_type not in ["BTREE", "BITMAP", "NGRAM", "LABEL_LIST", "INVERTED"]: raise NotImplementedError( ( - 'Only "BTREE", "LABEL_LIST", "INVERTED", ' + 'Only "BTREE", "LABEL_LIST", "INVERTED", "NGRAM", ' 'or "BITMAP" are supported for ' f"scalar columns. Received {index_type}", ) ) field = self.schema.field(column) + + field_type = field.type + if hasattr(field_type, "storage_type"): + field_type = field_type.storage_type + if index_type in ["BTREE", "BITMAP"]: if ( - not pa.types.is_integer(field.type) - and not pa.types.is_floating(field.type) - and not pa.types.is_boolean(field.type) - and not pa.types.is_string(field.type) - and not pa.types.is_temporal(field.type) + not pa.types.is_integer(field_type) + and not pa.types.is_floating(field_type) + and not pa.types.is_boolean(field_type) + and not pa.types.is_string(field_type) + and not pa.types.is_temporal(field_type) + and not pa.types.is_fixed_size_binary(field_type) ): raise TypeError( f"BTREE/BITMAP index column {column} must be int", - ", float, bool, str, or temporal", + ", float, bool, str, fixed-size-binary, or temporal ", ) elif index_type == "LABEL_LIST": - if not pa.types.is_list(field.type): + if not pa.types.is_list(field_type): raise TypeError(f"LABEL_LIST index column {column} must be a list") + elif index_type == "NGRAM": + if not pa.types.is_string(field_type): + raise TypeError(f"NGRAM index column {column} must be a string") elif index_type in ["INVERTED", "FTS"]: - if not pa.types.is_string(field.type) and not pa.types.is_large_string( - field.type + value_type = field_type + if pa.types.is_list(field_type) or pa.types.is_large_list(field_type): + value_type = field_type.value_type + if not pa.types.is_string(value_type) and not pa.types.is_large_string( + value_type ): raise TypeError( - f"INVERTED index column {column} must be string or large string" + f"INVERTED index column {column} must be string, large string" + " or list of strings, but got {value_type}" ) - if pa.types.is_duration(field.type): + if pa.types.is_duration(field_type): raise TypeError( f"Scalar index column {column} cannot currently be a duration" ) @@ -1598,15 +1825,19 @@ def create_index( Replace the existing index if it exists. num_partitions : int, optional The number of partitions of IVF (Inverted File Index). - ivf_centroids : ``np.ndarray``, ``pyarrow.FixedSizeListArray`` - or ``pyarrow.FixedShapeTensorArray``. Optional. - A ``num_partitions x dimension`` array of K-mean centroids for IVF - clustering. If not provided, a new Kmean model will be trained. - pq_codebook : ``np.ndarray``, ``pyarrow.FixedSizeListArray`` - or ``pyarrow.FixedShapeTensorArray``. Optional. + ivf_centroids : optional + It can be either :py:class:`np.ndarray`, + :py:class:`pyarrow.FixedSizeListArray` or + :py:class:`pyarrow.FixedShapeTensorArray`. + A ``num_partitions x dimension`` array of existing K-mean centroids + for IVF clustering. If not provided, a new KMeans model will be trained. + pq_codebook : optional, + It can be :py:class:`np.ndarray`, :py:class:`pyarrow.FixedSizeListArray`, + or :py:class:`pyarrow.FixedShapeTensorArray`. A ``num_sub_vectors x (2 ^ nbits * dimensions // num_sub_vectors)`` array of K-mean centroids for PQ codebook. - Note: nbits is always 8 for now. + + Note: ``nbits`` is always 8 for now. If not provided, a new PQ model will be trained. num_sub_vectors : int, optional The number of sub-vectors for PQ (Product Quantization). @@ -1640,7 +1871,9 @@ def create_index( kwargs : Parameters passed to the index building process. - The SQ (Scalar Quantization) is available for only "IVF_HNSW_SQ" index type, + + + The SQ (Scalar Quantization) is available for only ``IVF_HNSW_SQ`` index type, this quantization method is used to reduce the memory usage of the index, it maps the float vectors to integer vectors, each integer is of ``num_bits``, now only 8 bits are supported. @@ -1651,17 +1884,21 @@ def create_index( If ``index_type`` is with "PQ", then the following parameters are required: num_sub_vectors - Optional parameters for "IVF_PQ": - ivf_centroids : - K-mean centroids for IVF clustering. + Optional parameters for `IVF_PQ`: + + - ivf_centroids + Existing K-mean centroids for IVF clustering. + - num_bits + The number of bits for PQ (Product Quantization). Default is 8. + Only 4, 8 are supported. - Optional parameters for "IVF_HNSW_*": - max_level : int - the maximum number of levels in the graph. - m : int - the number of edges per node in the graph. - ef_construction : int - the number of nodes to examine during the construction. + Optional parameters for `IVF_HNSW_*`: + max_level + Int, the maximum number of levels in the graph. + m + Int, the number of edges per node in the graph. + ef_construction + Int, the number of nodes to examine during the construction. Examples -------- @@ -1727,8 +1964,14 @@ def create_index( if c not in self.schema.names: raise KeyError(f"{c} not found in schema") field = self.schema.field(c) + is_multivec = False if pa.types.is_fixed_size_list(field.type): dimension = field.type.list_size + elif pa.types.is_list(field.type) and pa.types.is_fixed_size_list( + field.type.value_type + ): + dimension = field.type.value_type.list_size + is_multivec = True elif ( isinstance(field.type, pa.FixedShapeTensorType) and len(field.type.shape) == 1 @@ -1746,7 +1989,12 @@ def create_index( f" ({num_sub_vectors})" ) - if not pa.types.is_floating(field.type.value_type): + element_type = field.type.value_type + if is_multivec: + element_type = field.type.value_type.value_type + if not ( + pa.types.is_floating(element_type) or pa.types.is_uint8(element_type) + ): raise TypeError( f"Vector column {c} must have floating value type, " f"got {field.type.value_type}" @@ -1757,16 +2005,17 @@ def create_index( "cosine", "euclidean", "dot", + "hamming", ]: raise ValueError(f"Metric {metric} not supported.") kwargs["metric_type"] = metric index_type = index_type.upper() - valid_index_types = ["IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ"] + valid_index_types = ["IVF_FLAT", "IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ"] if index_type not in valid_index_types: raise NotImplementedError( - f"Only {valid_index_types} index types supported. " f"Got {index_type}" + f"Only {valid_index_types} index types supported. Got {index_type}" ) if index_type != "IVF_PQ" and one_pass_ivfpq: raise ValueError( @@ -1781,7 +2030,7 @@ def create_index( one_pass_train_ivf_pq_on_accelerator, ) - logging.info("Doing one-pass ivfpq accelerated computations") + LOGGER.info("Doing one-pass ivfpq accelerated computations") timers["ivf+pq_train:start"] = time.time() ( @@ -1801,7 +2050,7 @@ def create_index( ) timers["ivf+pq_train:end"] = time.time() ivfpq_train_time = timers["ivf+pq_train:end"] - timers["ivf+pq_train:start"] - logging.info("ivf+pq training time: %ss", ivfpq_train_time) + LOGGER.info("ivf+pq training time: %ss", ivfpq_train_time) timers["ivf+pq_assign:start"] = time.time() shuffle_output_dir, shuffle_buffers = one_pass_assign_ivf_pq_on_accelerator( self, @@ -1817,7 +2066,7 @@ def create_index( ivfpq_assign_time = ( timers["ivf+pq_assign:end"] - timers["ivf+pq_assign:start"] ) - logging.info("ivf+pq transform time: %ss", ivfpq_assign_time) + LOGGER.info("ivf+pq transform time: %ss", ivfpq_assign_time) kwargs["precomputed_shuffle_buffers"] = shuffle_buffers kwargs["precomputed_shuffle_buffers_path"] = os.path.join( @@ -1857,7 +2106,7 @@ def create_index( " precomputed_partition_dataset is provided" ) if precomputed_partition_dataset is not None: - logging.info("Using provided precomputed partition dataset") + LOGGER.info("Using provided precomputed partition dataset") precomputed_ds = LanceDataset( precomputed_partition_dataset, storage_options=storage_options ) @@ -1882,7 +2131,7 @@ def create_index( kwargs["precomputed_partitions_file"] = precomputed_partition_dataset if accelerator is not None and ivf_centroids is None and not one_pass_ivfpq: - logging.info("Computing new precomputed partition dataset") + LOGGER.info("Computing new precomputed partition dataset") # Use accelerator to train ivf centroids from .vector import ( compute_partitions, @@ -1900,7 +2149,7 @@ def create_index( ) timers["ivf_train:end"] = time.time() ivf_train_time = timers["ivf_train:end"] - timers["ivf_train:start"] - logging.info("ivf training time: %ss", ivf_train_time) + LOGGER.info("ivf training time: %ss", ivf_train_time) timers["ivf_assign:start"] = time.time() num_sub_vectors_cur = None if "PQ" in index_type and pq_codebook is None: @@ -1916,7 +2165,7 @@ def create_index( ) timers["ivf_assign:end"] = time.time() ivf_assign_time = timers["ivf_assign:end"] - timers["ivf_assign:start"] - logging.info("ivf transform time: %ss", ivf_assign_time) + LOGGER.info("ivf transform time: %ss", ivf_assign_time) kwargs["precomputed_partitions_file"] = partitions_file if (ivf_centroids is None) and (pq_codebook is not None): @@ -1964,7 +2213,7 @@ def create_index( and "precomputed_partitions_file" in kwargs and not one_pass_ivfpq ): - logging.info("Computing new precomputed shuffle buffers for PQ.") + LOGGER.info("Computing new precomputed shuffle buffers for PQ.") partitions_file = kwargs["precomputed_partitions_file"] del kwargs["precomputed_partitions_file"] @@ -1981,10 +2230,11 @@ def create_index( metric, accelerator=accelerator, num_sub_vectors=num_sub_vectors, + dtype=element_type.to_pandas_dtype(), ) timers["pq_train:end"] = time.time() pq_train_time = timers["pq_train:end"] - timers["pq_train:start"] - logging.info("pq training time: %ss", pq_train_time) + LOGGER.info("pq training time: %ss", pq_train_time) timers["pq_assign:start"] = time.time() shuffle_output_dir, shuffle_buffers = compute_pq_codes( partitions_ds, @@ -1993,12 +2243,12 @@ def create_index( ) timers["pq_assign:end"] = time.time() pq_assign_time = timers["pq_assign:end"] - timers["pq_assign:start"] - logging.info("pq transform time: %ss", pq_assign_time) + LOGGER.info("pq transform time: %ss", pq_assign_time) # Save disk space if precomputed_partition_dataset is not None and os.path.exists( partitions_file ): - logging.info( + LOGGER.info( "Temporary partitions file stored at %s," "you may want to delete it.", partitions_file, @@ -2050,17 +2300,43 @@ def create_index( final_create_index_time = ( timers["final_create_index:end"] - timers["final_create_index:start"] ) - logging.info("Final create_index rust time: %ss", final_create_index_time) + LOGGER.info("Final create_index rust time: %ss", final_create_index_time) # Save disk space if "precomputed_shuffle_buffers_path" in kwargs.keys() and os.path.exists( kwargs["precomputed_shuffle_buffers_path"] ): - logging.info( + LOGGER.info( "Temporary shuffle buffers stored at %s, you may want to delete it.", kwargs["precomputed_shuffle_buffers_path"], ) return self + def drop_index(self, name: str): + """ + Drops an index from the dataset + + Note: Indices are dropped by "index name". This is not the same as the field + name. If you did not specify a name when you created the index then a name was + generated for you. You can use the `list_indices` method to get the names of + the indices. + """ + return self._ds.drop_index(name) + + def prewarm_index(self, name: str): + """ + Prewarm an index + + This will load the entire index into memory. This can help avoid cold start + issues with index queries. If the index does not fit in the index cache, then + this will result in wasted I/O. + + Parameters + ---------- + name: str + The name of the index to prewarm. + """ + return self._ds.prewarm_index(name) + def session(self) -> Session: """ Return the dataset session, which holds the dataset's state. @@ -2075,8 +2351,7 @@ def _commit( commit_lock: Optional[CommitLock] = None, ) -> LanceDataset: warnings.warn( - "LanceDataset._commit() is deprecated, use LanceDataset.commit()" - " instead", + "LanceDataset._commit() is deprecated, use LanceDataset.commit() instead", DeprecationWarning, ) return LanceDataset.commit(base_uri, operation, read_version, commit_lock) @@ -2084,7 +2359,8 @@ def _commit( @staticmethod def commit( base_uri: Union[str, Path, LanceDataset], - operation: LanceOperation.BaseOperation, + operation: Union[LanceOperation.BaseOperation, Transaction], + blobs_op: Optional[LanceOperation.BaseOperation] = None, read_version: Optional[int] = None, commit_lock: Optional[CommitLock] = None, storage_options: Optional[Dict[str, str]] = None, @@ -2189,25 +2465,47 @@ def commit( f"commit_lock must be a function, got {type(commit_lock)}" ) - if read_version is None and not isinstance( - operation, (LanceOperation.Overwrite, LanceOperation.Restore) + if ( + isinstance(operation, LanceOperation.BaseOperation) + and read_version is None + and not isinstance( + operation, (LanceOperation.Overwrite, LanceOperation.Restore) + ) ): raise ValueError( "read_version is required for all operations except " "Overwrite and Restore" ) + if isinstance(operation, Transaction): + new_ds = _Dataset.commit_transaction( + base_uri, + operation, + commit_lock, + storage_options=storage_options, + enable_v2_manifest_paths=enable_v2_manifest_paths, + detached=detached, + max_retries=max_retries, + ) + elif isinstance(operation, LanceOperation.BaseOperation): + new_ds = _Dataset.commit( + base_uri, + operation, + blobs_op, + read_version, + commit_lock, + storage_options=storage_options, + enable_v2_manifest_paths=enable_v2_manifest_paths, + detached=detached, + max_retries=max_retries, + ) + else: + raise TypeError( + "operation must be a LanceOperation.BaseOperation or Transaction, " + f"got {type(operation)}" + ) - new_ds = _Dataset.commit( - base_uri, - operation._to_inner(), - read_version, - commit_lock, - storage_options=storage_options, - enable_v2_manifest_paths=enable_v2_manifest_paths, - detached=detached, - max_retries=max_retries, - ) ds = LanceDataset.__new__(LanceDataset) + ds._storage_options = storage_options ds._ds = new_ds ds._uri = new_ds.uri ds._default_scan_options = None @@ -2297,24 +2595,11 @@ def commit_batch( detached=detached, max_retries=max_retries, ) - merged = Transaction(**merged) - # This logic is specific to append, which is all that should - # be returned here. - # TODO: generalize this to all other transaction types. - merged.operation["fragments"] = [ - FragmentMetadata.from_metadata(f) for f in merged.operation["fragments"] - ] - merged.operation = LanceOperation.Append(**merged.operation) - if merged.blobs_op: - merged.blobs_op["fragments"] = [ - FragmentMetadata.from_metadata(f) for f in merged.blobs_op["fragments"] - ] - merged.blobs_op = LanceOperation.Append(**merged.blobs_op) ds = LanceDataset.__new__(LanceDataset) ds._ds = new_ds ds._uri = new_ds.uri ds._default_scan_options = None - return dict( + return BulkCommitResult( dataset=ds, merged=merged, ) @@ -2353,6 +2638,12 @@ def stats(self) -> "LanceStats": """ return LanceStats(self._ds) + @staticmethod + def drop( + base_uri: Union[str, Path], storage_options: Optional[Dict[str, str]] = None + ) -> None: + _Dataset.drop(str(base_uri), storage_options) + class BulkCommitResult(TypedDict): dataset: LanceDataset @@ -2367,6 +2658,43 @@ class Transaction: blobs_op: Optional[LanceOperation.BaseOperation] = None +class Tag(TypedDict): + version: int + manifest_size: int + + +class Version(TypedDict): + version: int + timestamp: int | datetime + metadata: Dict[str, str] + + +class UpdateResult(TypedDict): + num_rows_updated: int + + +class AlterColumn(TypedDict): + path: str + name: Optional[str] + nullable: Optional[bool] + data_type: Optional[pa.DataType] + + +class ExecuteResult(TypedDict): + num_inserted_rows: int + num_updated_rows: int + num_deleted_rows: int + + +class Index(TypedDict): + name: str + type: str + uuid: str + fields: List[str] + version: int + fragment_ids: Set[int] + + # LanceOperation is a namespace for operations that can be applied to a dataset. class LanceOperation: @staticmethod @@ -2389,10 +2717,6 @@ class BaseOperation(ABC): See available operations under :class:`LanceOperation`. """ - @abstractmethod - def _to_inner(self): - raise NotImplementedError() - @dataclass class Overwrite(BaseOperation): """ @@ -2436,20 +2760,14 @@ class Overwrite(BaseOperation): 3 4 d """ - new_schema: pa.Schema + new_schema: LanceSchema | pa.Schema fragments: Iterable[FragmentMetadata] def __post_init__(self): - if not isinstance(self.new_schema, pa.Schema): - raise TypeError( - f"schema must be pyarrow.Schema, got {type(self.new_schema)}" - ) + if isinstance(self.new_schema, pa.Schema): + self.new_schema = LanceSchema.from_pyarrow(self.new_schema) LanceOperation._validate_fragments(self.fragments) - def _to_inner(self): - raw_fragments = [f._metadata for f in self.fragments] - return _Operation.overwrite(self.new_schema, raw_fragments) - @dataclass class Append(BaseOperation): """ @@ -2496,10 +2814,6 @@ class Append(BaseOperation): def __post_init__(self): LanceOperation._validate_fragments(self.fragments) - def _to_inner(self): - raw_fragments = [f._metadata for f in self.fragments] - return _Operation.append(raw_fragments) - @dataclass class Delete(BaseOperation): """ @@ -2568,11 +2882,28 @@ class Delete(BaseOperation): def __post_init__(self): LanceOperation._validate_fragments(self.updated_fragments) - def _to_inner(self): - raw_updated_fragments = [f._metadata for f in self.updated_fragments] - return _Operation.delete( - raw_updated_fragments, self.deleted_fragment_ids, self.predicate - ) + @dataclass + class Update(BaseOperation): + """ + Operation that updates rows in the dataset. + + Attributes + ---------- + removed_fragment_ids: list[int] + The ids of the fragments that have been removed entirely. + updated_fragments: list[FragmentMetadata] + The fragments that have been updated with new deletion vectors. + new_fragments: list[FragmentMetadata] + The fragments that contain the new rows. + """ + + removed_fragment_ids: List[int] + updated_fragments: List[FragmentMetadata] + new_fragments: List[FragmentMetadata] + + def __post_init__(self): + LanceOperation._validate_fragments(self.updated_fragments) + LanceOperation._validate_fragments(self.new_fragments) @dataclass class Merge(BaseOperation): @@ -2634,10 +2965,6 @@ class Merge(BaseOperation): schema: LanceSchema | pa.Schema def __post_init__(self): - LanceOperation._validate_fragments(self.fragments) - - def _to_inner(self): - raw_fragments = [f._metadata for f in self.fragments] if isinstance(self.schema, pa.Schema): warnings.warn( "Passing a pyarrow.Schema to Merge is deprecated. " @@ -2645,7 +2972,7 @@ def _to_inner(self): DeprecationWarning, ) self.schema = LanceSchema.from_pyarrow(self.schema) - return _Operation.merge(raw_fragments, self.schema) + LanceOperation._validate_fragments(self.fragments) @dataclass class Restore(BaseOperation): @@ -2655,9 +2982,6 @@ class Restore(BaseOperation): version: int - def _to_inner(self): - return _Operation.restore(self.version) - @dataclass class RewriteGroup: """ @@ -2667,11 +2991,6 @@ class RewriteGroup: old_fragments: Iterable[FragmentMetadata] new_fragments: Iterable[FragmentMetadata] - def _to_inner(self): - old_fragments = [f._metadata for f in self.old_fragments] - new_fragments = [f._metadata for f in self.new_fragments] - return _RewriteGroup(old_fragments, new_fragments) - @dataclass class RewrittenIndex: """ @@ -2681,9 +3000,6 @@ class RewrittenIndex: old_id: str new_id: str - def _to_inner(self): - return _RewrittenIndex(self.old_id, self.new_id) - @dataclass class Rewrite(BaseOperation): """ @@ -2710,11 +3026,6 @@ def __post_init__(self): all_frags += [new for group in self.groups for new in group.new_fragments] LanceOperation._validate_fragments(all_frags) - def _to_inner(self): - groups = [group._to_inner() for group in self.groups] - rewritten_indices = [index._to_inner() for index in self.rewritten_indices] - return _Operation.rewrite(groups, rewritten_indices) - @dataclass class CreateIndex(BaseOperation): """ @@ -2727,14 +3038,61 @@ class CreateIndex(BaseOperation): dataset_version: int fragment_ids: Set[int] - def _to_inner(self): - return _Operation.create_index( - self.uuid, - self.name, - self.fields, - self.dataset_version, - self.fragment_ids, - ) + @dataclass + class DataReplacementGroup: + """ + Group of data replacements + """ + + fragment_id: int + new_file: DataFile + + @dataclass + class DataReplacement(BaseOperation): + """ + Operation that replaces existing datafiles in the dataset. + """ + + replacements: List[LanceOperation.DataReplacementGroup] + + @dataclass + class Project(BaseOperation): + """ + Operation that project columns. + Use this operator for drop column or rename/swap column. + + Attributes + ---------- + schema: LanceSchema + The lance schema of the new dataset. + + Examples + -------- + Use the projece operator to swap column: + + >>> import lance + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> from lance.schema import LanceSchema + >>> table = pa.table({"a": [1, 2], "b": ["a", "b"], "b1": ["c", "d"]}) + >>> dataset = lance.write_dataset(table, "example") + >>> dataset.to_table().to_pandas() + a b b1 + 0 1 a c + 1 2 b d + >>> + >>> ## rename column `b` into `b0` and rename b1 into `b` + >>> table = pa.table({"a": [3, 4], "b0": ["a", "b"], "b": ["c", "d"]}) + >>> lance_schema = LanceSchema.from_pyarrow(table.schema) + >>> operation = lance.LanceOperation.Project(lance_schema) + >>> dataset = lance.LanceDataset.commit("example", operation, read_version=1) + >>> dataset.to_table().to_pandas() + a b0 b + 0 1 a c + 1 2 b d + """ + + schema: LanceSchema class ScannerBuilder: @@ -2761,6 +3119,8 @@ def __init__(self, ds: LanceDataset): self._fast_search = False self._full_text_query = None self._use_scalar_index = None + self._include_deleted_rows = None + self._scan_stats_callback: Optional[Callable[[ScanStatistics], None]] = None def apply_defaults(self, default_opts: Dict[str, Any]) -> ScannerBuilder: for key, value in default_opts.items(): @@ -2768,6 +3128,7 @@ def apply_defaults(self, default_opts: Dict[str, Any]) -> ScannerBuilder: if setter is None: raise ValueError(f"Unknown option {key}") setter(value) + return self def batch_size(self, batch_size: int) -> ScannerBuilder: """Set batch size for Scanner""" @@ -2882,9 +3243,19 @@ def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder: # Serialize the pyarrow compute expression toSubstrait and use # that as a filter. scalar_schema = pa.schema(fields_without_lists) - self._substrait_filter = serialize_expressions( + substrait_filter = serialize_expressions( [filter], ["my_filter"], scalar_schema ) + if isinstance(substrait_filter, memoryview): + self._substrait_filter = substrait_filter.tobytes() + else: + try: + self._substrait_filter = substrait_filter.to_pybytes() + except AttributeError: + raise TypeError( + "serialize_expressions returned unexpected" + f"type {type(substrait_filter)}" + ) except ImportError: # serialize_expressions was introduced in pyarrow 14. Fallback to # stringifying the expression if pyarrow is too old @@ -2971,7 +3342,7 @@ def nearest( use_index: bool = True, ef: Optional[int] = None, ) -> ScannerBuilder: - q = _coerce_query_vector(q) + q, q_dim = _coerce_query_vector(q) if self.ds.schema.get_field_index(column) < 0: raise ValueError(f"Embedding column {column} is not in the dataset") @@ -2980,14 +3351,20 @@ def nearest( column_type = column_field.type if hasattr(column_type, "storage_type"): column_type = column_type.storage_type - if not pa.types.is_fixed_size_list(column_type): + if pa.types.is_fixed_size_list(column_type): + dim = column_type.list_size + elif pa.types.is_list(column_type) and pa.types.is_fixed_size_list( + column_type.value_type + ): + dim = column_type.value_type.list_size + else: raise TypeError( f"Query column {column} must be a vector. Got {column_field.type}." ) - if len(q) != column_type.list_size: + + if q_dim != dim: raise ValueError( - f"Query vector size {len(q)} does not match index column size" - f" {column_type.list_size}" + f"Query vector size {len(q)} does not match index column size {dim}" ) if k is not None and int(k) <= 0: @@ -3021,9 +3398,18 @@ def fast_search(self, flag: bool) -> ScannerBuilder: self._fast_search = flag return self + def include_deleted_rows(self, flag: bool) -> ScannerBuilder: + """Include deleted rows + + Rows which have been deleted, but are still present in the fragment, will be + returned. These rows will have all columns (except _rowaddr) set to null + """ + self._include_deleted_rows = flag + return self + def full_text_search( self, - query: str, + query: str | FullTextQuery, columns: Optional[List[str]] = None, ) -> ScannerBuilder: """ @@ -3031,8 +3417,34 @@ def full_text_search( may remove it after we support to do this within `filter` SQL-like expression Must create inverted index on the given column before searching, + + Parameters + ---------- + query : str | Query + If str, the query string to search for, a match query would be performed. + If Query, the query object to search for, + and the `columns` parameter will be ignored. + columns : list of str, optional + The columns to search in. If None, search in all indexed columns. + """ + if isinstance(query, FullTextQuery): + self._full_text_query = query.inner + else: + self._full_text_query = { + "query": query, + "columns": columns, + } + return self + + def scan_stats_callback( + self, callback: Callable[[ScanStatistics], None] + ) -> ScannerBuilder: """ - self._full_text_query = {"query": query, "columns": columns} + Set a callback function that will be called with the scan statistics after the + scan is complete. Errors raised by the callback will be logged but not + re-raised. + """ + self._scan_stats_callback = callback return self def to_scanner(self) -> LanceScanner: @@ -3058,6 +3470,8 @@ def to_scanner(self) -> LanceScanner: self._full_text_query, self._late_materialization, self._use_scalar_index, + self._include_deleted_rows, + self._scan_stats_callback, ) return LanceScanner(scanner, self.ds) @@ -3173,6 +3587,21 @@ def explain_plan(self, verbose=False) -> str: return self._scanner.explain_plan(verbose=verbose) + def analyze_plan(self) -> str: + """Execute the plan for this scanner and display with runtime metrics. + + Parameters + ---------- + verbose : bool, default False + Use a verbose output format. + + Returns + ------- + plan : str + """ + + return self._scanner.analyze_plan() + class DatasetOptimizer: def __init__(self, dataset: LanceDataset): @@ -3277,6 +3706,15 @@ def optimize_indices(self, **kwargs): index_names: List[str], default None The names of the indices to optimize. If None, all indices will be optimized. + retrain: bool, default False + Whether to retrain the whole index. + If true, the index will be retrained based on the current data, + `num_indices_to_merge` will be ignored, + and all indices will be merged into one. + + This is useful when the data distribution has changed significantly, + and we want to retrain the index to improve the search quality. + This would be faster than re-create the index from scratch. """ self._dataset._ds.optimize_indices(**kwargs) @@ -3289,13 +3727,13 @@ class Tags: def __init__(self, dataset: _Dataset): self._ds = dataset - def list(self) -> dict[str, int]: + def list(self) -> dict[str, Tag]: """ List all dataset tags. Returns ------- - dict[str, int] + dict[str, Tag] A dictionary mapping tag names to version numbers. """ return self._ds.tags() @@ -3340,6 +3778,21 @@ def update(self, tag: str, version: int) -> None: self._ds.update_tag(tag, version) +@dataclass +class FieldStatistics: + """Statistics about a field in the dataset""" + + id: int #: id of the field + bytes_on_disk: int #: (possibly compressed) bytes on disk used to store the field + + +@dataclass +class DataStatistics: + """Statistics about the data in the dataset""" + + fields: FieldStatistics #: Statistics about the fields in the dataset + + class DatasetStats(TypedDict): num_deleted_rows: int num_fragments: int @@ -3376,6 +3829,12 @@ def index_stats(self, index_name: str) -> Dict[str, Any]: index_stats = json.loads(self._ds.index_statistics(index_name)) return index_stats + def data_stats(self) -> DataStatistics: + """ + Statistics about the data in the dataset. + """ + return self._ds.data_stats() + def write_dataset( data_obj: ReaderLike, @@ -3392,6 +3851,7 @@ def write_dataset( data_storage_version: Optional[str] = None, use_legacy_format: Optional[bool] = None, enable_v2_manifest_paths: bool = False, + enable_move_stable_row_ids: bool = False, ) -> LanceDataset: """Write a given data_obj to the given uri @@ -3445,6 +3905,11 @@ def write_dataset( versions on object stores. This parameter has no effect if the dataset already exists. To migrate an existing dataset, instead use the :meth:`LanceDataset.migrate_manifest_paths_v2` method. Default is False. + enable_move_stable_row_ids : bool, optional + Experimental parameter: if set to true, the writer will use move-stable row ids. + These row ids are stable after compaction operations, but not after updates. + This makes compaction more efficient, since with stable row ids no + secondary indices need to be updated to point to new row ids. """ if use_legacy_format is not None: warnings.warn( @@ -3478,6 +3943,7 @@ def write_dataset( "storage_options": storage_options, "data_storage_version": data_storage_version, "enable_v2_manifest_paths": enable_v2_manifest_paths, + "enable_move_stable_row_ids": enable_move_stable_row_ids, } if commit_lock: @@ -3495,13 +3961,30 @@ def write_dataset( inner_ds = _write_dataset(reader, uri, params) ds = LanceDataset.__new__(LanceDataset) + ds._storage_options = storage_options ds._ds = inner_ds ds._uri = inner_ds.uri ds._default_scan_options = None return ds -def _coerce_query_vector(query: QueryVectorLike): +def _coerce_query_vector(query: QueryVectorLike) -> tuple[pa.Array, int]: + # if the query is a multivector, convert it to pa.ListArray + if hasattr(query, "__getitem__") and isinstance( + query[0], (list, tuple, np.ndarray, pa.Array) + ): + dim = len(query[0]) + multivector_query = [] + for q in query: + if len(q) != dim: + raise ValueError( + "All query vectors must have the same length, " + f"but got {dim} and {len(q)}" + ) + multivector_query.append(_coerce_query_vector(q)[0]) + query = pa.array(multivector_query, type=pa.list_(pa.float32())) + return (query, dim) + if isinstance(query, pa.Scalar): if isinstance(query, pa.ExtensionScalar): # If it's an extension scalar then convert to storage @@ -3534,7 +4017,7 @@ def _coerce_query_vector(query: QueryVectorLike): f"but received {query.type}" ) - return query + return (query, len(query)) def _validate_schema(schema: pa.Schema): @@ -3562,3 +4045,112 @@ def _validate_metadata(metadata: dict): ) elif isinstance(v, dict): _validate_metadata(v) + + +class VectorIndexReader: + """ + This class allows you to initialize a reader for a specific vector index, + retrieve the number of partitions, + access the centroids of the index, + and read specific partitions of the index. + + Parameters + ---------- + dataset: LanceDataset + The dataset containing the index. + index_name: str + The name of the vector index to read. + + Examples + -------- + .. code-block:: python + + import lance + from lance.dataset import VectorIndexReader + import numpy as np + import pyarrow as pa + vectors = np.random.rand(256, 2) + data = pa.table({"vector": pa.array(vectors.tolist(), + type=pa.list_(pa.float32(), 2))}) + dataset = lance.write_dataset(data, "/tmp/index_reader_demo") + dataset.create_index("vector", index_type="IVF_PQ", + num_partitions=4, num_sub_vectors=2) + reader = VectorIndexReader(dataset, "vector_idx") + assert reader.num_partitions() == 4 + partition = reader.read_partition(0) + assert "_rowid" in partition.column_names + + Exceptions + ---------- + ValueError + If the specified index is not a vector index. + """ + + def __init__(self, dataset: LanceDataset, index_name: str): + stats = dataset.stats.index_stats(index_name) + self.dataset = dataset + self.index_name = index_name + self.stats = stats + try: + self.num_partitions() + except KeyError: + raise ValueError(f"Index {index_name} is not vector index") + + def num_partitions(self) -> int: + """ + Returns the number of partitions in the dataset. + + Returns + ------- + int + The number of partitions. + """ + + return self.stats["indices"][0]["num_partitions"] + + def centroids(self) -> np.ndarray: + """ + Returns the centroids of the index + + Returns + ------- + np.ndarray + The centroids of IVF + with shape (num_partitions, dim) + """ + # when we have more delta indices, + # they are with the same centroids + return np.array( + self.dataset._ds.get_index_centroids(self.stats["indices"][0]["centroids"]) + ) + + def read_partition( + self, partition_id: int, *, with_vector: bool = False + ) -> pa.Table: + """ + Returns a pyarrow table for the given IVF partition + + Parameters + ---------- + partition_id: int + The id of the partition to read + with_vector: bool, default False + Whether to include the vector column in the reader, + for IVF_PQ, the vector column is PQ codes + + Returns + ------- + pa.Table + A pyarrow table for the given partition, + containing the row IDs, and quantized vectors (if with_vector is True). + """ + + if partition_id < 0 or partition_id >= self.num_partitions(): + raise IndexError( + f"Partition id {partition_id} is out of range, " + f"expected 0 <= partition_id < {self.num_partitions()}" + ) + + return self.dataset._ds.read_index_partition( + self.index_name, partition_id, with_vector + ).read_all() diff --git a/python/python/lance/dependencies.py b/python/python/lance/dependencies.py index dd3859c7aef..19855a990d0 100644 --- a/python/python/lance/dependencies.py +++ b/python/python/lance/dependencies.py @@ -50,8 +50,6 @@ class _LazyModule(ModuleType): "pandas": "pd.", "polars": "pl.", "torch": "torch.", - "cagra": "cagra.", - "common": "raft_common.", "tensorflow": "tf.", "ray": "ray.", } @@ -176,8 +174,6 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: pandas, _PANDAS_AVAILABLE = _lazy_import("pandas") polars, _POLARS_AVAILABLE = _lazy_import("polars") torch, _TORCH_AVAILABLE = _lazy_import("torch") - cagra, _CAGRA_AVAILABLE = _lazy_import("cuvs.neighbors.cagra") - raft_common, _RAFT_COMMON_AVAILABLE = _lazy_import("pylibraft.common") datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets") tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow") ray, _RAY_AVAILABLE = _lazy_import("ray") @@ -195,43 +191,43 @@ def _might_be(cls: type, type_: str) -> bool: def _check_for_numpy(obj: Any, *, check_type: bool = True) -> bool: return _NUMPY_AVAILABLE and _might_be( - cast(Hashable, type(obj) if check_type else obj), "numpy" + cast("Hashable", type(obj) if check_type else obj), "numpy" ) def _check_for_pandas(obj: Any, *, check_type: bool = True) -> bool: return _PANDAS_AVAILABLE and _might_be( - cast(Hashable, type(obj) if check_type else obj), "pandas" + cast("Hashable", type(obj) if check_type else obj), "pandas" ) def _check_for_polars(obj: Any, *, check_type: bool = True) -> bool: return _POLARS_AVAILABLE and _might_be( - cast(Hashable, type(obj) if check_type else obj), "polars" + cast("Hashable", type(obj) if check_type else obj), "polars" ) def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool: return _TORCH_AVAILABLE and _might_be( - cast(Hashable, type(obj) if check_type else obj), "torch" + cast("Hashable", type(obj) if check_type else obj), "torch" ) def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool: return _HUGGING_FACE_AVAILABLE and _might_be( - cast(Hashable, type(obj) if check_type else obj), "datasets" + cast("Hashable", type(obj) if check_type else obj), "datasets" ) def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool: return _TENSORFLOW_AVAILABLE and _might_be( - cast(Hashable, type(obj) if check_type else obj), "tensorflow" + cast("Hashable", type(obj) if check_type else obj), "tensorflow" ) def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool: return _RAY_AVAILABLE and _might_be( - cast(Hashable, type(obj) if check_type else obj), "ray" + cast("Hashable", type(obj) if check_type else obj), "ray" ) @@ -244,8 +240,6 @@ def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool: "ray", "tensorflow", "torch", - "cagra", - "raft_common", # lazy utilities "_check_for_hugging_face", "_check_for_numpy", @@ -260,8 +254,6 @@ def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool: "_PANDAS_AVAILABLE", "_POLARS_AVAILABLE", "_TORCH_AVAILABLE", - "_CAGRA_AVAILABLE", - "_RAFT_COMMON_AVAILABLE", "_HUGGING_FACE_AVAILABLE", "_TENSORFLOW_AVAILABLE", "_RAY_AVAILABLE", diff --git a/python/python/lance/download.py b/python/python/lance/download.py new file mode 100644 index 00000000000..cff42520e4e --- /dev/null +++ b/python/python/lance/download.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import os +import shutil +import subprocess +import tarfile +import traceback +from io import BytesIO + +from .lance import language_model_home + +LANGUAGE_MODEL_HOME = language_model_home() + + +def check_lindera(): + if not shutil.which("lindera"): + raise Exception( + "lindera is not installed. Please install it by following https://github.com/lindera/lindera/tree/main/lindera-cli" + ) + + +def import_requests(): + try: + import requests + except Exception: + raise Exception("requests is not installed, Please pip install requests") + return requests + + +def download_jieba(): + dirname = os.path.join(LANGUAGE_MODEL_HOME, "jieba", "default") + os.makedirs(dirname, exist_ok=True) + try: + requests = import_requests() + resp = requests.get( + "https://github.com/messense/jieba-rs/raw/refs/heads/main/src/data/dict.txt" + ) + content = resp.content + with open(os.path.join(dirname, "dict.txt"), "wb") as out: + out.write(content) + except Exception as _: + traceback.print_exc() + print( + "Download jieba language model failed. Please download dict.txt from " + "https://github.com/messense/jieba-rs/tree/main/src/data " + f"and put it in {dirname}" + ) + + +def download_lindera(lm: str): + requests = import_requests() + dirname = os.path.join(LANGUAGE_MODEL_HOME, "lindera", lm) + src_dirname = os.path.join(dirname, "src") + if lm == "ipadic": + url = "https://dlwqk3ibdg1xh.cloudfront.net/mecab-ipadic-2.7.0-20070801.tar.gz" + elif lm == "ko-dic": + url = "https://dlwqk3ibdg1xh.cloudfront.net/mecab-ko-dic-2.1.1-20180720.tar.gz" + elif lm == "unidic": + url = "https://dlwqk3ibdg1xh.cloudfront.net/unidic-mecab-2.1.2.tar.gz" + else: + raise Exception(f"language model {lm} is not supported") + os.makedirs(src_dirname, exist_ok=True) + print(f"downloading language model: {url}") + data = requests.get(url).content + print(f"unzip language model: {url}") + + cwd = os.getcwd() + try: + os.chdir(src_dirname) + with tarfile.open(fileobj=BytesIO(data)) as tar: + tar.extractall() + name = tar.getnames()[0] + cmd = [ + "lindera", + "build", + f"--dictionary-kind={lm}", + os.path.join(src_dirname, name), + os.path.join(dirname, "main"), + ] + print(f"compiling language model: {' '.join(cmd)}") + subprocess.run(cmd) + finally: + os.chdir(cwd) + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Lance tokenizer language model downloader" + ) + parser.add_argument("tokenizer", choices=["jieba", "lindera"]) + parser.add_argument("-l", "--languagemodel") + args = parser.parse_args() + print(f"LANCE_LANGUAGE_MODEL_HOME={LANGUAGE_MODEL_HOME}") + if args.tokenizer == "jieba": + download_jieba() + elif args.tokenizer == "lindera": + download_lindera(args.languagemodel) + + +if __name__ == "__main__": + main() diff --git a/python/python/lance/file.py b/python/python/lance/file.py index 895d09f3c90..e81b61d7b5a 100644 --- a/python/python/lance/file.py +++ b/python/python/lance/file.py @@ -10,6 +10,7 @@ LanceBufferDescriptor, LanceColumnMetadata, LanceFileMetadata, + LanceFileStatistics, LancePageMetadata, ) from .lance import ( @@ -133,7 +134,7 @@ def take_rows( if indices[i] > indices[i + 1]: raise ValueError( f"Indices must be sorted in ascending order for \ - file API, got {indices[i]} > {indices[i+1]}" + file API, got {indices[i]} > {indices[i + 1]}" ) return ReaderResults( @@ -146,6 +147,12 @@ def metadata(self) -> LanceFileMetadata: """ return self._reader.metadata() + def file_statistics(self) -> LanceFileStatistics: + """ + Return file statistics of the file + """ + return self._reader.file_statistics() + def read_global_buffer(self, index: int) -> bytes: """ Read a global buffer from the file at a given index @@ -232,7 +239,7 @@ def write_batch(self, batch: Union[pa.RecordBatch, pa.Table]) -> None: else: self._writer.write_batch(batch) - def close(self) -> int: + def close(self) -> Optional[int]: """ Write the file metadata and close the file @@ -289,4 +296,5 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: "LanceColumnMetadata", "LancePageMetadata", "LanceBufferDescriptor", + "LanceFileStatistics", ] diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index 88c9d69c796..dd17fbcf427 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -7,83 +7,198 @@ import json import warnings +from dataclasses import asdict, dataclass, field from pathlib import Path from typing import ( TYPE_CHECKING, Callable, Dict, - Iterable, Iterator, List, + Literal, Optional, Tuple, Union, + overload, ) import pyarrow as pa -from .dependencies import _check_for_pandas -from .dependencies import pandas as pd -from .lance import _Fragment, _write_fragments -from .lance import _FragmentMetadata as _FragmentMetadata +from .lance import ( + DeletionFile as DeletionFile, +) +from .lance import ( + RowIdMeta as RowIdMeta, +) +from .lance import _Fragment, _write_fragments, _write_fragments_transaction from .progress import FragmentWriteProgress, NoopFragmentWriteProgress +from .types import _coerce_reader from .udf import BatchUDF, normalize_transform if TYPE_CHECKING: - from .dataset import LanceDataset, LanceScanner, ReaderLike - from .schema import LanceSchema + from .dataset import LanceDataset, LanceScanner, ReaderLike, Transaction + from .lance import LanceSchema DEFAULT_MAX_BYTES_PER_FILE = 90 * 1024 * 1024 * 1024 +@dataclass class FragmentMetadata: - """Metadata of a Fragment in the dataset.""" - - def __init__(self, metadata: str): - """Construct a FragmentMetadata from a JSON representation of the metadata. + """Metadata for a fragment. - Internal use only. - """ - self._metadata = _FragmentMetadata.from_json(metadata) - - @classmethod - def from_metadata(cls, metadata: _FragmentMetadata): - instance = cls.__new__(cls) - instance._metadata = metadata - return instance + Attributes + ---------- + id : int + The ID of the fragment. + files : List[DataFile] + The data files of the fragment. Each data file must have the same number + of rows. Each file stores a different subset of the columns. + physical_rows : int + The number of rows originally in this fragment. This is the number of rows + in the data files before deletions. + deletion_file : Optional[DeletionFile] + The deletion file, if any. + row_id_meta : Optional[RowIdMeta] + The row id metadata, if any. + """ - def __repr__(self): - return self._metadata.__repr__() + id: int + files: List[DataFile] + physical_rows: int + deletion_file: Optional[DeletionFile] = None + row_id_meta: Optional[RowIdMeta] = None - def __reduce__(self): - return (FragmentMetadata, (self._metadata.json(),)) + @property + def num_deletions(self) -> int: + """The number of rows that have been deleted from this fragment.""" + if self.deletion_file is None: + return 0 + else: + return self.deletion_file.num_deleted_rows - def __eq__(self, other: object) -> bool: - if not isinstance(other, FragmentMetadata): - return False - return self._metadata.__eq__(other._metadata) + @property + def num_rows(self) -> int: + """The number of rows in this fragment after deletions.""" + return self.physical_rows - self.num_deletions - def to_json(self) -> str: - """Serialize :class:`FragmentMetadata` to a JSON blob""" - return json.loads(self._metadata.json()) + def data_files(self) -> List[DataFile]: + warnings.warn( + "FragmentMetadata.data_files is deprecated. Use .files instead.", + DeprecationWarning, + ) + return self.files + + def to_json(self) -> dict: + """Get this as a simple JSON-serializable dictionary.""" + files = [asdict(f) for f in self.files] + for f in files: + f["path"] = f.pop("_path") + return dict( + id=self.id, + files=files, + physical_rows=self.physical_rows, + deletion_file=( + self.deletion_file.asdict() if self.deletion_file is not None else None + ), + row_id_meta=( + self.row_id_meta.asdict() if self.row_id_meta is not None else None + ), + ) @staticmethod def from_json(json_data: str) -> FragmentMetadata: - """Reconstruct :class:`FragmentMetadata` from a JSON blob""" - return FragmentMetadata(json_data) + json_data = json.loads(json_data) + + deletion_file = json_data.get("deletion_file") + if deletion_file is not None: + deletion_file = DeletionFile(**deletion_file) + + row_id_meta = json_data.get("row_id_meta") + if row_id_meta is not None: + row_id_meta = RowIdMeta(**row_id_meta) + + return FragmentMetadata( + id=json_data["id"], + files=[DataFile(**f) for f in json_data["files"]], + physical_rows=json_data["physical_rows"], + deletion_file=deletion_file, + row_id_meta=row_id_meta, + ) - def data_files(self) -> Iterable[str]: - """Return the data files of the fragment""" - return self._metadata.data_files() - def deletion_file(self): - """Return the deletion file, if any""" - return self._metadata.deletion_file() +@dataclass +class DataFile: + """ + A data file in a fragment. + + Attributes + ---------- + path : str + The path to the data file. + fields : List[int] + The field ids of the columns in this file. + column_indices : List[int] + The column indices where the fields are stored in the file. Will have + the same length as `fields`. + file_major_version : int + The major version of the data storage format. + file_minor_version : int + The minor version of the data storage format. + """ + + _path: str + fields: List[int] + column_indices: List[int] = field(default_factory=list) + file_major_version: int = 0 + file_minor_version: int = 0 + + def __init__( + self, + path: str, + fields: List[int], + column_indices: List[int] = None, + file_major_version: int = 0, + file_minor_version: int = 0, + ): + # TODO: only we eliminate the path method, we can remove this + self._path = path + self.fields = fields + self.column_indices = column_indices or [] + self.file_major_version = file_major_version + self.file_minor_version = file_minor_version + + def __repr__(self): + # pretend we have a 'path' attribute + return ( + f"DataFile(path='{self._path}', fields={self.fields}, " + f"column_indices={self.column_indices}, " + f"file_major_version={self.file_major_version}, " + f"file_minor_version={self.file_minor_version})" + ) @property - def id(self) -> int: - return self._metadata.id + def path(self) -> str: + # path used to be a method. This is for backwards compatibility. + class CallableStr(str): + def __call__(self): + warnings.warn( + "DataFile.path() is deprecated, use DataFile.path instead", + DeprecationWarning, + ) + return self + + def __reduce__(self): + return (str, (str(self),)) + + return CallableStr(self._path) + + def field_ids(self) -> List[int]: + warnings.warn( + "DataFile.field_ids is deprecated, use DataFile.fields instead", + DeprecationWarning, + ) + return self.fields class LanceFragment(pa.dataset.Fragment): @@ -99,7 +214,7 @@ def __init__( if fragment_id is None: raise ValueError("Either fragment or fragment_id must be specified") fragment = dataset.get_fragment(fragment_id)._fragment - self._fragment = fragment + self._fragment: _Fragment = fragment if self._fragment is None: raise ValueError(f"Fragment id does not exist: {fragment_id}") @@ -114,7 +229,7 @@ def __reduce__(self): @staticmethod def create_from_file( - filename: Union[str, Path], + filename: str, dataset: LanceDataset, fragment_id: int, ) -> FragmentMetadata: @@ -135,13 +250,12 @@ def create_from_file( fragment_id: int The ID of the fragment. """ - fragment = _Fragment.create_from_file(filename, dataset._ds, fragment_id) - return FragmentMetadata(fragment.json()) + return _Fragment.create_from_file(filename, dataset._ds, fragment_id) @staticmethod def create( dataset_uri: Union[str, Path], - data: Union[pa.Table, pa.RecordBatchReader], + data: ReaderLike, fragment_id: Optional[int] = None, schema: Optional[pa.Schema] = None, max_rows_per_group: int = 1024, @@ -215,23 +329,14 @@ def create( else: data_storage_version = "stable" - if _check_for_pandas(data) and isinstance(data, pd.DataFrame): - reader = pa.Table.from_pandas(data, schema=schema).to_reader() - elif isinstance(data, pa.Table): - reader = data.to_reader() - elif isinstance(data, pa.dataset.Scanner): - reader = data.to_reader() - elif isinstance(data, pa.RecordBatchReader): - reader = data - else: - raise TypeError(f"Unknown data_obj type {type(data)}") + reader = _coerce_reader(data, schema) if isinstance(dataset_uri, Path): dataset_uri = str(dataset_uri) if progress is None: progress = NoopFragmentWriteProgress() - inner_meta = _Fragment.create( + return _Fragment.create( dataset_uri, fragment_id, reader, @@ -241,7 +346,6 @@ def create( data_storage_version=data_storage_version, storage_options=storage_options, ) - return FragmentMetadata(inner_meta.json()) @property def fragment_id(self): @@ -250,9 +354,11 @@ def fragment_id(self): def count_rows( self, filter: Optional[Union[pa.compute.Expression, str]] = None ) -> int: - if filter is not None: - raise ValueError("Does not support filter at the moment") - return self._fragment.count_rows() + if isinstance(filter, pa.compute.Expression): + return self.scanner( + with_row_id=True, columns=[], filter=filter + ).count_rows() + return self._fragment.count_rows(filter) @property def num_deletions(self) -> int: @@ -291,6 +397,7 @@ def scanner( limit: Optional[int] = None, offset: Optional[int] = None, with_row_id: bool = False, + with_row_address: bool = False, batch_readahead: int = 16, ) -> "LanceScanner": """See Dataset::scanner for details""" @@ -309,6 +416,7 @@ def scanner( limit=limit, offset=offset, with_row_id=with_row_id, + with_row_address=with_row_address, batch_readahead=batch_readahead, **columns_arg, ) @@ -360,12 +468,86 @@ def to_table( with_row_id=with_row_id, ).to_table() + def merge( + self, + data_obj: ReaderLike, + left_on: str, + right_on: Optional[str] = None, + schema=None, + ) -> Tuple[FragmentMetadata, LanceSchema]: + """ + Merge another dataset into this fragment. + + Performs a left join, where the fragment is the left side and data_obj + is the right side. Rows existing in the dataset but not on the left will + be filled with null values, unless Lance doesn't support null values for + some types, in which case an error will be raised. + + Parameters + ---------- + data_obj: Reader-like + The data to be merged. Acceptable types are: + - Pandas DataFrame, Pyarrow Table, Dataset, Scanner, + Iterator[RecordBatch], or RecordBatchReader + left_on: str + The name of the column in the dataset to join on. + right_on: str or None + The name of the column in data_obj to join on. If None, defaults to + left_on. + + Examples + -------- + + >>> import lance + >>> import pyarrow as pa + >>> df = pa.table({'x': [1, 2, 3], 'y': ['a', 'b', 'c']}) + >>> dataset = lance.write_dataset(df, "dataset") + >>> dataset.to_table().to_pandas() + x y + 0 1 a + 1 2 b + 2 3 c + >>> fragments = dataset.get_fragments() + >>> new_df = pa.table({'x': [1, 2, 3], 'z': ['d', 'e', 'f']}) + >>> merged = [] + >>> schema = None + >>> for f in fragments: + ... f, schema = f.merge(new_df, 'x') + ... merged.append(f) + >>> merge = lance.LanceOperation.Merge(merged, schema) + >>> dataset = lance.LanceDataset.commit("dataset", merge, read_version=1) + >>> dataset.to_table().to_pandas() + x y z + 0 1 a d + 1 2 b e + 2 3 c f + + See Also + -------- + LanceDataset.merge_columns : + Add columns to this Fragment. + + Returns + ------- + Tuple[FragmentMetadata, LanceSchema] + A new fragment with the merged column(s) and the final schema. + """ + if right_on is None: + right_on = left_on + + reader = _coerce_reader(data_obj, schema) + max_field_id = self._ds.max_field_id + metadata, schema = self._fragment.merge(reader, left_on, right_on, max_field_id) + return metadata, schema + def merge_columns( self, - value_func: Dict[str, str] - | BatchUDF - | ReaderLike - | Callable[[pa.RecordBatch], pa.RecordBatch], + value_func: ( + Dict[str, str] + | BatchUDF + | ReaderLike + | Callable[[pa.RecordBatch], pa.RecordBatch] + ), columns: Optional[list[str]] = None, batch_size: Optional[int] = None, reader_schema: Optional[pa.Schema] = None, @@ -413,7 +595,7 @@ def merge_columns( transforms, columns, batch_size ) - return FragmentMetadata.from_metadata(metadata), schema + return metadata, schema def delete(self, predicate: str) -> FragmentMetadata | None: """Delete rows from this Fragment. @@ -444,7 +626,7 @@ def delete(self, predicate: str) -> FragmentMetadata | None: >>> dataset = lance.write_dataset(tab, "dataset") >>> frag = dataset.get_fragment(0) >>> frag.delete("a > 1") - Fragment { id: 0, files: ..., deletion_file: Some(...), ...} + FragmentMetadata(id=0, files=[DataFile(path='...', fields=[0, 1], ...), ...) >>> frag.delete("a > 0") is None True @@ -457,7 +639,7 @@ def delete(self, predicate: str) -> FragmentMetadata | None: raw_fragment = self._fragment.delete(predicate) if raw_fragment is None: return None - return FragmentMetadata.from_metadata(raw_fragment.metadata()) + return raw_fragment.metadata() @property def schema(self) -> pa.Schema: @@ -482,7 +664,46 @@ def metadata(self) -> FragmentMetadata: ------- FragmentMetadata """ - return FragmentMetadata.from_metadata(self._fragment.metadata()) + return self._fragment.metadata() + + +if TYPE_CHECKING: + + @overload + def write_fragments( + data: ReaderLike, + dataset_uri: Union[str, Path, LanceDataset], + schema: Optional[pa.Schema] = None, + *, + return_transaction: Literal[True], + mode: str = "append", + max_rows_per_file: int = 1024 * 1024, + max_rows_per_group: int = 1024, + max_bytes_per_file: int = DEFAULT_MAX_BYTES_PER_FILE, + progress: Optional[FragmentWriteProgress] = None, + data_storage_version: Optional[str] = None, + use_legacy_format: Optional[bool] = None, + storage_options: Optional[Dict[str, str]] = None, + enable_move_stable_row_ids: bool = False, + ) -> Transaction: ... + + @overload + def write_fragments( + data: ReaderLike, + dataset_uri: Union[str, Path, LanceDataset], + schema: Optional[pa.Schema] = None, + *, + return_transaction: Literal[False] = False, + mode: str = "append", + max_rows_per_file: int = 1024 * 1024, + max_rows_per_group: int = 1024, + max_bytes_per_file: int = DEFAULT_MAX_BYTES_PER_FILE, + progress: Optional[FragmentWriteProgress] = None, + data_storage_version: Optional[str] = None, + use_legacy_format: Optional[bool] = None, + storage_options: Optional[Dict[str, str]] = None, + enable_move_stable_row_ids: bool = False, + ) -> List[FragmentMetadata]: ... def write_fragments( @@ -490,6 +711,7 @@ def write_fragments( dataset_uri: Union[str, Path, LanceDataset], schema: Optional[pa.Schema] = None, *, + return_transaction: bool = False, mode: str = "append", max_rows_per_file: int = 1024 * 1024, max_rows_per_group: int = 1024, @@ -498,7 +720,8 @@ def write_fragments( data_storage_version: Optional[str] = None, use_legacy_format: Optional[bool] = None, storage_options: Optional[Dict[str, str]] = None, -) -> List[FragmentMetadata]: + enable_move_stable_row_ids: bool = False, +) -> List[FragmentMetadata] | Transaction: """ Write data into one or more fragments. @@ -516,6 +739,8 @@ def write_fragments( schema : pa.Schema, optional The schema of the data. If not specified, the schema will be inferred from the data. + return_transaction: bool, default False + If it's true, the transaction will be returned. mode : str, default "append" The write mode. If "append" is specified, the data will be checked against the existing dataset's schema. Otherwise, pass "create" or @@ -544,26 +769,28 @@ def write_fragments( storage_options : Optional[Dict[str, str]] Extra options that make sense for a particular storage connection. This is used to store connection parameters like credentials, endpoint, etc. - + enable_move_stable_row_ids: bool + Experimental: if set to true, the writer will use move-stable row ids. + These row ids are stable after compaction operations, but not after updates. + This makes compaction more efficient, since with stable row ids no + secondary indices need to be updated to point to new row ids. Returns ------- - List[FragmentMetadata] - A list of :class:`FragmentMetadata` for the fragments written. The - fragment ids are left as zero meaning they are not yet specified. They - will be assigned when the fragments are committed to a dataset. + List[FragmentMetadata] | Transaction + If return_transaction is False: + a list of :class:`FragmentMetadata` for the fragments written. The + fragment ids are left as zero meaning they are not yet specified. They + will be assigned when the fragments are committed to a dataset. + + If return_transaction is True: + The write transaction. The type of transaction will correspond to + the mode parameter specified. This transaction can be passed to + :meth:`LanceDataset.commit`. + """ from .dataset import LanceDataset - if _check_for_pandas(data) and isinstance(data, pd.DataFrame): - reader = pa.Table.from_pandas(data, schema=schema).to_reader() - elif isinstance(data, pa.Table): - reader = data.to_reader() - elif isinstance(data, pa.dataset.Scanner): - reader = data.to_reader() - elif isinstance(data, pa.RecordBatchReader): - reader = data - else: - raise TypeError(f"Unknown data_obj type {type(data)}") + reader = _coerce_reader(data, schema) if isinstance(dataset_uri, Path): dataset_uri = str(dataset_uri) @@ -582,7 +809,9 @@ def write_fragments( else: data_storage_version = "stable" - fragments = _write_fragments( + function = _write_fragments_transaction if return_transaction else _write_fragments + + return function( dataset_uri, reader, mode=mode, @@ -592,5 +821,5 @@ def write_fragments( progress=progress, data_storage_version=data_storage_version, storage_options=storage_options, + enable_move_stable_row_ids=enable_move_stable_row_ids, ) - return [FragmentMetadata.from_metadata(frag) for frag in fragments] diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 97d2cb602de..b81f682c0ac 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -12,10 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Self, + Sequence, + Tuple, + Union, +) import pyarrow as pa +from .._arrow.bf16 import BFloat16Array +from ..commit import CommitLock +from ..dataset import ( + AlterColumn, + ExecuteResult, + Index, + LanceOperation, + Tag, + Transaction, + UpdateResult, + Version, +) +from ..fragment import ( + DataFile, + FragmentMetadata, +) +from ..progress import FragmentWriteProgress as FragmentWriteProgress +from ..types import ReaderLike as ReaderLike +from ..udf import BatchUDF as BatchUDF +from .debug import format_fragment as format_fragment +from .debug import format_manifest as format_manifest +from .debug import format_schema as format_schema +from .debug import list_transactions as list_transactions +from .fragment import ( + DeletionFile as DeletionFile, +) +from .fragment import ( + RowIdMeta as RowIdMeta, +) +from .optimize import ( + Compaction as Compaction, +) +from .optimize import ( + CompactionMetrics as CompactionMetrics, +) +from .optimize import ( + CompactionPlan as CompactionPlan, +) +from .optimize import ( + CompactionTask as CompactionTask, +) +from .optimize import ( + RewriteResult as RewriteResult, +) +from .schema import LanceSchema as LanceSchema +from .trace import trace_to_chrome as trace_to_chrome + def infer_tfrecord_schema( uri: str, tensor_features: Optional[List[str]] = None, @@ -27,12 +87,6 @@ class CleanupStats: bytes_removed: int old_versions: int -class CompactionMetrics: - fragments_removed: int - fragments_added: int - files_removed: int - files_added: int - class LanceFileWriter: def __init__( self, @@ -60,6 +114,8 @@ class LanceFileReader: self, indices: List[int], batch_size: int, batch_readahead: int ) -> pa.RecordBatchReader: ... def read_global_buffer(self, index: int) -> bytes: ... + def metadata(self) -> LanceFileMetadata: ... + def file_statistics(self) -> LanceFileStatistics: ... class LanceBufferDescriptor: position: int @@ -82,6 +138,13 @@ class LanceFileMetadata: global_buffers: List[LanceBufferDescriptor] columns: List[LanceColumnMetadata] +class LanceFileStatistics: + columns: List[LanceColumnStatistics] + +class LanceColumnStatistics: + num_pages: int + size_bytes: int + class _Session: def size_bytes(self) -> int: ... @@ -92,4 +155,339 @@ class LanceBlobFile: def tell(self) -> int: ... def size(self) -> int: ... def readall(self) -> bytes: ... - def readinto(self, b: bytearray) -> int: ... + def read_into(self, b: bytearray) -> int: ... + +class _Dataset: + @property + def uri(self) -> str: ... + def __init__( + self, + uri: str, + version: Optional[int | str] = None, + block_size: Optional[int] = None, + index_cache_size: Optional[int] = None, + metadata_cache_size: Optional[int] = None, + commit_handler: Optional[CommitLock] = None, + storage_options: Optional[Dict[str, str]] = None, + manifest: Optional[bytes] = None, + **kwargs, + ): ... + @property + def schema(self) -> pa.Schema: ... + @property + def lance_schema(self) -> LanceSchema: ... + def replace_schema_metadata(self, metadata: Dict[str, str]): ... + def replace_field_metadata(self, field_name: str, metadata: Dict[str, str]): ... + @property + def data_storage_version(self) -> str: ... + def index_statistics(self, index_name: str) -> str: ... + def serialized_manifest(self) -> bytes: ... + def load_indices(self) -> List[Index]: ... + def scanner( + self, + columns: Optional[List[str]] = None, + columns_with_transform: Optional[List[Tuple[str, str]]] = None, + filter: Optional[str] = None, + prefilter: Optional[bool] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + nearest: Optional[Dict] = None, + batch_size: Optional[int] = None, + io_buffer_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + scan_in_order: Optional[bool] = None, + fragments: Optional[List[_Fragment]] = None, + with_row_id: Optional[bool] = None, + with_row_address: Optional[bool] = None, + use_stats: Optional[bool] = None, + substrait_filter: Optional[bytes] = None, + fast_search: Optional[bool] = None, + full_text_query: Optional[dict] = None, + late_materialization: Optional[bool | List[str]] = None, + use_scalar_index: Optional[bool] = None, + include_deleted_rows: Optional[bool] = None, + ) -> _Scanner: ... + def count_rows(self, filter: Optional[str] = None) -> int: ... + def take( + self, + row_indices: List[int], + columns: Optional[List[str]] = None, + columns_with_transform: Optional[List[Tuple[str, str]]] = None, + ) -> pa.RecordBatch: ... + def take_rows( + self, + row_indices: List[int], + columns: Optional[List[str]] = None, + columns_with_transform: Optional[List[Tuple[str, str]]] = None, + ) -> pa.RecordBatch: ... + def take_blobs( + self, + row_indices: List[int], + blob_column: str, + ) -> List[LanceBlobFile]: ... + def take_scan( + self, + row_slices: Iterable[Tuple[int, int]], + columns: Optional[List[str]] = None, + batch_readahead: int = 10, + ) -> pa.RecordBatchReader: ... + def alter_columns(self, alterations: List[AlterColumn]): ... + def merge(self, reader: pa.RecordBatchReader, left_on: str, right_on: str): ... + def delete(self, predicate: str): ... + def update( + self, + updates: Dict[str, str], + predicate: Optional[str] = None, + ) -> UpdateResult: ... + def count_deleted_rows(self) -> int: ... + def versions(self) -> List[Version]: ... + def version(self) -> int: ... + def latest_version(self) -> int: ... + def checkout_version(self, version: int | str) -> _Dataset: ... + def restore(self): ... + def cleanup_old_versions( + self, + older_than_micros: int, + delete_unverified: Optional[bool] = None, + error_if_tagged_old_versions: Optional[bool] = None, + ) -> CleanupStats: ... + def tags(self) -> Dict[str, Tag]: ... + def create_tag(self, tag: str, version: int): ... + def delete_tag(self, tag: str): ... + def update_tag(self, tag: str, version: int): ... + def optimize_indices(self, **kwargs): ... + def create_index( + self, + columns: List[str], + index_type: str, + name: Optional[str] = None, + replace: Optional[bool] = None, + storage_options: Optional[Dict[str, str]] = None, + kwargs: Optional[Dict[str, Any]] = None, + ): ... + def drop_index(self, name: str): ... + def prewarm_index(self, name: str): ... + def count_fragments(self) -> int: ... + def num_small_files(self, max_rows_per_group: int) -> int: ... + def get_fragments(self) -> List[_Fragment]: ... + def get_fragment(self, fragment_id: int) -> Optional[_Fragment]: ... + def index_cache_entry_count(self) -> int: ... + def index_cache_hit_rate(self) -> float: ... + def session(self) -> _Session: ... + @staticmethod + def drop(dest: str, storage_options: Optional[Dict[str, str]] = None): ... + @staticmethod + def commit( + dest: str | _Dataset, + operation: LanceOperation.BaseOperation, + blobs_op: Optional[LanceOperation.BaseOperation] = None, + read_version: Optional[int] = None, + commit_lock: Optional[CommitLock] = None, + storage_options: Optional[Dict[str, str]] = None, + enable_v2_manifest_paths: Optional[bool] = None, + detached: Optional[bool] = None, + max_retries: Optional[int] = None, + **kwargs, + ) -> _Dataset: ... + @staticmethod + def commit_batch( + dest: str | _Dataset, + transactions: Sequence[Transaction], + commit_lock: Optional[CommitLock] = None, + storage_options: Optional[Dict[str, str]] = None, + enable_v2_manifest_paths: Optional[bool] = None, + detached: Optional[bool] = None, + max_retries: Optional[int] = None, + ) -> Tuple[_Dataset, Transaction]: ... + def validate(self): ... + def migrate_manifest_paths_v2(self): ... + def drop_columns(self, columns: List[str]): ... + def add_columns_from_reader( + self, reader: pa.RecordBatchReader, batch_size: Optional[int] = None + ): ... + def add_columns( + self, + transforms: Dict[str, str] | BatchUDF | ReaderLike, + read_columns: Optional[List[str]] = None, + batch_size: Optional[int] = None, + ): ... + def add_columns_with_schema(self, schema: pa.Schema): ... + +class _MergeInsertBuilder: + def __init__(self, dataset: _Dataset, on: str | Iterable[str]): ... + def when_matched_update_all(self, condition: Optional[str] = None) -> Self: ... + def when_not_matched_insert_all(self) -> Self: ... + def when_not_matched_by_source_delete(self, expr: Optional[str] = None) -> Self: ... + def execute(self, new_data: pa.RecordBatchReader) -> ExecuteResult: ... + +class _Scanner: + @property + def schema(self) -> pa.Schema: ... + def explain_plan(self, verbose: bool) -> str: ... + def analyze_plan(self) -> str: ... + def count_rows(self) -> int: ... + def to_pyarrow(self) -> pa.RecordBatchReader: ... + +class _Fragment: + @staticmethod + def create_from_file( + filename: str, + dataset: _Dataset, + fragment_id: int, + ) -> FragmentMetadata: ... + @staticmethod + def create( + dataset_uri: str, + fragment_id: Optional[int], + reader: ReaderLike, + **kwargs, + ): ... + def id(self) -> int: ... + def metadata(self) -> FragmentMetadata: ... + def count_rows(self, _filter: Optional[str] = None) -> int: ... + def take( + self, + row_indices: List[int], + columns: Optional[Union[List[str], Dict[str, str]]], + ) -> pa.RecordBatch: ... + def scanner( + self, + columns: Optional[List[str]], + columns_with_transform: Optional[List[Tuple[str, str]]], + batch_size: Optional[int], + filter: Optional[str], + limit: Optional[int], + offset: Optional[int], + with_row_id: Optional[bool], + batch_readahead: Optional[int], + **kwargs, + ) -> _Scanner: ... + def add_columns_from_reader( + self, + reader: ReaderLike, + batch_size: Optional[int], + ) -> Tuple[FragmentMetadata, LanceSchema]: ... + def add_columns( + self, + transforms: Dict[str, str] | BatchUDF | ReaderLike, + read_columns: Optional[List[str]], + batch_size: Optional[int], + ) -> Tuple[FragmentMetadata, LanceSchema]: ... + def delete(self, predicate: str) -> Optional[_Fragment]: ... + def schema(self) -> pa.Schema: ... + def data_files(self) -> List[DataFile]: ... + def deletion_file(self) -> Optional[str]: ... + @property + def physical_rows(self) -> int: ... + @property + def num_deletions(self) -> int: ... + +def iops_counter() -> int: ... +def bytes_read_counter() -> int: ... +def _write_dataset( + reader: pa.RecordBatchReader, uri: str | Path | _Dataset, params: Dict[str, Any] +): ... +def _write_fragments( + dataset_uri: str | Path | _Dataset, + reader: ReaderLike, + mode: str, + max_rows_per_file: int, + max_rows_per_group: int, + max_bytes_per_file: int, + progress: Optional[FragmentWriteProgress], + data_storage_version: Optional[str], + storage_options: Optional[Dict[str, str]], + enable_move_stable_row_ids: bool, +): ... +def _write_fragments_transaction( + dataset_uri: str | Path | _Dataset, + reader: ReaderLike, + mode: str, + max_rows_per_file: int, + max_rows_per_group: int, + max_bytes_per_file: int, + progress: Optional[FragmentWriteProgress], + data_storage_version: Optional[str], + storage_options: Optional[Dict[str, str]], + enable_move_stable_row_ids: bool, +) -> Transaction: ... +def _json_to_schema(schema_json: str) -> pa.Schema: ... +def _schema_to_json(schema: pa.Schema) -> str: ... + +class _Hnsw: + @staticmethod + def build( + vectors_array: Iterator[pa.Array], + max_level: int, + m: int, + ef_construction: int, + ): ... + def to_lance_file(self, file_path: str): ... + def vectors(self) -> pa.Array: ... + +class _KMeans: + def __init__( + self, + k: int, + metric_type: str, + max_iters: int, + centroids_arr: Optional[pa.FixedSizeListArray] = None, + ): ... + def fit(self, data: pa.FixedSizeListArray): ... + def predict(self, data: pa.FixedSizeListArray) -> pa.UInt32Array: ... + def centroids( + self, + ) -> Union[pa.FixedShapeTensorType, pa.FixedSizeListType | None]: ... + +class BFloat16: + def __init__(self, value: float) -> None: ... + @classmethod + def from_bytes(cls, bytes: bytes) -> BFloat16: ... + def as_float(self) -> float: ... + def __lt__(self, other: BFloat16) -> bool: ... + def __le__(self, other: BFloat16) -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __ne__(self, other: object) -> bool: ... + def __gt__(self, other: BFloat16) -> bool: ... + def __ge__(self, other: BFloat16) -> bool: ... + +def bfloat16_array(values: List[str | None]) -> BFloat16Array: ... + +class PyFullTextQuery: + @staticmethod + def match_query( + column: str, + query: str, + boost: float = 1.0, + fuzziness: Optional[int] = 0, + max_expansions: int = 50, + operator: str = "OR", + ) -> PyFullTextQuery: ... + @staticmethod + def phrase_query( + query: str, + column: str, + ) -> PyFullTextQuery: ... + @staticmethod + def boost_query( + positive: PyFullTextQuery, + negative: PyFullTextQuery, + negative_boost: Optional[float], + ) -> PyFullTextQuery: ... + @staticmethod + def multi_match_query( + query: str, + columns: List[str], + boosts: Optional[List[float]] = None, + operator: str = "OR", + ) -> PyFullTextQuery: ... + +class ScanStatistics: + iops: int + bytes_read: int + indices_loaded: int + parts_loaded: int + +__version__: str +language_model_home: Callable[[], str] diff --git a/python/python/lance/lance/datagen/__init__.pyi b/python/python/lance/lance/datagen/__init__.pyi index c1d2ae43b4a..b3a2b61921a 100644 --- a/python/python/lance/lance/datagen/__init__.pyi +++ b/python/python/lance/lance/datagen/__init__.pyi @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import pyarrow as pa def rand_batches( - schema: pa.Schema, num_batches: int = None, batch_size_bytes: int = None + schema: pa.Schema, + num_batches: Optional[int] = None, + batch_size_bytes: Optional[int] = None, ): ... def is_datagen_supported() -> bool: ... diff --git a/python/python/lance/debug.pyi b/python/python/lance/lance/debug.pyi similarity index 100% rename from python/python/lance/debug.pyi rename to python/python/lance/lance/debug.pyi diff --git a/python/python/lance/lance/fragment.pyi b/python/python/lance/lance/fragment.pyi new file mode 100644 index 00000000000..dd3463e45a6 --- /dev/null +++ b/python/python/lance/lance/fragment.pyi @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from typing import Literal, Optional + +class DeletionFile: + """ + Metadata for a deletion file. + + The deletion file contains the row ids that are tombstoned. + + Attributes + ---------- + read_version : int + The read version of the deletion file. + id : int + A unique identifier for the deletion file, used to prevent collisions. + num_deleted_rows : int + The number of rows that are deleted. + file_type : str + The type of deletion file. "array" is used for small sets, while + "bitmap" is used for large sets. + """ + + read_version: int + id: int + num_deleted_rows: int + file_type: Literal["array", "bitmap"] + + def __init__( + self, + read_version: int, + id: int, + file_type: Literal["array", "bitmap"], + num_deleted_rows: int, + ): ... + def asdict(self) -> dict: + """Get a dictionary representation of the deletion file.""" + ... + def path(self, fragment_id: int, base_uri: Optional[str] = None) -> str: + """ + Get the path to the deletion file. + + Parameters + ---------- + fragment_id : int + The fragment id. + base_uri : str, optional + The base URI to use for the path. If not provided, a relative path + is returned. + + Returns + ------- + str + The path to the deletion file. + """ + ... + + def json(self) -> str: + """Get a JSON representation of the deletion file. + + Returns + ------- + str + + Warning + ------- + The JSON representation is not guaranteed to be stable across versions. + """ + ... + + @classmethod + def from_json(json: str) -> DeletionFile: + """ + Load a deletion file from a JSON representation. + + Parameters + ---------- + json : str + The JSON representation of the deletion file. + + Returns + ------- + DeletionFile + """ + ... + + def __reduce__(self) -> tuple: ... + +class RowIdMeta: + def json(self) -> str: + """Get a JSON representation of the row id metadata. + + Returns + ------- + str + + Warning + ------- + The JSON representation is not guaranteed to be stable across versions. + """ + ... + + @classmethod + def from_json(json: str) -> RowIdMeta: + """ + Load row id metadata from a JSON representation. + + Parameters + ---------- + json : str + The JSON representation of the row id metadata. + + Returns + ------- + RowIdMeta + """ + ... + + def __reduce__(self) -> tuple: ... diff --git a/python/python/lance/optimize.pyi b/python/python/lance/lance/optimize.pyi similarity index 81% rename from python/python/lance/optimize.pyi rename to python/python/lance/lance/optimize.pyi index fde9093e5df..9a26d23c003 100644 --- a/python/python/lance/optimize.pyi +++ b/python/python/lance/lance/optimize.pyi @@ -14,7 +14,7 @@ from typing import List -from lance import Dataset +from lance import LanceDataset from lance.fragment import FragmentMetadata from lance.optimize import CompactionOptions @@ -34,7 +34,7 @@ class CompactionTask: read_version: int fragments: List["FragmentMetadata"] - def execute(self, dataset: "Dataset") -> RewriteResult: ... + def execute(self, dataset: "LanceDataset") -> RewriteResult: ... class CompactionPlan: read_version: int @@ -45,11 +45,11 @@ class CompactionPlan: class Compaction: @staticmethod def execute( - dataset: "Dataset", options: CompactionOptions + dataset: "LanceDataset", options: CompactionOptions ) -> CompactionMetrics: ... @staticmethod - def plan(dataset: "Dataset", options: CompactionOptions) -> CompactionPlan: ... + def plan(dataset: "LanceDataset", options: CompactionOptions) -> CompactionPlan: ... @staticmethod def commit( - dataset: "Dataset", rewrites: List[RewriteResult] + dataset: "LanceDataset", rewrites: List[RewriteResult] ) -> CompactionMetrics: ... diff --git a/python/python/lance/lance/schema.pyi b/python/python/lance/lance/schema.pyi new file mode 100644 index 00000000000..6bbb54a4b4d --- /dev/null +++ b/python/python/lance/lance/schema.pyi @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from typing import Any, Dict, List + +import pyarrow as pa + +class LanceField: + def name(self) -> str: ... + def id(self) -> int: ... + def children(self) -> List[LanceField]: ... + +class LanceSchema: + def fields(self) -> List[LanceField]: ... + def to_pyarrow(self) -> pa.Schema: ... + @staticmethod + def from_pyarrow(schema: pa.Schema) -> "LanceSchema": ... + +def schema_to_json(schema: pa.Schema) -> Dict[str, Any]: ... +def json_to_schema(schema_json: Dict[str, Any]) -> pa.Schema: ... diff --git a/python/python/lance/lance/trace.pyi b/python/python/lance/lance/trace.pyi new file mode 100644 index 00000000000..15b6cb260af --- /dev/null +++ b/python/python/lance/lance/trace.pyi @@ -0,0 +1,3 @@ +from typing import Optional + +def trace_to_chrome(file: Optional[str] = None, level: Optional[str] = None): ... diff --git a/python/python/lance/log.py b/python/python/lance/log.py new file mode 100644 index 00000000000..b1f0a1c591f --- /dev/null +++ b/python/python/lance/log.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import logging +import os +from typing import Optional + +ENV_NAME_PYLANCE_LOGGING_LEVEL = "LANCE_LOG" + + +# Rust has 'trace' and Python does not so we map it to 'debug' +def get_python_log_level(rust_log_level: str) -> str: + if rust_log_level.lower() == "trace": + return "DEBUG" + return rust_log_level + + +def get_log_level(): + lance_log_level = os.environ.get(ENV_NAME_PYLANCE_LOGGING_LEVEL, "INFO").upper() + if lance_log_level == "": + return "INFO" + + lance_log_level = [ + entry for entry in lance_log_level.split(",") if "=" not in entry + ] + if len(lance_log_level) > 0: + return get_python_log_level(lance_log_level[0]) + else: + return "INFO" + + +LOGGER = logging.getLogger("pylance") +LOGGER.setLevel(get_log_level()) + + +def set_logger( + file_path: Optional[str] = "pylance.log", + name="pylance", + level=logging.INFO, + format_string=None, + log_handler=None, +): + global LOGGER + if not format_string: + format_string = "%(asctime)s %(name)s [%(levelname)s] %(filename)s:%(lineno)d %(funcName)s : %(message)s" # noqa E501 + LOGGER = logging.getLogger(name) + LOGGER.setLevel(level) + lh = log_handler + if lh is None: + lh = logging.FileHandler(file_path) + lh.setLevel(level) + formatter = logging.Formatter(format_string) + lh.setFormatter(formatter) + LOGGER.addHandler(lh) diff --git a/python/python/lance/query.py b/python/python/lance/query.py new file mode 100644 index 00000000000..21aa6f3ccf0 --- /dev/null +++ b/python/python/lance/query.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + + +import abc +from enum import Enum +from typing import Optional + +from .lance import PyFullTextQuery + + +class FullTextQueryType(Enum): + MATCH = "match" + MATCH_PHRASE = "match_phrase" + BOOST = "boost" + MULTI_MATCH = "multi_match" + + +class FullTextOperator(Enum): + AND = "AND" + OR = "OR" + + +class FullTextQuery(abc.ABC): + _inner: PyFullTextQuery + + @property + def inner(self) -> PyFullTextQuery: + """ + Get the inner query object. + + Returns + ------- + PyFullTextQuery + The inner query object. + """ + return self._inner + + @abc.abstractmethod + def query_type(self) -> FullTextQueryType: + """ + Get the query type of the query. + + Returns + ------- + str + The type of the query. + """ + + +class MatchQuery(FullTextQuery): + def __init__( + self, + query: str, + column: str, + *, + boost: float = 1.0, + fuzziness: int = 0, + max_expansions: int = 50, + operator: FullTextOperator = FullTextOperator.OR, + ): + """ + Match query for full-text search. + + Parameters + ---------- + query : str + The query string to match against. + column : str + The name of the column to match against. + boost : float, default 1.0 + The boost factor for the query. + The score of each matching document is multiplied by this value. + fuzziness : int, optional + The maximum edit distance for each term in the match query. + Defaults to 0 (exact match). + If None, fuzziness is applied automatically by the rules: + - 0 for terms with length <= 2 + - 1 for terms with length <= 5 + - 2 for terms with length > 5 + max_expansions : int, optional + The maximum number of terms to consider for fuzzy matching. + Defaults to 50. + """ + self._inner = PyFullTextQuery.match_query( + query, + column, + boost=boost, + fuzziness=fuzziness, + max_expansions=max_expansions, + operator=operator.value, + ) + + def query_type(self) -> FullTextQueryType: + return FullTextQueryType.MATCH + + +class PhraseQuery(FullTextQuery): + def __init__(self, query: str, column: str): + """ + Phrase query for full-text search. + + Parameters + ---------- + query : str + The query string to match against. + column : str + The name of the column to match against. + """ + self._inner = PyFullTextQuery.phrase_query(query, column) + + def query_type(self) -> FullTextQueryType: + return FullTextQueryType.MATCH_PHRASE + + +class BoostQuery(FullTextQuery): + def __init__( + self, + positive: FullTextQuery, + negative: FullTextQuery, + *, + negative_boost: float = 0.5, + ): + """ + Boost query for full-text search. + + Parameters + ---------- + positive : dict + The positive query object. + negative : dict + The negative query object. + negative_boost : float, default 0.5 + The boost factor for the negative query. + """ + self._inner = PyFullTextQuery.boost_query( + positive.inner, negative.inner, negative_boost + ) + + def query_type(self) -> FullTextQueryType: + return FullTextQueryType.BOOST + + +class MultiMatchQuery(FullTextQuery): + def __init__( + self, + query: str, + columns: list[str], + *, + boosts: Optional[list[float]] = None, + operator: FullTextOperator = FullTextOperator.OR, + ): + """ + Multi-match query for full-text search. + + Parameters + ---------- + query : str | list[Query] + If a string, the query string to match against. + + columns : list[str] + The list of columns to match against. + + boosts : list[float], optional + The list of boost factors for each column. If not provided, + all columns will have the same boost factor. + operator : FullTextOperator, default OR + The operator to use for combining the query results. + Can be either `AND` or `OR`. + It would be applied to all columns individually. + For example, if the operator is `AND`, + then the query "hello world" is equal to + `match("hello AND world", column1) OR match("hello AND world", column2)`. + """ + self._inner = PyFullTextQuery.multi_match_query( + query, columns, boosts=boosts, operator=operator.value + ) + + def query_type(self) -> FullTextQueryType: + return FullTextQueryType.MULTI_MATCH diff --git a/python/python/lance/ray/distribute_task.py b/python/python/lance/ray/distribute_task.py new file mode 100644 index 00000000000..3296c10907d --- /dev/null +++ b/python/python/lance/ray/distribute_task.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +import lance + +# ============================================================================== +# Message Structure Constants +# ============================================================================== +TASK_ID_KEY = "task_id" +PARTITION_KEY = "partition" + +# ============================================================================== +# Data Component Keys +# ============================================================================== +FRAGMENT_KEY = "fragment" +SCHEMA_KEY = "schema" + +# ============================================================================== +# Operation Parameters +# ============================================================================== +PARAMS_KEY = "params" +ACTION_KEY = "action" +READ_COLUMNS_KEY = "read_columns" + +# ============================================================================== +# Execution Metadata +# ============================================================================== +OPERATION_TYPE_KEY = "operation_type" +VERSION_KEY = "version" + + +@dataclass +class TaskInput: + """Container for task execution parameters and metadata.""" + + task_id: str + fn: Callable + fragment: Any + params: Dict[str, Any] = field(default_factory=dict) + + +class FragmentTask: + """Base class for distributed data processing tasks.""" + + def __init__(self, task_input: TaskInput): + self.task_input = task_input + + def __call__(self) -> Dict[str, Any]: + output = self._fn() + return { + TASK_ID_KEY: self.task_input.task_id, + PARTITION_KEY: {FRAGMENT_KEY: self.task_input.fragment, "output": output}, + } + + +class AddColumnTask(FragmentTask): + """Task for adding new columns to dataset fragments.""" + + def __init__(self, task_input: TaskInput, read_columns): + super().__init__(task_input) + self._read_columns = read_columns + self._validate_input_params() + + def _validate_input_params(self) -> None: + """Ensure required parameters are present and valid.""" + if self.task_input.fragment is None: + raise ValueError("Fragment must be provided for column addition") + + def __call__(self) -> Dict[str, Any]: + """Execute column addition and return updated fragment metadata.""" + new_fragment, new_schema = self.task_input.fragment.merge_columns( + value_func=self.task_input.fn, columns=self._read_columns + ) + return { + TASK_ID_KEY: self.task_input.task_id, + PARTITION_KEY: {FRAGMENT_KEY: new_fragment, SCHEMA_KEY: new_schema}, + } + + +class DispatchFragmentTasks: + """Orchestrates distributed execution of fragment operations.""" + + def __init__(self, dataset: lance.LanceDataset): + self.dataset = dataset + + def get_tasks( + self, transform_fn: Callable, operation_params: Optional[Dict[str, Any]] = None + ) -> List[FragmentTask]: + """Generate tasks for processing all dataset fragments.""" + operation_params = operation_params or {} + return [ + self._create_task(fragment, transform_fn, operation_params) + for fragment in self.dataset.get_fragments() + ] + + def _create_task( + self, fragment: Any, transform_fn: Callable, params: Dict[str, Any] + ) -> FragmentTask: + """Factory method for creating appropriate task type.""" + task_input = TaskInput( + task_id=fragment.fragment_id, + fn=transform_fn, + fragment=fragment, + params=params, + ) + + if params[ACTION_KEY] == "add_column": + return AddColumnTask(task_input, params[READ_COLUMNS_KEY]) + + raise ValueError(f"Unsupported operation: {params[ACTION_KEY]}") + + def commit_results(self, partitions: List[Dict[str, Any]]) -> bool: + """Commit processed results to the dataset.""" + if not partitions: + return False + + fragments = [part[FRAGMENT_KEY] for part in partitions] + unified_schema = partitions[0][SCHEMA_KEY] + + operation = lance.LanceOperation.Merge(fragments, unified_schema) + self.dataset.commit( + base_uri=self.dataset.uri, + operation=operation, + read_version=self.dataset.version, + ) + return True diff --git a/python/python/lance/ray/fragment_api.py b/python/python/lance/ray/fragment_api.py new file mode 100644 index 00000000000..3eb5dc02b2b --- /dev/null +++ b/python/python/lance/ray/fragment_api.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from typing import Any, Callable, Dict, List, Union + +import pyarrow as pa +from ray.data import from_items + +import lance + +from .distribute_task import PARTITION_KEY, DispatchFragmentTasks + +# ============================================================================== +# Component Keys +# ============================================================================== +ITEM_KEY = "item" +READ_COLUMNS_KEY = "read_columns" +ACTION_KEY = "action" +ADD_COLUMN_ACTION = "add_column" + +# ============================================================================== +# Type Aliases +# ============================================================================== +RecordBatchTransformer = Callable[[pa.RecordBatch], pa.RecordBatch] + + +def execute_fragment_operation( + task_dispatcher: "DispatchFragmentTasks", + value_function: Union[Dict[str, str], RecordBatchTransformer], + operation_parameters: Dict[str, Any] = None, +) -> None: + """ + Execute distributed fragment operations and commit results. + + Args: + task_dispatcher: Coordinator for fragment tasks + value_function: Data transformation logic + operation_parameters: Contextual parameters for the operation + """ + operation_parameters = operation_parameters or {} + + # Generate and execute distributed tasks + processing_tasks = task_dispatcher.get_tasks(value_function, operation_parameters) + task_dataset = from_items(processing_tasks).map(lambda task: task[ITEM_KEY]()) + + # Collect and commit results + results = [item[PARTITION_KEY] for item in task_dataset.take_all()] + task_dispatcher.commit_results(results) + + +def add_columns( + dataset: lance.LanceDataset, + column_generator: RecordBatchTransformer, + source_columns: List[str], +) -> None: + """ + Add new columns to a Lance dataset through distributed processing. + + Args: + dataset: Target dataset for column addition + column_generator: Function generating new column values + source_columns: Existing columns required for generation + """ + dispatcher = DispatchFragmentTasks(dataset) + execute_fragment_operation( + dispatcher, + value_function=column_generator, + operation_parameters={ + READ_COLUMNS_KEY: source_columns, + ACTION_KEY: ADD_COLUMN_ACTION, + }, + ) diff --git a/python/python/lance/ray/sink.py b/python/python/lance/ray/sink.py index 7c3d8f9c4a3..cfcde00f462 100644 --- a/python/python/lance/ray/sink.py +++ b/python/python/lance/ray/sink.py @@ -29,6 +29,8 @@ __all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter", "write_lance"] +NONE_ARROW_STR = "None" + def _pd_to_arrow( df: Union[pa.Table, "pd.DataFrame", Dict], schema: Optional[pa.Schema] @@ -39,16 +41,18 @@ def _pd_to_arrow( if isinstance(df, dict): return pa.Table.from_pydict(df, schema=schema) - if _PANDAS_AVAILABLE and isinstance(df, pd.DataFrame): + elif _PANDAS_AVAILABLE and isinstance(df, pd.DataFrame): tbl = pa.Table.from_pandas(df, schema=schema) - new_schema = tbl.schema.remove_metadata() - new_table = tbl.replace_schema_metadata(new_schema.metadata) - return new_table + tbl.schema = tbl.schema.remove_metadata() + return tbl + elif isinstance(df, pa.Table): + if schema is not None: + return df.cast(schema) return df def _write_fragment( - stream: Iterable[Union[pa.Table, "pd.Pandas"]], + stream: Iterable[Union[pa.Table, "pd.DataFrame"]], uri: str, *, schema: Optional[pa.Schema] = None, @@ -57,7 +61,7 @@ def _write_fragment( max_rows_per_group: int = 1024, # Only useful for v1 writer. data_storage_version: Optional[str] = None, storage_options: Optional[Dict[str, Any]] = None, -) -> Tuple[FragmentMetadata, pa.Schema]: +) -> List[Tuple[FragmentMetadata, pa.Schema]]: from ..dependencies import _PANDAS_AVAILABLE from ..dependencies import pandas as pd @@ -131,6 +135,37 @@ def on_write_complete( self, write_results: List[List[Tuple[str, str]]], ): + import warnings + + if not write_results: + warnings.warn( + "write_results is empty.", + DeprecationWarning, + ) + return + if ( + not isinstance(write_results, list) + or not isinstance(write_results[0], list) + ) and not hasattr(write_results, "write_returns"): + warnings.warn( + "write_results type is wrong. please check version, " + "upgrade or downgrade your ray version. ray versions >= 2.38 " + "and < 2.41 are unable to write Lance datasets, check ray PR " + "https://github.com/ray-project/ray/pull/49251 in your " + "ray version. ", + DeprecationWarning, + ) + return + if hasattr(write_results, "write_returns"): + write_results = write_results.write_returns + + if len(write_results) == 0: + warnings.warn( + "write results is empty. please check ray version or internal error", + DeprecationWarning, + ) + return + fragments = [] schema = None for batch in write_results: @@ -389,6 +424,7 @@ def write_lance( output_uri: str, *, schema: Optional[pa.Schema] = None, + mode: Literal["create", "append", "overwrite"] = "create", transform: Optional[ Callable[[pa.Table], Union[pa.Table, Generator[None, pa.Table, None]]] ] = None, @@ -435,7 +471,9 @@ def write_lance( ), batch_size=max_rows_per_file, ).write_datasink( - LanceCommitter(output_uri, schema=schema, storage_options=storage_options) + LanceCommitter( + output_uri, schema=schema, mode=mode, storage_options=storage_options + ) ) diff --git a/python/python/lance/sampler.py b/python/python/lance/sampler.py index f283a7daa46..820ccbacbf6 100644 --- a/python/python/lance/sampler.py +++ b/python/python/lance/sampler.py @@ -5,7 +5,6 @@ from __future__ import annotations import gc -import logging import math import random import warnings @@ -13,13 +12,23 @@ from dataclasses import dataclass, field from heapq import heappush, heappushpop from pathlib import Path -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Dict, + Generic, + Iterable, + List, + Optional, + TypeVar, + Union, +) import pyarrow as pa import pyarrow.compute as pc import lance from lance.dependencies import numpy as np +from lance.log import LOGGER if TYPE_CHECKING: from collections.abc import Generator @@ -94,7 +103,7 @@ def _efficient_sample( ).to_batches() ) if idx % 50 == 0: - logging.info("Sampled at offset=%s, len=%s", offset, chunk_sample_size) + LOGGER.info("Sampled at offset=%s, len=%s", offset, chunk_sample_size) if sum(len(b) for b in buf) >= batch_size: tbl = pa.Table.from_batches(buf) buf.clear() @@ -110,7 +119,7 @@ def _efficient_sample( def _filtered_efficient_sample( dataset: lance.LanceDataset, n: int, - columns: Optional[Union[List[str], Dict[str, str]]], + columns: List[str], batch_size: int, target_takes: int, filter: str, @@ -162,7 +171,7 @@ def _filtered_efficient_sample( def maybe_sample( dataset: Union[str, Path, lance.LanceDataset], n: int, - columns: Union[list[str], dict[str, str], str], + columns: Union[list[str], str], batch_size: int = 10240, max_takes: int = 2048, filt: Optional[str] = None, @@ -225,7 +234,7 @@ def maybe_sample( @dataclass(order=True) -class PrioritizedItem: +class PrioritizedItem(Generic[T]): priority: int item: T = field(compare=False) @@ -241,7 +250,7 @@ def reservoir_sampling(stream: Iterable[T], k: int) -> list[T]: vic = heappushpop(heap, entry) del vic if idx % 10240 == 0: - logging.info("Force Python GC") + LOGGER.info("Force Python GC") gc.collect() samples = [i.item for i in heap] del heap @@ -314,7 +323,8 @@ class FullScanSampler(FragmentSampler): def iter_fragments( self, dataset: lance.LanceDataset, **kwargs ) -> Generator[lance.LanceFragment, None, None]: - return dataset.get_fragments() + for fragment in dataset.get_fragments(): + yield fragment class ShardedFragmentSampler(FragmentSampler): @@ -420,7 +430,7 @@ def _shard_scan( columns: Optional[Union[List[str], Dict[str, str]]], batch_readahead: int, filter: str, - ) -> Generator[lance.RecordBatch, None, None]: + ) -> Generator[pa.RecordBatch, None, None]: accumulated_batches = [] rows_accumulated = 0 rows_to_skip = self._rank @@ -471,7 +481,7 @@ def _sample_filtered( columns: Optional[Union[List[str], Dict[str, str]]], batch_readahead: int, filter: str, - ) -> Generator[lance.RecordBatch, None, None]: + ) -> Generator[pa.RecordBatch, None, None]: shard_scan = self._shard_scan( dataset, batch_size, columns, batch_readahead, filter ) @@ -508,9 +518,9 @@ def _sample_all( self, dataset: lance.LanceDataset, batch_size: int, - columns: Optional[Union[List[str], Dict[str, str]]], + columns: Optional[List[str]], batch_readahead: int, - ) -> Generator[lance.RecordBatch, None, None]: + ) -> Generator[pa.RecordBatch, None, None]: total = dataset.count_rows() def _gen_ranges(): @@ -537,12 +547,12 @@ def __call__( dataset: lance.LanceDataset, *args, batch_size: int = 128, - columns: Optional[Union[List[str], Dict[str, str]]] = None, + columns: Optional[List[str]] = None, filter: Optional[str] = None, batch_readahead: int = 16, with_row_id: Optional[bool] = None, **kwargs, - ) -> Generator[lance.RecordBatch, None, None]: + ) -> Generator[pa.RecordBatch, None, None]: if filter is None: if with_row_id is not None: warnings.warn( diff --git a/python/python/lance/schema.pyi b/python/python/lance/schema.pyi deleted file mode 100644 index 256a066dddc..00000000000 --- a/python/python/lance/schema.pyi +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright The Lance Authors - -import pyarrow as pa - -class LanceSchema: - def to_pyarrow(self) -> pa.Schema: ... - def from_pyarrow(schema: pa.Schema) -> "LanceSchema": ... diff --git a/python/python/lance/tf/data.py b/python/python/lance/tf/data.py index 861d18c0cec..6efe2d3c837 100644 --- a/python/python/lance/tf/data.py +++ b/python/python/lance/tf/data.py @@ -12,7 +12,6 @@ from __future__ import annotations -import logging from functools import partial from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -25,6 +24,7 @@ from lance.dependencies import numpy as np from lance.dependencies import tensorflow as tf from lance.fragment import FragmentMetadata, LanceFragment +from lance.log import LOGGER if TYPE_CHECKING: from pathlib import Path @@ -215,7 +215,7 @@ def gen_fragments(fragments): ): yield LanceFragment(dataset, int(f)) elif isinstance(f, FragmentMetadata): - yield LanceFragment(dataset, f.fragment_id) + yield LanceFragment(dataset, f.id) elif isinstance(f, LanceFragment): yield f else: @@ -231,7 +231,7 @@ def gen_fragments(fragments): if output_signature is None: schema = scanner.projected_schema output_signature = schema_to_spec(schema) - logging.debug("Output signature: %s", output_signature) + LOGGER.debug("Output signature: %s", output_signature) def generator(): for batch in scanner.to_batches(): @@ -315,7 +315,7 @@ def lance_take_batches( dataset: Union[str, Path, LanceDataset], batch_ranges: Iterable[Tuple[int, int]], *, - columns: Optional[Union[List[str], Dict[str, str]]] = None, + columns: Optional[List[str]] = None, output_signature: Optional[Dict[str, tf.TypeSpec]] = None, batch_readahead: int = 10, ) -> tf.data.Dataset: @@ -356,7 +356,7 @@ def lance_take_batches( if output_signature is None: schema = dataset.scanner(columns=columns).projected_schema output_signature = schema_to_spec(schema) - logging.debug("Output signature: %s", output_signature) + LOGGER.debug("Output signature: %s", output_signature) def gen_ranges(): for start, end in batch_ranges: diff --git a/python/python/lance/torch/async_dataset.py b/python/python/lance/torch/async_dataset.py index 2e925817ab8..37081c0a4f0 100644 --- a/python/python/lance/torch/async_dataset.py +++ b/python/python/lance/torch/async_dataset.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright The Lance Authors import contextlib -import logging from multiprocessing import Process, Queue, Value from typing import Callable, Iterable from torch.utils.data import IterableDataset +from lance.log import LOGGER + def _worker_ep( dataset_creator: Callable[[], IterableDataset], @@ -69,7 +70,7 @@ def close(self): for _ in self: pass except Exception as e: - logging.exception(e) + LOGGER.exception(e) pass self.queue.close() self.worker.join() diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index 05b2b4d737f..744f617b904 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -7,10 +7,11 @@ from __future__ import annotations import json +import logging import math import warnings from pathlib import Path -from typing import Dict, Iterable, List, Literal, Optional, Union +from typing import Callable, Dict, Iterable, List, Literal, Optional, Union import pyarrow as pa @@ -28,7 +29,7 @@ ) from .dist import get_global_rank, get_global_world_size -__all__ = ["LanceDataset"] +__all__ = ["LanceDataset", "SafeLanceDataset", "get_safe_loader"] # Convert an Arrow FSL array into a 2D torch tensor @@ -40,22 +41,37 @@ def _fsl_to_tensor(arr: pa.FixedSizeListArray, dimension: int) -> torch.Tensor: num_vals = len(arr) * dimension values = values.slice(start, num_vals) # Convert to numpy - nparr = values.to_numpy(zero_copy_only=True).reshape(-1, dimension) + nparr = values.to_numpy(zero_copy_only=False).reshape(-1, dimension) return torch.from_numpy(nparr) def _to_tensor( - batch: pa.RecordBatch, + batch: Union[pa.RecordBatch, Dict[str, pa.Array]], *, uint64_as_int64: bool = True, hf_converter: Optional[dict] = None, + use_blob_api: bool = False, + **kwargs, ) -> Union[dict[str, torch.Tensor], torch.Tensor]: """Convert a pyarrow RecordBatch to torch Tensor.""" ret = {} - for col in batch.schema.names: + cols = ( + batch.column_names if isinstance(batch, pa.RecordBatch) else list(batch.keys()) + ) + for col in cols: arr: pa.Array = batch[col] + if ( + use_blob_api + and isinstance(arr, list) + and arr + and isinstance(arr[0], lance.BlobFile) + ): + raise NotImplementedError( + 'Need user-provided "to_tensor_fn" for Blob files' + ) + tensor: torch.Tensor = None if (isinstance(arr.type, pa.FixedShapeTensorType)) and ( pa.types.is_floating(arr.type.value_type) @@ -73,7 +89,7 @@ def _to_tensor( or pa.types.is_floating(arr.type) or pa.types.is_boolean(arr.type) ): - tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=True)) + tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=False)) if uint64_as_int64 and tensor.dtype == torch.uint64: tensor = tensor.to(torch.int64) @@ -176,8 +192,8 @@ def __init__( shard_granularity: Optional[Literal["fragment", "batch"]] = None, batch_readahead: int = 16, to_tensor_fn: Optional[ - callable[[pa.RecordBatch], Union[dict[str, torch.Tensor], torch.Tensor]] - ] = None, + Callable[[pa.RecordBatch], Union[dict[str, torch.Tensor], torch.Tensor]] + ] = _to_tensor, sampler: Optional[Sampler] = None, **kwargs, ): @@ -219,7 +235,7 @@ def __init__( to_tensor_fn : callable, optional A function that converts a pyarrow RecordBatch to torch.Tensor. """ - super().__init__(*args, **kwargs) + super().__init__() if isinstance(dataset, (str, Path)): dataset = lance.dataset(dataset) self.dataset = dataset @@ -229,11 +245,13 @@ def __init__( self.filter = filter self.with_row_id = with_row_id self.batch_readahead = batch_readahead - if to_tensor_fn is None: - to_tensor_fn = _to_tensor self._to_tensor_fn = to_tensor_fn self._hf_converter = None + self._blob_columns = self._blob_columns() + if self._blob_columns: + self.with_row_id = True + # As Shared Dataset self.shard_granularity = shard_granularity self.rank = rank @@ -258,6 +276,13 @@ def __init__( def __repr__(self) -> str: return f"LanceTorchDataset({self.dataset.uri}, size={self.samples})" + @property + def schema(self) -> pa.Schema: + if not self.columns: + return self.dataset.schema + fields = [self.dataset.schema.field(col) for col in self.columns] + return pa.schema(fields, metadata=self.dataset.schema.metadata) + def __iter__(self): if self.sampler is None: if self.rank is not None and self.world_size is not None: @@ -280,6 +305,12 @@ def __iter__(self): else: sampler = self.sampler + projected_columns = self.columns or self.dataset.schema.names + if self._blob_columns: + projected_columns = [ + c for c in projected_columns if c not in self._blob_columns + ] + stream: Iterable[pa.RecordBatch] if self.cached_ds: stream = self.cached_ds @@ -288,14 +319,14 @@ def __iter__(self): raw_stream = maybe_sample( self.dataset, n=self.samples, - columns=self.columns, + columns=projected_columns, batch_size=self.batch_size, filt=self.filter, ) else: raw_stream = sampler( self.dataset, - columns=self.columns, + columns=projected_columns, filter=self.filter, batch_size=self.batch_size, with_row_id=self.with_row_id, @@ -308,8 +339,112 @@ def __iter__(self): self.cached_ds = CachedDataset(stream, cache=self.cache) stream = self.cached_ds + use_blob_api = bool(self._blob_columns) for batch in stream: + if use_blob_api: + dict_batch = {} + assert "_rowid" in batch.column_names + row_ids = batch["_rowid"] + for col in batch.column_names: + dict_batch[col] = batch[col] + for col in self._blob_columns: + dict_batch[col] = self.dataset.take_blobs( + row_ids=row_ids.to_pylist(), blob_column=col + ) + batch = dict_batch if self._to_tensor_fn is not None: - batch = self._to_tensor_fn(batch, hf_converter=self._hf_converter) + batch = self._to_tensor_fn( + batch, hf_converter=self._hf_converter, use_blob_api=use_blob_api + ) yield batch del batch + + def _blob_columns(self) -> List[str]: + """Returns True if one of the projected column is Large Blob encoded.""" + cols = self.columns + if not cols: + cols = self.dataset.schema.names + blob_cols = [] + for col in cols: + field = self.dataset.schema.field(col) + if ( + field.type == pa.large_binary() + and field.metadata is not None + and field.metadata.get(b"lance-encoding:blob") == b"true" + ): + logging.debug("Column %s is a Large Blob column", col) + blob_cols.append(col) + return blob_cols + + +class SafeLanceDataset(torch.utils.data.Dataset): + def __init__(self, uri): + self.uri = uri + self._len = self._safe_preload() + self._ds = None # Deferred initialization + + def _safe_preload(self): + """Main-process safe metadata loading""" + ds = lance.dataset(self.uri) + length = ds.count_rows() + del ds # Critical: release before spawning + return length + + def __len__(self): + return self._len + + def __getitem__(self, idx): + return self.get_items([idx])[0] + + def get_items(self, indices): + """Batch data fetching with worker-safe initialization + + Args: + indices: List[int] - batch indices to retrieve + + Returns: + List[dict] - samples in original data format + """ + if self._ds is None: + # Worker-process initialization + import os + + self._ds = lance.dataset(self.uri) + print(f"Worker {os.getpid()} initialized dataset") + + # Leverage native batch reading + batch = self._ds.take(indices) + + # Convert to python-native format + return batch.to_pylist() + + +def get_safe_loader(dataset, batch_size=32, num_workers=4, **kwargs): + """Create a DataLoader with safe multiprocessing defaults + + Args: + dataset: Input dataset object + batch_size: Number of samples per batch (default=32) + num_workers: Number of parallel data workers (default=4) + **kwargs: Additional DataLoader arguments. Note: + - Forces 'spawn' context for Windows compatibility + - Sets persistent_workers=True by default + - User-provided args override defaults + + Returns: + Configured DataLoader instance with process-safe settings + """ + + # Force spawn context for Windows/multiprocessing compatibility + ctx = torch.multiprocessing.get_context("spawn") + + # Configure default parameters with process safety + loader_args = { + "batch_size": batch_size, + "num_workers": num_workers, + "persistent_workers": kwargs.pop("persistent_workers", True), + "multiprocessing_context": ctx, + **kwargs, # User-provided arguments take priority + } + + return torch.utils.data.DataLoader(dataset, **loader_args) diff --git a/python/python/lance/torch/distance.py b/python/python/lance/torch/distance.py index c31d637ed03..06388210544 100644 --- a/python/python/lance/torch/distance.py +++ b/python/python/lance/torch/distance.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright The Lance Authors -import logging from typing import Optional, Tuple from lance.dependencies import torch +from lance.log import LOGGER __all__ = [ "pairwise_cosine", @@ -225,7 +225,7 @@ def l2_distance( return _l2_distance(vectors, centroids, split_size=split, y2=y2) except RuntimeError as e: # noqa: PERF203 if "CUDA out of memory" in str(e): - logging.warning( + LOGGER.warning( "L2: batch split=%s out of memory, attempt to use reduced split %s", split, split // 2, diff --git a/python/python/lance/torch/kmeans.py b/python/python/lance/torch/kmeans.py index 1881452c110..44fca9ae60a 100644 --- a/python/python/lance/torch/kmeans.py +++ b/python/python/lance/torch/kmeans.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -import logging import time from typing import List, Literal, Optional, Tuple, Union @@ -14,6 +13,8 @@ torch, ) from lance.dependencies import numpy as np +from lance.log import LOGGER +from lance.util import MetricType, _normalize_metric_type from . import preferred_device from .data import TensorDataset @@ -53,7 +54,7 @@ def __init__( self, k: int, *, - metric: Literal["l2", "euclidean", "cosine", "dot"] = "l2", + metric: MetricType = "l2", init: Literal["random"] = "random", max_iters: int = 50, tolerance: float = 1e-4, @@ -64,9 +65,8 @@ def __init__( self.k = k self.max_iters = max_iters - metric = metric.lower() - self.metric = metric - if metric in ["l2", "euclidean", "cosine"]: + self.metric = _normalize_metric_type(metric) + if metric in ["l2", "cosine"]: # Cosine uses normalized unit vector and calculate l2 distance self.dist_func = l2_distance elif metric == "dot": @@ -92,7 +92,7 @@ def _to_tensor( self, data: Union[pa.FixedSizeListArray, np.ndarray, torch.Tensor] ) -> torch.Tensor: if isinstance(data, pa.FixedSizeListArray): - np_tensor = data.values.to_numpy(zero_copy_only=True).reshape( + np_tensor = data.values.to_numpy(zero_copy_only=False).reshape( -1, data.type.list_size ) data = torch.from_numpy(np_tensor) @@ -113,7 +113,7 @@ def _to_tensor( def _random_init(self, data: Union[torch.Tensor, np.ndarray]): """Random centroid initialization.""" if self.centroids is not None: - logging.debug("KMeans centroids already initialized") + LOGGER.debug("KMeans centroids already initialized") return is_numpy = _check_for_numpy(data) and isinstance(data, np.ndarray) @@ -154,7 +154,7 @@ def fit( assert self.centroids is not None self.centroids = self.centroids.to(self.device) - logging.info( + LOGGER.info( "Start kmean training, metric: %s, iters: %s", self.metric, self.max_iters ) self.total_distance = 0 @@ -166,8 +166,8 @@ def fit( except StopIteration: break if i % 10 == 0: - logging.debug("Total distance: %s, iter: %s", self.total_distance, i) - logging.info("Finish KMean training in %s", time.time() - start) + LOGGER.debug("Total distance: %s, iter: %s", self.total_distance, i) + LOGGER.info("Finish KMean training in %s", time.time() - start) def _updated_centroids( self, centroids: torch.Tensor, counts: torch.Tensor @@ -234,7 +234,7 @@ def _fit_once( self.rebuild_index() for idx, chunk in enumerate(data): if idx % 50 == 0: - logging.info("Kmeans::train: epoch %s, chunk %s", epoch, idx) + LOGGER.info("Kmeans::train: epoch %s, chunk %s", epoch, idx) if column is not None: chunk = chunk[column] chunk: torch.Tensor = chunk @@ -264,7 +264,7 @@ def _fit_once( # vectors repeated over and over. Performance may be bad but we don't # want to crash. if total_dist == 0: - logging.warning( + LOGGER.warning( "Kmeans::train: total_dist is 0, this is unusual." " This could result in bad performance during search." ) diff --git a/python/python/lance/tracing.py b/python/python/lance/tracing.py index be35185b8e5..2605aabcf37 100644 --- a/python/python/lance/tracing.py +++ b/python/python/lance/tracing.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright The Lance Authors import atexit +from typing import Optional from .lance import trace_to_chrome as lance_trace_to_chrome -def trace_to_chrome(*, file: str = None, level: str = None): +def trace_to_chrome(*, file: Optional[str] = None, level: Optional[str] = None): """ Begins tracing lance events to a chrome trace file. diff --git a/python/python/lance/types.py b/python/python/lance/types.py index b0559c5ff15..58a935c3222 100644 --- a/python/python/lance/types.py +++ b/python/python/lance/types.py @@ -18,6 +18,7 @@ pa.Table, pa.dataset.Dataset, pa.dataset.Scanner, + pa.RecordBatch, Iterable[RecordBatch], pa.RecordBatchReader, ] @@ -73,6 +74,17 @@ def _coerce_reader( and data_obj.__class__.__name__ == "DataFrame" ): return data_obj.to_arrow().to_reader() + elif isinstance(data_obj, dict): + batch = pa.RecordBatch.from_pydict(data_obj, schema=schema) + return pa.RecordBatchReader.from_batches(batch.schema, [batch]) + elif ( + isinstance(data_obj, list) + and len(data_obj) > 0 + and isinstance(data_obj[0], dict) + ): + # List of dictionaries + batch = pa.RecordBatch.from_pylist(data_obj, schema=schema) + return pa.RecordBatchReader.from_batches(batch.schema, [batch]) # for other iterables, assume they are of type Iterable[RecordBatch] elif isinstance(data_obj, Iterable): if schema is not None: diff --git a/python/python/lance/util.py b/python/python/lance/util.py index 1ddc6ffdcd9..62da80fa202 100644 --- a/python/python/lance/util.py +++ b/python/python/lance/util.py @@ -4,7 +4,7 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Iterator, Literal, Optional, Union +from typing import TYPE_CHECKING, Iterator, Literal, Optional, Union, cast import pyarrow as pa @@ -16,14 +16,16 @@ if TYPE_CHECKING: ts_types = Union[datetime, pd.Timestamp, str] -try: - from pyarrow import FixedShapeTensorType +MetricType = Literal["l2", "euclidean", "dot", "cosine"] - CENTROIDS_TYPE = FixedShapeTensorType - has_fixed_shape_tensor = True -except ImportError: - has_fixed_shape_tensor = False - CENTROIDS_TYPE = pa.FixedSizeListType + +def _normalize_metric_type(metric_type: str) -> MetricType: + normalized = metric_type.lower() + if normalized == "euclidean": + normalized = "l2" + if normalized not in {"l2", "dot", "cosine"}: + raise ValueError(f"Invalid metric_type: {metric_type}") + return cast("MetricType", normalized) def sanitize_ts(ts: ts_types) -> datetime: @@ -76,7 +78,7 @@ class KMeans: def __init__( self, k: int, - metric_type: Literal["l2", "dot", "cosine"] = "l2", + metric_type: MetricType = "l2", max_iters: int = 50, centroids: Optional[pa.FixedSizeListArray] = None, ): @@ -93,11 +95,7 @@ def __init__( The maximum number of iterations to run the KMeans algorithm. Default: 50. centroids (pyarrow.FixedSizeListArray, optional.) – Provide existing centroids. """ - metric_type = metric_type.lower() - if metric_type not in ["l2", "dot", "cosine"]: - raise ValueError( - f"metric_type must be one of 'l2', 'dot', 'cosine', got: {metric_type}" - ) + metric_type = _normalize_metric_type(metric_type) self.k = k self._metric_type = metric_type self._kmeans = _KMeans( @@ -108,7 +106,7 @@ def __repr__(self) -> str: return f"lance.KMeans(k={self.k}, metric_type={self._metric_type})" @property - def centroids(self) -> Optional[CENTROIDS_TYPE]: + def centroids(self) -> Optional[pa.FixedShapeTensorArray]: """Returns the centroids of the model, Returns None if the model is not trained. @@ -116,11 +114,10 @@ def centroids(self) -> Optional[CENTROIDS_TYPE]: ret = self._kmeans.centroids() if ret is None: return None - if has_fixed_shape_tensor: - # Pyarrow compatibility - shape = (ret.type.list_size,) - tensor_type = pa.fixed_shape_tensor(ret.type.value_type, shape) - ret = pa.FixedShapeTensorArray.from_storage(tensor_type, ret) + + shape = (ret.type.list_size,) + tensor_type = pa.fixed_shape_tensor(ret.type.value_type, shape) + ret = pa.FixedShapeTensorArray.from_storage(tensor_type, ret) return ret def _to_fixed_size_list(self, data: pa.Array) -> pa.FixedSizeListArray: @@ -130,7 +127,7 @@ def _to_fixed_size_list(self, data: pa.Array) -> pa.FixedSizeListArray: f"Array must be float32 type, got: {data.type.value_type}" ) return data - elif has_fixed_shape_tensor and isinstance(data, pa.FixedShapeTensorArray): + elif isinstance(data, pa.FixedShapeTensorArray): if len(data.type.shape) != 1: raise ValueError( f"Fixed shape tensor array must be a 1-D array, " @@ -224,7 +221,7 @@ def validate_vector_index( class HNSW: - _hnsw = None + _hnsw: _Hnsw def __init__(self, hnsw) -> None: self._hnsw = hnsw diff --git a/python/python/lance/vector.py b/python/python/lance/vector.py index 88b46eee6c7..b1a396a54e9 100644 --- a/python/python/lance/vector.py +++ b/python/python/lance/vector.py @@ -5,22 +5,21 @@ from __future__ import annotations -import logging import re import tempfile -from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union import pyarrow as pa from tqdm.auto import tqdm from . import write_dataset from .dependencies import ( - _CAGRA_AVAILABLE, - _RAFT_COMMON_AVAILABLE, _check_for_numpy, torch, ) from .dependencies import numpy as np +from .log import LOGGER +from .util import MetricType, _normalize_metric_type if TYPE_CHECKING: from pathlib import Path @@ -133,20 +132,19 @@ def vec_to_table( def train_pq_codebook_on_accelerator( - dataset: LanceDataset, - metric_type: Literal["l2", "cosine", "dot"], + dataset: LanceDataset | Path | str, + metric_type: MetricType, accelerator: Union[str, "torch.Device"], num_sub_vectors: int, batch_size: int = 1024 * 10 * 4, + dtype: np.dtype = np.float32, ) -> Tuple[np.ndarray, List[Any]]: """Use accelerator (GPU or MPS) to train pq codebook.""" from .torch.data import LanceDataset as TorchDataset from .torch.kmeans import KMeans - # cuvs not particularly useful for only 256 centroids without more work - if accelerator == "cuvs": - accelerator = "cuda" + metric_type = _normalize_metric_type(metric_type) centroids_list = [] kmeans_list = [] @@ -173,7 +171,7 @@ def train_pq_codebook_on_accelerator( ) for sub_vector in range(num_sub_vectors): - logging.info("Training IVF partitions using GPU(%s)", accelerator) + LOGGER.info("Training IVF partitions using GPU(%s)", accelerator) if num_sub_vectors == 1: # sampler has different behaviour with one column init_centroids_slice = init_centroids @@ -195,7 +193,7 @@ def train_pq_codebook_on_accelerator( centroids_list.append(ivf_centroids_local) kmeans_list.append(kmeans_local) - pq_codebook = np.stack(centroids_list) + pq_codebook = np.stack(centroids_list).astype(dtype) return pq_codebook, kmeans_list @@ -203,7 +201,7 @@ def train_ivf_centroids_on_accelerator( dataset: LanceDataset, column: str, k: int, - metric_type: Literal["l2", "cosine", "dot"], + metric_type: MetricType, accelerator: Union[str, "torch.Device"], batch_size: int = 1024 * 10 * 4, *, @@ -213,16 +211,14 @@ def train_ivf_centroids_on_accelerator( ) -> Tuple[np.ndarray, Any]: """Use accelerator (GPU or MPS) to train kmeans.""" - from .cuvs.kmeans import KMeans as KMeansCuVS from .torch.data import LanceDataset as TorchDataset from .torch.kmeans import KMeans + metric_type = _normalize_metric_type(metric_type) + vector_value_type = dataset.schema.field(column).type.value_type + if isinstance(accelerator, str) and ( - not ( - CUDA_REGEX.match(accelerator) - or accelerator == "mps" - or accelerator == "cuvs" - ) + not (CUDA_REGEX.match(accelerator) or accelerator == "mps") ): raise ValueError( "Train ivf centroids on accelerator: " @@ -238,7 +234,7 @@ def train_ivf_centroids_on_accelerator( else: filt = None - logging.info("Randomly select %s centroids from %s (filt=%s)", k, dataset, filt) + LOGGER.info("Randomly select %s centroids from %s (filt=%s)", k, dataset, filt) ds = TorchDataset( dataset, @@ -249,7 +245,7 @@ def train_ivf_centroids_on_accelerator( ) init_centroids = next(iter(ds)) - logging.info("Done sampling: centroids shape: %s", init_centroids.shape) + LOGGER.info("Done sampling: centroids shape: %s", init_centroids.shape) ds = TorchDataset( dataset, @@ -260,40 +256,23 @@ def train_ivf_centroids_on_accelerator( cache=True, ) - if accelerator == "cuvs": - logging.info("Training IVF partitions using cuVS+GPU") - print("Training IVF partitions using cuVS+GPU") - if not (_CAGRA_AVAILABLE and _RAFT_COMMON_AVAILABLE): - logging.error( - "Missing cuvs and pylibraft - " - "please install cuvs-cu11 and pylibraft-cu11 or " - "cuvs-cu12 and pylibraft-cu12 using --extra-index-url " - "https://pypi.nvidia.com/" - ) - raise Exception("Missing cuvs or pylibraft dependency.") - kmeans = KMeansCuVS( - k, - max_iters=max_iters, - metric=metric_type, - device="cuda", - centroids=init_centroids, - ) - else: - logging.info("Training IVF partitions using GPU(%s)", accelerator) - kmeans = KMeans( - k, - max_iters=max_iters, - metric=metric_type, - device=accelerator, - centroids=init_centroids, - ) + LOGGER.info("Training IVF partitions using GPU(%s)", accelerator) + kmeans = KMeans( + k, + max_iters=max_iters, + metric=metric_type, + device=accelerator, + centroids=init_centroids, + ) kmeans.fit(ds) - centroids = kmeans.centroids.cpu().numpy() + centroids = ( + kmeans.centroids.cpu().numpy().astype(vector_value_type.to_pandas_dtype()) + ) with tempfile.NamedTemporaryFile(delete=False) as f: np.save(f, centroids) - logging.info("Saved centroids to %s", f.name) + LOGGER.info("Saved centroids to %s", f.name) return centroids, kmeans @@ -304,7 +283,7 @@ def compute_pq_codes( batch_size: int = 1024 * 10 * 4, dst_dataset_uri: Optional[Union[str, Path]] = None, allow_cuda_tf32: bool = True, -) -> str: +) -> Tuple[Union[str, Path], List[str]]: """Compute pq codes for each row using GPU kmeans and spill to disk. Parameters @@ -323,8 +302,8 @@ def compute_pq_codes( Returns ------- - str - The absolute path of the pq codes dataset. + Tuple[Union[str, Path], List[str]] + The absolute path of the pq codes dataset and shuffle buffers """ from .torch.data import LanceDataset as TorchDataset @@ -409,12 +388,10 @@ def _pq_codes_assignment() -> Iterable[pa.RecordBatch]: progress.close() - logging.info("Saved precomputed pq_codes to %s", dst_dataset_uri) + LOGGER.info("Saved precomputed pq_codes to %s", dst_dataset_uri) shuffle_buffers = [ - data_file.path() - for frag in ds.get_fragments() - for data_file in frag.data_files() + data_file.path for frag in ds.get_fragments() for data_file in frag.data_files() ] return dst_dataset_uri, shuffle_buffers @@ -530,7 +507,7 @@ def _partition_assignment() -> Iterable[pa.RecordBatch]: assert vecs.shape[0] == ids.shape[0] # Ignore any invalid vectors. - mask_gpu = partitions.isfinite() + mask_gpu = partitions.isfinite() & (partitions >= 0) mask = mask_gpu.cpu() ids = ids[mask] partitions = partitions[mask_gpu] @@ -539,7 +516,7 @@ def _partition_assignment() -> Iterable[pa.RecordBatch]: split_columns = [] if num_sub_vectors is not None: - residual_vecs = vecs - kmeans.centroids[partitions] + residual_vecs = vecs[mask_gpu] - kmeans.centroids[partitions] for i in range(num_sub_vectors): subvector_tensor = residual_vecs[ :, i * subvector_size : (i + 1) * subvector_size @@ -561,7 +538,7 @@ def _partition_assignment() -> Iterable[pa.RecordBatch]: schema=output_schema, ) if len(part_batch) < len(ids): - logging.warning( + LOGGER.warning( "%s vectors are ignored during partition assignment", len(part_batch) - len(ids), ) @@ -582,7 +559,7 @@ def _partition_assignment() -> Iterable[pa.RecordBatch]: progress.close() - logging.info("Saved precomputed partitions to %s", dst_dataset_uri) + LOGGER.info("Saved precomputed partitions to %s", dst_dataset_uri) return str(dst_dataset_uri) @@ -590,7 +567,7 @@ def one_pass_train_ivf_pq_on_accelerator( dataset: LanceDataset, column: str, k: int, - metric_type: Literal["l2", "cosine", "dot"], + metric_type: MetricType, accelerator: Union[str, "torch.Device"], num_sub_vectors: int, batch_size: int = 1024 * 10 * 4, @@ -599,6 +576,7 @@ def one_pass_train_ivf_pq_on_accelerator( max_iters: int = 50, filter_nan: bool = True, ): + metric_type = _normalize_metric_type(metric_type) centroids, kmeans = train_ivf_centroids_on_accelerator( dataset, column, @@ -629,7 +607,7 @@ def one_pass_train_ivf_pq_on_accelerator( def one_pass_assign_ivf_pq_on_accelerator( dataset: LanceDataset, column: str, - metric_type: Literal["l2", "cosine", "dot"], + metric_type: MetricType, accelerator: Union[str, "torch.Device"], ivf_kmeans: Any, # KMeans pq_kmeans_list: List[Any], # List[KMeans] @@ -707,7 +685,7 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]: assert vecs.shape[0] == ids.shape[0] # Ignore any invalid vectors. - mask_gpu = partitions.isfinite() + mask_gpu = partitions.isfinite() & (partitions >= 0) ids = ids.to(ivf_kmeans.device)[mask_gpu].cpu().reshape(-1) partitions = partitions[mask_gpu].cpu() vecs = vecs[mask_gpu] @@ -716,7 +694,7 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]: # cast centroids to the same dtype as vecs if first_iter: first_iter = False - logging.info("Residual shape: %s", residual_vecs.shape) + LOGGER.info("Residual shape: %s", residual_vecs.shape) for kmeans in pq_kmeans_list: cents: torch.Tensor = kmeans.centroids kmeans.centroids = cents.to( @@ -743,7 +721,7 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]: ) if len(part_batch) < len(ids): - logging.warning( + LOGGER.warning( "%s vectors are ignored during partition assignment", len(part_batch) - len(ids), ) @@ -765,11 +743,9 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]: progress.close() - logging.info("Saved precomputed pq_codes to %s", dst_dataset_uri) + LOGGER.info("Saved precomputed pq_codes to %s", dst_dataset_uri) shuffle_buffers = [ - data_file.path() - for frag in ds.get_fragments() - for data_file in frag.data_files() + data_file.path for frag in ds.get_fragments() for data_file in frag.data_files() ] return dst_dataset_uri, shuffle_buffers diff --git a/python/python/tests/models/jieba/default/dict.txt b/python/python/tests/models/jieba/default/dict.txt new file mode 100644 index 00000000000..237b47ca6a8 --- /dev/null +++ b/python/python/tests/models/jieba/default/dict.txt @@ -0,0 +1,8 @@ +我们 98740 r +都 202780 d +有 423765 v +光明 1219 n +çš„ 318825 uj +å‰é€” 1263 n +å‰ 62779 f +途 857 n diff --git a/python/python/tests/models/jieba/invalid_dict/config.json b/python/python/tests/models/jieba/invalid_dict/config.json new file mode 100644 index 00000000000..cf4301aa2b2 --- /dev/null +++ b/python/python/tests/models/jieba/invalid_dict/config.json @@ -0,0 +1,6 @@ +{ + "main": "../default/dict.txt", + "users": [ + "invalid_user.txt" + ] +} diff --git a/python/python/tests/models/jieba/invalid_dict2/config.json b/python/python/tests/models/jieba/invalid_dict2/config.json new file mode 100644 index 00000000000..d0216419a5f --- /dev/null +++ b/python/python/tests/models/jieba/invalid_dict2/config.json @@ -0,0 +1,3 @@ +{ + "main": "invalid_dict.txt" +} diff --git a/python/python/tests/models/jieba/user_dict/config.json b/python/python/tests/models/jieba/user_dict/config.json new file mode 100644 index 00000000000..0d65334ca28 --- /dev/null +++ b/python/python/tests/models/jieba/user_dict/config.json @@ -0,0 +1,6 @@ +{ + "main": "../default/dict.txt", + "users": [ + "user.txt" + ] +} diff --git a/python/python/tests/models/jieba/user_dict/user.txt b/python/python/tests/models/jieba/user_dict/user.txt new file mode 100644 index 00000000000..bb6ffa4d85f --- /dev/null +++ b/python/python/tests/models/jieba/user_dict/user.txt @@ -0,0 +1 @@ +光明的å‰é€” 318825 n diff --git a/python/python/tests/models/lindera/README.md b/python/python/tests/models/lindera/README.md new file mode 100644 index 00000000000..c4073b65d56 --- /dev/null +++ b/python/python/tests/models/lindera/README.md @@ -0,0 +1,28 @@ +# How to build this test language model + +Ipadic model is about 45M. so we created a tiny ipadic in zip. + +- Download language model + +```bash +curl -L -o mecab-ipadic-2.7.0-20070801.tar.gz "https://github.com/lindera-morphology/mecab-ipadic/archive/refs/tags/2.7.0-20070801.tar.gz" +tar xvf mecab-ipadic-2.7.0-20070801.tar.gz +``` + +- Remove csv files in folder + +- Put files in `ipadic/raw` into folder + +- Edit matrix.def, reset last column(weight) into zero, except first row. + +- build + +```bash +lindera build --dictionary-kind=ipadic mecab-ipadic-2.7.0-20070801 main +``` + +- build user dict + +```bash +lindera build --build-user-dictionary --dictionary-kind=ipadic user_dict/userdict.csv user_dict2 +``` diff --git a/python/python/tests/models/lindera/invalid_dict/config.json b/python/python/tests/models/lindera/invalid_dict/config.json new file mode 100644 index 00000000000..b486aeba24b --- /dev/null +++ b/python/python/tests/models/lindera/invalid_dict/config.json @@ -0,0 +1,4 @@ +{ + "main": "../main", + "user": "invalid.bin" +} diff --git a/python/python/tests/models/lindera/invalid_dict2/config.json b/python/python/tests/models/lindera/invalid_dict2/config.json new file mode 100644 index 00000000000..11c22e9f1ce --- /dev/null +++ b/python/python/tests/models/lindera/invalid_dict2/config.json @@ -0,0 +1,4 @@ +{ + "main": "../main", + "user": "ipadic_simple_userdic.csv" +} diff --git a/python/python/tests/models/lindera/ipadic/main.zip b/python/python/tests/models/lindera/ipadic/main.zip new file mode 100644 index 00000000000..25966ae2a1d Binary files /dev/null and b/python/python/tests/models/lindera/ipadic/main.zip differ diff --git a/python/python/tests/models/lindera/ipadic/raw/Noun.mock.csv b/python/python/tests/models/lindera/ipadic/raw/Noun.mock.csv new file mode 100644 index 00000000000..4201b57a543 --- /dev/null +++ b/python/python/tests/models/lindera/ipadic/raw/Noun.mock.csv @@ -0,0 +1,3 @@ +À®ÅÄ,1293,1293,5686,̾»ì,¸Çͭ̾»ì,Ãϰè,°ìÈÌ,*,*,À®ÅÄ,¥Ê¥ê¥¿,¥Ê¥ê¥¿ +¹ñºÝ,1285,1285,553,̾»ì,°ìÈÌ,*,*,*,*,¹ñºÝ,¥³¥¯¥µ¥¤,¥³¥¯¥µ¥¤ +¶õ¹Á,1285,1285,7778,̾»ì,°ìÈÌ,*,*,*,*,¶õ¹Á,¥¯¥¦¥³¥¦,¥¯¡¼¥³¡¼ \ No newline at end of file diff --git a/python/python/tests/models/lindera/user_dict/config.json b/python/python/tests/models/lindera/user_dict/config.json new file mode 100644 index 00000000000..e554849af24 --- /dev/null +++ b/python/python/tests/models/lindera/user_dict/config.json @@ -0,0 +1,5 @@ +{ + "main": "../ipadic/main", + "user": "userdic.csv", + "user_kind": "ipadic" +} diff --git a/python/python/tests/models/lindera/user_dict/userdic.csv b/python/python/tests/models/lindera/user_dict/userdic.csv new file mode 100644 index 00000000000..652c3f77910 --- /dev/null +++ b/python/python/tests/models/lindera/user_dict/userdic.csv @@ -0,0 +1 @@ +æˆç”°å›½éš›ç©ºæ¸¯,カスタムå詞,トウキョウスカイツリー diff --git a/python/python/tests/models/lindera/user_dict2/config.json b/python/python/tests/models/lindera/user_dict2/config.json new file mode 100644 index 00000000000..e06bd8c71be --- /dev/null +++ b/python/python/tests/models/lindera/user_dict2/config.json @@ -0,0 +1,4 @@ +{ + "main": "../ipadic/main", + "user": "userdic.bin" +} diff --git a/python/python/tests/models/lindera/user_dict2/userdic.bin b/python/python/tests/models/lindera/user_dict2/userdic.bin new file mode 100644 index 00000000000..a0410fa0798 Binary files /dev/null and b/python/python/tests/models/lindera/user_dict2/userdic.bin differ diff --git a/python/python/tests/test_balanced.py b/python/python/tests/test_balanced.py index a7d33bd3d09..769a6ca0a17 100644 --- a/python/python/tests/test_balanced.py +++ b/python/python/tests/test_balanced.py @@ -58,6 +58,32 @@ def balanced_dataset(tmp_path, big_val): ) +def test_write_fragments(balanced_dataset, tmp_path): + tbl = balanced_dataset._take_rows(range(10)) + transaction = lance.fragment.write_fragments( + tbl, + tmp_path / "ds", + enable_move_stable_row_ids=True, + return_transaction=True, + ) + operation = lance.LanceOperation.Overwrite( + transaction.operation.new_schema, transaction.operation.fragments + ) + blob_operation = lance.LanceOperation.Overwrite( + transaction.blobs_op.new_schema, transaction.blobs_op.fragments + ) + + lance.LanceDataset.commit( + tmp_path / "ds", + operation, + blobs_op=blob_operation, + enable_v2_manifest_paths=True, + ) + dataset = lance.LanceDataset(tmp_path / "ds") + + assert dataset._take_rows(range(10)) == balanced_dataset._take_rows(range(10)) + + def test_append_then_take(balanced_dataset, tmp_path, big_val): blob_dir = tmp_path / "test_ds" / "_blobs" / "data" assert len(list(blob_dir.iterdir())) == 8 diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index af74f9b2a6a..73df37f81aa 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -30,6 +30,7 @@ from lance._dataset.sharded_batch_iterator import ShardedBatchIterator from lance.commit import CommitConflictError from lance.debug import format_fragment +from lance.schema import LanceSchema from lance.util import validate_vector_index # Various valid inputs for write_dataset @@ -160,6 +161,39 @@ def test_dataset_from_record_batch_iterable(tmp_path: Path): assert list(dataset.to_batches())[0].to_pylist() == test_pylist +def test_schema_metadata(tmp_path: Path): + schema = pa.schema( + [ + pa.field("a", pa.int64(), metadata={b"thisis": "a"}), + pa.field("b", pa.int64(), metadata={b"thisis": "b"}), + ], + metadata={b"foo": b"bar", b"baz": b"qux"}, + ) + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}, schema=schema) + ds = lance.write_dataset(table, tmp_path) + # Original schema + assert ds.schema.metadata == {b"foo": b"bar", b"baz": b"qux"} + assert ds.schema.field("a").metadata == {b"thisis": b"a"} + assert ds.schema.field("b").metadata == {b"thisis": b"b"} + + # Replace schema metadata + ds.replace_schema_metadata({"foo": "baz"}) + assert ds.schema.metadata == {b"foo": b"baz"} + assert ds.schema.field("a").metadata == {b"thisis": b"a"} + assert ds.schema.field("b").metadata == {b"thisis": b"b"} + + # Replace field metadata + ds.replace_field_metadata("a", {"thisis": "c"}) + assert ds.schema.field("a").metadata == {b"thisis": b"c"} + assert ds.schema.field("b").metadata == {b"thisis": b"b"} + + # Overwrite overwrites metadata + ds = lance.write_dataset(table, tmp_path, mode="overwrite") + assert ds.schema.metadata == {b"foo": b"bar", b"baz": b"qux"} + assert ds.schema.field("a").metadata == {b"thisis": b"a"} + assert ds.schema.field("b").metadata == {b"thisis": b"b"} + + def test_versions(tmp_path: Path): table1 = pa.Table.from_pylist([{"a": 1, "b": 2}, {"a": 10, "b": 20}]) base_dir = tmp_path / "test" @@ -261,6 +295,41 @@ def test_asof_checkout(tmp_path: Path): assert len(ds.to_table()) == 9 +def test_enable_move_stable_row_ids(tmp_path: Path): + table = pa.Table.from_pylist( + [{"name": "Alice", "age": 20}, {"name": "Bob", "age": 30}] + ) + lance.write_dataset(table, tmp_path, enable_move_stable_row_ids=True) + ds = lance.write_dataset( + table, tmp_path, enable_move_stable_row_ids=True, mode="append" + ) + table_before = ds.scanner(with_row_id=True, with_row_address=True).to_table() + assert len(table_before) == 4 + assert table_before["_rowid"][0].as_py() == 0 + assert table_before["_rowid"][1].as_py() == 1 + assert table_before["_rowid"][2].as_py() == 2 + assert table_before["_rowid"][3].as_py() == 3 + + assert table_before["_rowaddr"][0].as_py() == 0 + assert table_before["_rowaddr"][1].as_py() == 1 + assert table_before["_rowaddr"][2].as_py() == (1 << 32) + 0 + assert table_before["_rowaddr"][3].as_py() == (1 << 32) + 1 + + ds.optimize.compact_files() + + table_after = ds.scanner(with_row_id=True, with_row_address=True).to_table() + assert len(table_after) == 4 + assert table_after["_rowid"][0].as_py() == 0 + assert table_after["_rowid"][1].as_py() == 1 + assert table_after["_rowid"][2].as_py() == 2 + assert table_after["_rowid"][3].as_py() == 3 + + assert table_after["_rowaddr"][0].as_py() == (2 << 32) + 0 + assert table_after["_rowaddr"][1].as_py() == (2 << 32) + 1 + assert table_after["_rowaddr"][2].as_py() == (2 << 32) + 2 + assert table_after["_rowaddr"][3].as_py() == (2 << 32) + 3 + + def test_v2_manifest_paths(tmp_path: Path): lance.write_dataset( pa.table({"a": range(100)}), tmp_path, enable_v2_manifest_paths=True @@ -451,9 +520,11 @@ def test_limit_offset(tmp_path: Path, data_storage_version: str): # test just limit assert dataset.to_table(limit=10) == table.slice(0, 10) + assert dataset.to_table(limit=100) == table.slice(0, 100) # test just offset - assert dataset.to_table(offset=10) == table.slice(10, 100) + assert dataset.to_table(offset=0) == table.slice(0, 100) + assert dataset.to_table(offset=10) == table.slice(10, 90) # test both assert dataset.to_table(offset=10, limit=10) == table.slice(10, 10) @@ -468,7 +539,18 @@ def test_limit_offset(tmp_path: Path, data_storage_version: str): assert dataset.to_table(offset=50, limit=25) == table.slice(50, 25) # Limit past the end - assert dataset.to_table(offset=50, limit=100) == table.slice(50, 50) + assert dataset.to_table(limit=101) == table.slice(0, 100) + + # Limit with offset past the end + assert dataset.to_table(offset=50, limit=51) == table.slice(50, 50) + + # Offset past the end + assert dataset.to_table(offset=100) == table.slice(100, 0) # Empty table + assert dataset.to_table(offset=101) == table.slice(100, 0) # Empty table + + # Offset with limit past the end + assert dataset.to_table(offset=100, limit=1) == table.slice(100, 0) # Empty table + assert dataset.to_table(offset=101, limit=1) == table.slice(100, 0) # Empty table # Invalid limit / offset with pytest.raises(ValueError, match="Offset must be non-negative"): @@ -697,6 +779,68 @@ def test_count_rows(tmp_path: Path): assert dataset.count_rows(filter="a < 50") == 50 +def test_select_none(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + ds = lance.write_dataset(table, base_dir) + + assert "projection=[a]" in ds.scanner( + columns=[], filter="a < 50", with_row_id=True + ).explain_plan(True) + + +def test_analyze_filtered_scan(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + ds = lance.write_dataset(table, base_dir) + plan = ds.scanner(columns=[], filter="a < 50", with_row_id=True).analyze_plan() + print(plan) + assert re.search(r"^\s*LanceScan:.*output_rows=100.*$", plan, re.MULTILINE) + assert re.search(r"^\s*FilterExec:.*output_rows=50.*$", plan, re.MULTILINE) + + +def test_analyze_index_scan(tmp_path: Path): + table = pa.table({"filter": range(100)}) + dataset = lance.write_dataset(table, tmp_path) + dataset.create_scalar_index("filter", "BTREE") + plan = dataset.scanner(filter="filter = 10").analyze_plan() + assert "MaterializeIndex: query=filter = 10, metrics=[output_rows=1" in plan + + +def test_analyze_scan(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + dataset = lance.write_dataset(table, tmp_path) + plan = dataset.scanner().analyze_plan() + # The bytes_read part might get brittle if we change file versions a lot + # future us are free to ignore that part. + assert "bytes_read=3643, iops=3, requests=3" in plan + + +def test_analyze_take(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + dataset = lance.write_dataset(table, tmp_path) + dataset.create_scalar_index("a", "BTREE") + plan = dataset.scanner(filter="a = 50").analyze_plan() + assert "bytes_read=16, iops=2, requests=2" in plan + + +def test_analyze_vector_search(tmp_path: Path): + table = pa.Table.from_pydict( + { + "id": [i for i in range(10)], + "vector": pa.array( + [[1.0, 1.0] for _ in range(10)], pa.list_(pa.float32(), 2) + ), + } + ) + dataset = lance.write_dataset(table, tmp_path / "dataset", mode="create") + dataset.delete("id = 0") + plan = dataset.scanner( + nearest={"column": "vector", "k": 10, "q": [1.0, 1.0]} + ).analyze_plan() + assert "KNNVectorDistance: metric=l2, metrics=[output_rows=10" in plan + + def test_get_fragments(tmp_path: Path): table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) base_dir = tmp_path / "test" @@ -934,6 +1078,31 @@ def test_restore_with_commit(tmp_path: Path): assert tbl == table +def test_merge_insert_with_commit(): + table = pa.table({"id": range(10), "updated": [False] * 10}) + dataset = lance.write_dataset(table, "memory://test") + + updates = pa.Table.from_pylist([{"id": 1, "updated": True}]) + transaction, stats = ( + dataset.merge_insert(on="id") + .when_matched_update_all() + .execute_uncommitted(updates) + ) + + assert isinstance(stats, dict) + assert stats["num_updated_rows"] == 1 + assert stats["num_inserted_rows"] == 0 + assert stats["num_deleted_rows"] == 0 + + assert isinstance(transaction, lance.Transaction) + assert isinstance(transaction.operation, lance.LanceOperation.Update) + + dataset = lance.LanceDataset.commit(dataset, transaction) + assert dataset.to_table().sort_by("id") == pa.table( + {"id": range(10), "updated": [False] + [True] + [False] * 8} + ) + + def test_merge_with_commit(tmp_path: Path): table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) base_dir = tmp_path / "test" @@ -1013,12 +1182,14 @@ def test_data_files(tmp_path: Path): base_dir = tmp_path / "test" fragment = lance.fragment.LanceFragment.create(base_dir, table) - data_files = fragment.data_files() + data_files = fragment.files assert len(data_files) == 1 # it is a valid uuid - uuid.UUID(os.path.splitext(data_files[0].path())[0]) + with pytest.warns(DeprecationWarning): + path = data_files[0].path() + uuid.UUID(os.path.splitext(path)[0]) - assert fragment.deletion_file() is None + assert fragment.deletion_file is None def test_deletion_file(tmp_path: Path): @@ -1035,8 +1206,10 @@ def test_deletion_file(tmp_path: Path): assert fragment.deletion_file() is None # New fragment has deletion file - assert new_fragment.deletion_file() is not None - assert re.match("_deletions/0-1-[0-9]{1,32}.arrow", new_fragment.deletion_file()) + assert new_fragment.deletion_file is not None + assert re.match( + "_deletions/0-1-[0-9]{1,32}.arrow", new_fragment.deletion_file.path(0) + ) operation = lance.LanceOperation.Overwrite(table.schema, [new_fragment]) dataset = lance.LanceDataset.commit(base_dir, operation) assert dataset.count_rows() == 90 @@ -1055,6 +1228,9 @@ def test_commit_fragments_via_scanner(tmp_path: Path): pickled = pickle.dumps(fragment_metadata) unpickled = pickle.loads(pickled) assert fragment_metadata == unpickled + with pytest.warns(DeprecationWarning): + path = fragment_metadata.files[0].path() + assert path == unpickled.files[0].path() operation = lance.LanceOperation.Overwrite(table.schema, [fragment_metadata]) dataset = lance.LanceDataset.commit(base_dir, operation) @@ -1081,6 +1257,18 @@ def test_load_scanner_from_fragments(tmp_path: Path): assert scanner.to_table().num_rows == 2 * 100 +def test_write_unstable_data_version(tmp_path: Path, capfd): + # Note: this test will only work if no earlier test attempts + # to use an unstable version. If we need that later we can find a way to + # run this test in a separate process (pytest-xdist?) + tab = pa.table({"a": range(100), "b": range(100)}) + ds = lance.write_dataset( + tab, tmp_path / "dataset", mode="append", data_storage_version="next" + ) + assert ds.to_table() == tab + assert "You have requested an unstable format version" in capfd.readouterr().err + + def test_merge_data(tmp_path: Path): tab = pa.table({"a": range(100), "b": range(100)}) lance.write_dataset(tab, tmp_path / "dataset", mode="append") @@ -1193,12 +1381,23 @@ def check_merge_stats(merge_dict, expected): def test_merge_insert(tmp_path: Path): nrows = 1000 + # Create a schema with some metadata to regress an issue where the metadata + # caused schema comparison problems in merge_insert. + schema = pa.schema( + [ + pa.field("a", pa.int64()), + pa.field("b", pa.int64()), + pa.field("c", pa.int64()), + ], + metadata={"foo": "bar"}, + ) table = pa.Table.from_pydict( { "a": range(nrows), "b": [1 for _ in range(nrows)], "c": [x % 2 for x in range(nrows)], - } + }, + schema=schema, ) dataset = lance.write_dataset( table, tmp_path / "dataset", mode="create", max_rows_per_file=100 @@ -1210,7 +1409,8 @@ def test_merge_insert(tmp_path: Path): "a": range(300, 300 + nrows), "b": [2 for _ in range(nrows)], "c": [0 for _ in range(nrows)], - } + }, + schema=schema, ) is_new = pc.field("b") == 2 @@ -1327,14 +1527,37 @@ def test_merge_insert_subcols(tmp_path: Path): assert fragments[1].fragment_id == original_fragments[1].fragment_id assert len(fragments[0].data_files()) == 2 - assert str(fragments[0].data_files()[0]) == str( - original_fragments[0].data_files()[0] + assert ( + fragments[0].data_files()[0].path == original_fragments[0].data_files()[0].path ) assert len(fragments[1].data_files()) == 1 assert str(fragments[1].data_files()[0]) == str( original_fragments[1].data_files()[0] ) + new_values = pa.table( + { + "a": range(9, 12), + "b": range(30, 33), + } + ) + ( + dataset.merge_insert("a") + .when_not_matched_insert_all() + .when_matched_update_all() + .execute(new_values) + ) + + assert dataset.count_rows() == 12 + expected = pa.table( + { + "a": range(0, 12), + "b": [0, 1, 2, 20, 21, 5, 6, 7, 8, 30, 31, 32], + "c": list(range(10, 20)) + [None] * 2, + } + ) + assert dataset.to_table().sort_by("a") == expected + def test_flat_vector_search_with_delete(tmp_path: Path): table = pa.Table.from_pydict( @@ -1355,6 +1578,19 @@ def test_flat_vector_search_with_delete(tmp_path: Path): ) +def test_null_reader_with_deletes(tmp_path: Path): + full_schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("other", pa.int64()), + ] + ) + ds = lance.write_dataset([], tmp_path, schema=full_schema, mode="create") + ds.insert(pa.table({"id": [1, 2, 3, 4, 5]})) + ds.delete("id in (1, 2)") + ds.to_table() + + def test_merge_insert_conditional_upsert_example(tmp_path: Path): table = pa.Table.from_pydict( { @@ -1496,34 +1732,6 @@ def test_merge_insert_multiple_keys(tmp_path: Path): check_merge_stats(merge_dict, (0, 350, 0)) -def test_merge_insert_incompatible_schema(tmp_path: Path): - nrows = 1000 - table = pa.Table.from_pydict( - { - "a": range(nrows), - "b": [1 for _ in range(nrows)], - } - ) - dataset = lance.write_dataset( - table, tmp_path / "dataset", mode="create", max_rows_per_file=100 - ) - - new_table = pa.Table.from_pydict( - { - "a": range(300, 300 + nrows), - } - ) - - with pytest.raises(OSError): - merge_dict = ( - dataset.merge_insert("a") - .when_matched_update_all() - .when_not_matched_insert_all() - .execute(new_table) - ) - check_merge_stats(merge_dict, (None, None, None)) - - def test_merge_insert_vector_column(tmp_path: Path): table = pa.Table.from_pydict( { @@ -1562,6 +1770,115 @@ def test_merge_insert_vector_column(tmp_path: Path): check_merge_stats(merge_dict, (1, 1, 0)) +def test_merge_insert_large(): + # Doing subcolumns update with merge insert triggers this error. + # Data needs to be large enough to make DataFusion create multiple batches + # when outputting join results. + # https://github.com/lancedb/lance/issues/3406 + # This test is in Python because for whatever reason, the error doesn't + # reproduce in the equivalent Rust test. + dims = 32 + nrows = 20_000 + data = pa.table({"id": range(nrows), "num": [str(i) for i in range(nrows)]}) + + ds = lance.write_dataset(data, "memory://") + + ds.add_columns({"vector": f"arrow_cast(NULL, 'FixedSizeList({dims}, Float32)')"}) + + batch_size = 10_000 + other_columns = pa.table( + { + "id": range(batch_size), + "vector": pa.FixedSizeListArray.from_arrays( + pc.random(batch_size * dims).cast(pa.float32()), dims + ), + } + ) + + ( + ds.merge_insert(on="id") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(other_columns) + ) + + +def test_merge_insert_empty_index(): + # Reported in https://github.com/lancedb/lancedb/issues/2285 + empty_table = pa.table({"id": pa.array([], type=pa.float64())}) + empty_ds = lance.write_dataset(empty_table, "memory://") + + empty_ds.create_scalar_index("id", "BTREE") + + df = pa.table({"id": [1.0, 2.0, 3.0]}) + + empty_ds.merge_insert("id").when_not_matched_insert_all().execute(df) + + +def test_add_null_columns(tmp_path: Path): + data = pa.table({"id": [1, 2, 4]}) + ds = lance.write_dataset(data, tmp_path) + fragments = ds.get_fragments() + assert len(fragments) == 1 + assert len(fragments[0].data_files()) == 1 + + ds.add_columns(pa.field("f1", pa.float32())) + fragments = ds.get_fragments() + assert len(fragments) == 1 + assert len(fragments[0].data_files()) == 1 + assert ds.schema == pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("f1", pa.float32()), + ] + ) + + ds.add_columns( + [pa.field("v2", pa.list_(pa.float32(), 32)), pa.field("v3", pa.int32())] + ) + fragments = ds.get_fragments() + assert len(fragments) == 1 + assert len(fragments[0].data_files()) == 1 + assert ds.schema == pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("f1", pa.float32()), + pa.field("v2", pa.list_(pa.float32(), 32)), + pa.field("v3", pa.int32()), + ] + ) + + ds.add_columns( + pa.schema([pa.field("s6", pa.struct([("a", pa.int32()), ("b", pa.bool_())]))]) + ) + fragments = ds.get_fragments() + assert len(fragments) == 1 + assert len(fragments[0].data_files()) == 1 + assert ds.schema == pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("f1", pa.float32()), + pa.field("v2", pa.list_(pa.float32(), 32)), + pa.field("v3", pa.int32()), + pa.field("s6", pa.struct([("a", pa.int32()), ("b", pa.bool_())])), + ] + ) + + +def test_add_null_columns_with_conflict_names(tmp_path: Path): + data = pa.table({"id": [1, 2, 4]}) + ds = lance.write_dataset(data, tmp_path) + fragments = ds.get_fragments() + assert len(fragments) == 1 + assert len(fragments[0].data_files()) == 1 + + with pytest.raises(Exception, match=".*Column id already exists in the dataset.*"): + ds.add_columns(pa.field("id", pa.float32())) + + with pytest.raises(Exception, match=".*Column id already exists in the dataset.*"): + ds.add_columns([pa.field("id", pa.float32()), pa.field("good", pa.int32())]) + + def check_update_stats(update_dict, expected): assert (update_dict["num_rows_updated"],) == expected @@ -2047,7 +2364,7 @@ def test_scan_count_rows(tmp_path: Path): df = pd.DataFrame({"a": range(42), "b": range(42)}) dataset = lance.write_dataset(df, base_dir) - assert dataset.scanner().count_rows() == 42 + assert dataset.scanner(columns=[], with_row_id=True).count_rows() == 42 assert dataset.count_rows(filter="a < 10") == 10 assert dataset.count_rows(filter=pa_ds.field("a") < 20) == 20 @@ -2063,6 +2380,44 @@ def test_scanner_schemas(tmp_path: Path): assert scanner.projected_schema == pa.schema([pa.field("a", pa.int64())]) +def test_scan_deleted_rows(tmp_path: Path): + base_dir = tmp_path / "dataset" + df = pd.DataFrame({"a": range(100), "b": range(100)}) + ds = lance.write_dataset(df, base_dir, max_rows_per_file=25) + ds.create_scalar_index("b", "BTREE") + ds.delete("a < 30") + + assert ds.count_rows() == 70 + + assert ds.scanner(with_row_id=True).to_table().num_rows == 70 + with_deleted = ds.scanner(with_row_id=True, include_deleted_rows=True).to_table() + + assert with_deleted.num_rows == 75 + + assert with_deleted.slice(0, 5) == pa.table( + { + "a": range(25, 30), + "b": range(25, 30), + "_rowid": pa.array([None] * 5, pa.uint64()), + } + ) + + assert ( + ds.scanner(with_row_id=True, include_deleted_rows=True, filter="a < 32") + .to_table() + .num_rows + == 7 + ) + + with pytest.raises(ValueError, match="Cannot include deleted rows"): + ds.scanner( + include_deleted_rows=True, with_row_id=True, filter="b < 30" + ).to_table() + + with pytest.raises(ValueError, match="with_row_id is false"): + ds.scanner(include_deleted_rows=True, filter="a < 30").to_table() + + def test_custom_commit_lock(tmp_path: Path): called_lock = False called_release = False @@ -2647,23 +3002,53 @@ def test_use_scalar_index(tmp_path: Path): EXPECTED_MINOR_VERSION = 0 +def test_stats(tmp_path: Path): + table = pa.table({"x": [1, 2, 3, 4], "y": ["foo", "bar", "baz", "qux"]}) + dataset = lance.write_dataset(table, tmp_path) + stats = dataset.stats.dataset_stats() + + assert stats["num_deleted_rows"] == 0 + assert stats["num_fragments"] == 1 + assert stats["num_small_files"] == 1 + + data_stats = dataset.stats.data_stats() + + assert data_stats.fields[0].id == 0 + assert data_stats.fields[0].bytes_on_disk == 32 + assert data_stats.fields[1].id == 1 + assert data_stats.fields[1].bytes_on_disk == 44 # 12 bytes data + 32 bytes offset + + dataset.add_columns({"z": "y"}) + + dataset.insert(pa.table({"x": [5], "z": ["quux"]})) + + data_stats = dataset.stats.data_stats() + + assert data_stats.fields[0].id == 0 + assert data_stats.fields[0].bytes_on_disk == 40 + assert data_stats.fields[1].id == 1 + assert data_stats.fields[1].bytes_on_disk == 44 # 12 bytes data + 32 bytes offset + assert data_stats.fields[2].id == 2 + assert data_stats.fields[2].bytes_on_disk == 56 # 16 bytes data + 40 bytes offset + + def test_default_storage_version(tmp_path: Path): table = pa.table({"x": [0]}) dataset = lance.write_dataset(table, tmp_path) assert dataset.data_storage_version == EXPECTED_DEFAULT_STORAGE_VERSION frag = lance.LanceFragment.create(dataset.uri, table) - sample_file = frag.to_json()["files"][0] - assert sample_file["file_major_version"] == EXPECTED_MAJOR_VERSION - assert sample_file["file_minor_version"] == EXPECTED_MINOR_VERSION + sample_file = frag.files[0] + assert sample_file.file_major_version == EXPECTED_MAJOR_VERSION + assert sample_file.file_minor_version == EXPECTED_MINOR_VERSION from lance.fragment import write_fragments frags = write_fragments(table, dataset.uri) frag = frags[0] - sample_file = frag.to_json()["files"][0] - assert sample_file["file_major_version"] == EXPECTED_MAJOR_VERSION - assert sample_file["file_minor_version"] == EXPECTED_MINOR_VERSION + sample_file = frag.files[0] + assert sample_file.file_major_version == EXPECTED_MAJOR_VERSION + assert sample_file.file_minor_version == EXPECTED_MINOR_VERSION def test_no_detached_v1(tmp_path: Path): @@ -2715,3 +3100,139 @@ def test_detached_commits(tmp_path: Path): ) assert detached2.to_table() == pa.table({"x": [0, 1, 3]}) + + +def test_dataset_drop(tmp_path: Path): + table = pa.table({"x": [0]}) + lance.write_dataset(table, tmp_path) + assert Path(tmp_path).exists() + lance.LanceDataset.drop(tmp_path) + assert not Path(tmp_path).exists() + + +def test_dataset_schema(tmp_path: Path): + table = pa.table({"x": [0]}) + ds = lance.write_dataset(table, str(tmp_path)) # noqa: F841 + ds._default_scan_options = {"with_row_id": True} + assert ds.schema == ds.to_table().schema + + +def test_data_replacement(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + + dataset = lance.write_dataset(table, base_dir) + + table = pa.Table.from_pydict({"a": range(100, 200), "b": range(100, 200)}) + fragment = lance.fragment.LanceFragment.create(base_dir, table) + data_file = fragment.files[0] + data_replacement = lance.LanceOperation.DataReplacement( + [lance.LanceOperation.DataReplacementGroup(0, data_file)] + ) + dataset = lance.LanceDataset.commit(dataset, data_replacement, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "a": list(range(100, 200)), + "b": list(range(100, 200)), + } + ) + assert tbl == expected + + +def test_schema_project_drop_column(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100, 200), "b": range(300, 400)}) + base_dir = tmp_path / "test" + + dataset = lance.write_dataset(table, base_dir) + + schema = pa.Table.from_pydict({"a": range(1)}).schema + lance_schema = LanceSchema.from_pyarrow(schema) + + project = lance.LanceOperation.Project(lance_schema) + dataset = lance.LanceDataset.commit(dataset, project, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "a": list(range(100, 200)), + } + ) + assert tbl == expected + + +def test_schema_project_rename_column(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100, 200), "b": range(300, 400)}) + base_dir = tmp_path / "test" + + dataset = lance.write_dataset(table, base_dir) + + schema = pa.Table.from_pydict({"c": range(1), "d": range(1)}).schema + lance_schema = LanceSchema.from_pyarrow(schema) + + project = lance.LanceOperation.Project(lance_schema) + dataset = lance.LanceDataset.commit(dataset, project, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "c": list(range(100, 200)), + "d": list(range(300, 400)), + } + ) + assert tbl == expected + + +def test_schema_project_swap_column(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100, 200), "b": range(300, 400)}) + base_dir = tmp_path / "test" + + dataset = lance.write_dataset(table, base_dir) + + schema = pa.Table.from_pydict({"b": range(1), "a": range(1)}).schema + lance_schema = LanceSchema.from_pyarrow(schema) + + project = lance.LanceOperation.Project(lance_schema) + dataset = lance.LanceDataset.commit(dataset, project, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "b": list(range(100, 200)), + "a": list(range(300, 400)), + } + ) + assert tbl == expected + + +def test_empty_structs(tmp_path): + schema = pa.schema([pa.field("id", pa.int32()), pa.field("empties", pa.struct([]))]) + table = pa.table({"id": [0, 1, 2], "empties": [{}] * 3}, schema=schema) + ds = lance.write_dataset(table, tmp_path) + res = ds.take([2, 0, 1]) + assert res.num_rows == 3 + assert res == table.take([2, 0, 1]) + + +def test_create_table_from_pylist(tmp_path): + data = [ + {"foo": 1, "bar": "one"}, + {"foo": 3, "bar": "three"}, + ] + ds = lance.write_dataset(data, tmp_path) + + assert ds.to_table() == pa.Table.from_pylist(data) + + +def test_create_table_from_pydict(tmp_path): + dat = { + "foo": [1, 3], + "bar": ["one", "three"], + } + ds = lance.write_dataset(dat, tmp_path) + assert ds.to_table() == pa.Table.from_pydict(dat) diff --git a/python/python/tests/test_f16.py b/python/python/tests/test_f16.py index 266e1ef1628..d06703593d0 100644 --- a/python/python/tests/test_f16.py +++ b/python/python/tests/test_f16.py @@ -6,11 +6,18 @@ import lance import numpy as np import pyarrow as pa +import pytest +torch = pytest.importorskip("torch") -def test_f16_embeddings(tmp_path: Path): - DIM = 32 - TOTAL = 1000 + +@pytest.mark.parametrize("accelerator", [None, "cuda"]) +def test_f16_embeddings(tmp_path: Path, accelerator: str): + if not torch.cuda.is_available() and accelerator == "cuda": + pytest.skip("CUDA not available") + + DIM = 16 + TOTAL = 256 values = np.random.random(TOTAL * DIM).astype(np.float16) fsl = pa.FixedSizeListArray.from_arrays(values, DIM) data = pa.Table.from_arrays([fsl, np.arange(TOTAL)], names=["vec", "id"]) @@ -19,7 +26,12 @@ def test_f16_embeddings(tmp_path: Path): assert ds.schema.field("vec").type.value_type == pa.float16() ds = ds.create_index( - "vec", "IVF_PQ", replace=True, num_partitions=2, num_sub_vectors=2 + "vec", + "IVF_PQ", + replace=True, + num_partitions=2, + num_sub_vectors=2, + accelerator=accelerator, ) # Can use float32 to search diff --git a/python/python/tests/test_file.py b/python/python/tests/test_file.py index 45c53a8c20c..da330d315a2 100644 --- a/python/python/tests/test_file.py +++ b/python/python/tests/test_file.py @@ -3,6 +3,7 @@ import os +import numpy as np import pyarrow as pa import pyarrow.parquet as pq import pytest @@ -214,6 +215,34 @@ def test_metadata(tmp_path): assert len(page.encoding) > 0 +def test_file_stat(tmp_path): + path = tmp_path / "foo.lance" + schema = pa.schema( + [pa.field("a", pa.int64()), pa.field("b", pa.list_(pa.float64(), 8))] + ) + + num_rows = 1_000_000 + + data1 = pa.array(range(num_rows)) + + # Create a fixed-size list of float64 with dimension 8 + fixed_size_list = [np.random.rand(8).tolist() for _ in range(num_rows)] + data2 = pa.array(fixed_size_list, type=pa.list_(pa.float64(), 8)) + + with LanceFileWriter(str(path), schema) as writer: + writer.write_batch(pa.table({"a": data1, "b": data2})) + reader = LanceFileReader(str(path)) + file_stat = reader.file_statistics() + + assert len(file_stat.columns) == 2 + + assert file_stat.columns[0].num_pages == 1 + assert file_stat.columns[0].size_bytes == 8_000_000 + + assert file_stat.columns[1].num_pages == 2 + assert file_stat.columns[1].size_bytes == 64_000_000 + + def test_round_trip_parquet(tmp_path): pq_path = tmp_path / "foo.parquet" table = pa.table({"int": [1, 2], "list_str": [["x", "yz", "abc"], ["foo", "bar"]]}) @@ -330,6 +359,17 @@ def round_trip(arr): assert round_tripped.type == dict_arr.type +def test_empty_structs(tmp_path): + schema = pa.schema([pa.field("empties", pa.struct([]))]) + table = pa.table({"empties": [{}] * 3}, schema=schema) + path = tmp_path / "foo.lance" + with LanceFileWriter(str(path)) as writer: + writer.write_batch(table) + reader = LanceFileReader(str(path)) + round_tripped = reader.read_all().to_table() + assert round_tripped == table + + def test_write_read_global_buffer(tmp_path): table = pa.table({"a": [1, 2, 3]}) path = tmp_path / "foo.lance" diff --git a/python/python/tests/test_filter.py b/python/python/tests/test_filter.py index e9096599c5c..d74a383501e 100644 --- a/python/python/tests/test_filter.py +++ b/python/python/tests/test_filter.py @@ -81,10 +81,15 @@ def test_sql_predicates(dataset): ("int >= 50", 50), ("int = 50", 1), ("int != 50", 99), + ("int BETWEEN 50 AND 60", 11), ("float < 30.0", 45), ("str = 'aa'", 16), ("str in ('aa', 'bb')", 26), ("rec.bool", 50), + ("rec.bool is true", 50), + ("rec.bool is not true", 50), + ("rec.bool is false", 50), + ("rec.bool is not false", 50), ("rec.date = cast('2021-01-01' as date)", 1), ("rec.dt = cast('2021-01-01 00:00:00' as timestamp(6))", 1), ("rec.dt = cast('2021-01-01 00:00:00' as timestamp)", 1), @@ -103,6 +108,13 @@ def test_sql_predicates(dataset): assert dataset.to_table(filter=expr).num_rows == expected_num_rows +def test_illegal_predicates(dataset): + predicates_nrows = ["str BETWEEN 10 AND 20", "str > 10"] + for expr in predicates_nrows: + with pytest.raises(ValueError, match="Invalid user input: *"): + dataset.to_table(filter=expr) + + def test_compound(dataset): predicates = [ pc.field("int") >= 50, @@ -256,3 +268,19 @@ def test_duckdb(tmp_path): expected = duckdb.query("SELECT id, meta, price FROM ds").to_df() expected = expected[expected.meta == "aa"].reset_index(drop=True) tm.assert_frame_equal(actual, expected) + + +def test_struct_field_order(tmp_path): + """ + This test regresses some old behavior where the order of struct fields would get + messed up due to late materialization and we would get {y,x} instead of {x,y} + """ + data = pa.table({"struct": [{"x": i, "y": i} for i in range(10)]}) + dataset = lance.write_dataset(data, tmp_path) + + for late_materialization in [True, False]: + result = dataset.to_table( + filter="struct.y > 5", late_materialization=late_materialization + ) + expected = pa.table({"struct": [{"x": i, "y": i} for i in range(6, 10)]}) + assert result == expected diff --git a/python/python/tests/test_fragment.py b/python/python/tests/test_fragment.py index c14cba5a731..9f82678ba96 100644 --- a/python/python/tests/test_fragment.py +++ b/python/python/tests/test_fragment.py @@ -3,12 +3,14 @@ import json import multiprocessing +import pickle import uuid from pathlib import Path import lance import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pytest from helper import ProgressForTest from lance import ( @@ -32,8 +34,14 @@ def test_write_fragment(tmp_path: Path): df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) frag = LanceFragment.create(tmp_path, df) - meta = frag.to_json() + assert len(frag.files) == 1 + assert frag.files[0].fields == [0] + assert frag.physical_rows == 5 + assert frag.row_id_meta is None + assert frag.deletion_file is None + + meta = frag.to_json() assert "id" in meta assert "files" in meta assert meta["files"][0]["fields"] == [0] @@ -63,11 +71,11 @@ def test_write_fragment_two_phases(tmp_path: Path): def test_write_legacy_fragment(tmp_path: Path): tab = pa.table({"a": range(1024)}) frag = LanceFragment.create(tmp_path, tab, data_storage_version="legacy") - assert "file_major_version: 2" not in str(frag) + assert "file_major_version=2" not in str(frag) tab = pa.table({"a": range(1024)}) frag = LanceFragment.create(tmp_path, tab, data_storage_version="stable") - assert "file_major_version: 2" in str(frag) + assert "file_major_version=2" in str(frag) def test_scan_fragment(tmp_path: Path): @@ -133,9 +141,9 @@ def test_write_fragments_schema_holes(tmp_path: Path): dataset.drop_columns(["b"]) def get_field_ids(fragment): - return [id for f in fragment.data_files() for id in f.field_ids()] + return [id for f in fragment.files for id in f.fields] - field_ids = get_field_ids(dataset.get_fragments()[0]) + field_ids = get_field_ids(dataset.get_fragments()[0].metadata) data = pa.table({"a": range(3, 6), "c": range(5, 8)}) fragment = LanceFragment.create(tmp_path, data) @@ -204,10 +212,10 @@ def test_dataset_progress(tmp_path: Path): assert len(metadata["files"]) == 1 # Fragments aren't exactly equal, because the file was written before # physical_rows was known. However, the paths should be the same. - assert len(fragment.data_files()) == 1 + assert len(fragment.files) == 1 deserialized = FragmentMetadata.from_json(json.dumps(metadata)) - assert len(deserialized.data_files()) == 1 - assert fragment.data_files()[0].path() == deserialized.data_files()[0].path() + assert len(deserialized.files) == 1 + assert fragment.files[0].path == deserialized.files[0].path ctx = multiprocessing.get_context("spawn") p = ctx.Process(target=failing_write, args=(progress_uri, dataset_uri)) @@ -246,16 +254,17 @@ def test_fragment_meta(): meta = FragmentMetadata.from_json(json.dumps(data)) assert meta.id == 0 - assert len(meta.data_files()) == 2 - assert meta.data_files()[0].path() == "0.lance" - assert meta.data_files()[1].path() == "1.lance" + assert len(meta.files) == 2 + with pytest.warns(DeprecationWarning): + assert meta.files[0].path() == "0.lance" + assert meta.files[1].path == "1.lance" assert repr(meta) == ( - 'Fragment { id: 0, files: [DataFile { path: "0.lance", fields: [0], ' - "column_indices: [], file_major_version: 0, file_minor_version: 0 }, " - 'DataFile { path: "1.lance", fields: [1], column_indices: [], ' - "file_major_version: 0, file_minor_version: 0 }], deletion_file: None, " - "row_id_meta: None, physical_rows: Some(100) }" + "FragmentMetadata(id=0, files=[DataFile(path='0.lance', fields=[0], " + "column_indices=[], file_major_version=0, file_minor_version=0), " + "DataFile(path='1.lance', fields=[1], column_indices=[], " + "file_major_version=0, file_minor_version=0)], physical_rows=100, " + "deletion_file=None, row_id_meta=None)" ) @@ -317,7 +326,7 @@ def test_create_from_file(tmp_path): frag = LanceFragment.create_from_file(fragment_name, dataset, 0) op = LanceOperation.Append([frag]) - dataset = lance.LanceDataset.commit(dataset.uri, op, dataset.version) + dataset = lance.LanceDataset.commit(dataset.uri, op, read_version=dataset.version) frag = dataset.get_fragments()[0] assert frag.fragment_id == 0 @@ -331,7 +340,7 @@ def test_create_from_file(tmp_path): frag = LanceFragment.create_from_file(fragment_name, dataset, 0) op = LanceOperation.Append([frag]) - dataset = lance.LanceDataset.commit(dataset.uri, op, dataset.version) + dataset = lance.LanceDataset.commit(dataset.uri, op, read_version=dataset.version) frag = dataset.get_fragments()[1] assert frag.fragment_id == 1 @@ -349,8 +358,104 @@ def test_create_from_file(tmp_path): new_fragments=[frag], ) op = LanceOperation.Rewrite(groups=[group], rewritten_indices=[]) - dataset = lance.LanceDataset.commit(dataset.uri, op, dataset.version) + dataset = lance.LanceDataset.commit(dataset.uri, op, read_version=dataset.version) assert dataset.count_rows() == 1600 assert len(dataset.get_fragments()) == 1 assert dataset.get_fragments()[0].fragment_id == 2 + + +def test_fragment_merge(tmp_path): + schema = pa.schema([pa.field("a", pa.string())]) + batches = pa.RecordBatchReader.from_batches( + schema, + [ + pa.record_batch([pa.array(["0" * 1024] * 1024 * 8)], names=["a"]), + pa.record_batch([pa.array(["0" * 1024] * 1024 * 8)], names=["a"]), + ], + ) + + progress = ProgressForTest() + fragments = write_fragments( + batches, + tmp_path, + max_rows_per_group=512, + max_bytes_per_file=1024, + progress=progress, + ) + + operation = lance.LanceOperation.Overwrite(schema, fragments) + dataset = lance.LanceDataset.commit(tmp_path, operation) + merged = [] + schema = None + for fragment in dataset.get_fragments(): + table = fragment.scanner(with_row_id=True, columns=[]).to_table() + table = table.add_column(0, "b", [[i for i in range(len(table))]]) + fragment, schema = fragment.merge(table, "_rowid") + merged.append(fragment) + + merge = lance.LanceOperation.Merge(merged, schema) + dataset = lance.LanceDataset.commit( + tmp_path, merge, read_version=dataset.latest_version + ) + + merged = [] + schema = None + for fragment in dataset.get_fragments(): + table = fragment.scanner(with_row_address=True, columns=[]).to_table() + table = table.add_column(0, "c", [[i + 1 for i in range(len(table))]]) + fragment, schema = fragment.merge(table, "_rowaddr") + merged.append(fragment) + + merge = lance.LanceOperation.Merge(merged, schema) + dataset = lance.LanceDataset.commit( + tmp_path, merge, read_version=dataset.latest_version + ) + + merged = [] + for fragment in dataset.get_fragments(): + table = fragment.scanner(columns=["b"]).to_table() + table = table.add_column(0, "d", [[i + 2 for i in range(len(table))]]) + fragment, schema = fragment.merge(table, "b") + merged.append(fragment) + + merge = lance.LanceOperation.Merge(merged, schema) + dataset = lance.LanceDataset.commit( + tmp_path, merge, read_version=dataset.latest_version + ) + assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"] + + +def test_fragment_count_rows(tmp_path: Path): + data = pa.table({"a": range(800), "b": range(800)}) + ds = write_dataset(data, tmp_path) + + fragments = ds.get_fragments() + assert len(fragments) == 1 + + assert fragments[0].count_rows() == 800 + assert fragments[0].count_rows("a < 200") == 200 + assert fragments[0].count_rows(pc.field("a") < 200) == 200 + + +@pytest.mark.parametrize("enable_move_stable_row_ids", [False, True]) +def test_fragment_metadata_pickle(tmp_path: Path, enable_move_stable_row_ids: bool): + ds = write_dataset( + pa.table({"a": range(100)}), + tmp_path, + enable_move_stable_row_ids=enable_move_stable_row_ids, + ) + # Create a deletion file + ds.delete("a < 50") + fragment = ds.get_fragments()[0] + + frag_meta = fragment.metadata + + assert frag_meta.deletion_file is not None + if enable_move_stable_row_ids: + assert frag_meta.row_id_meta is not None + + # Pickle and unpickle the fragment metadata + round_trip = pickle.loads(pickle.dumps(frag_meta)) + + assert frag_meta == round_trip diff --git a/python/python/tests/test_lance.py b/python/python/tests/test_lance.py index 3e7360f916e..1a33c1b17b3 100644 --- a/python/python/tests/test_lance.py +++ b/python/python/tests/test_lance.py @@ -239,3 +239,12 @@ def test_roundtrip_schema(tmp_path): data = pa.table({"a": [1.0, 2.0]}).to_batches() dataset = lance.write_dataset(data, tmp_path, schema=schema) assert dataset.schema == schema + + +def test_io_counters(tmp_path): + starting_iops = lance.iops_counter() + starting_bytes = lance.bytes_read_counter() + dataset = lance.write_dataset(pa.table({"a": [1, 2, 3]}), tmp_path) + dataset.to_table() + assert lance.iops_counter() > starting_iops + assert lance.bytes_read_counter() > starting_bytes diff --git a/python/python/tests/test_log.py b/python/python/tests/test_log.py new file mode 100644 index 00000000000..e7c1b1a310b --- /dev/null +++ b/python/python/tests/test_log.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import logging +import os +from unittest import mock + +import pytest +from lance.log import ENV_NAME_PYLANCE_LOGGING_LEVEL, LOGGER, get_log_level, set_logger + + +@pytest.fixture(autouse=True) +def teardown_logger(): + yield + while LOGGER.handlers: + LOGGER.handlers.pop() + + +@pytest.mark.parametrize( + "env_value, expected", + [ + ("DEBUG", "DEBUG"), + ("INFO", "INFO"), + ("WARNING", "WARNING"), + ("DEBUG,INFO", "DEBUG"), + ("", "INFO"), + ("lance-core=debug,WARNING", "WARNING"), + ("DEBUG,lance-core=WARNING", "DEBUG"), + ], +) +def test_get_log_level(env_value, expected): + with mock.patch.dict(os.environ, {ENV_NAME_PYLANCE_LOGGING_LEVEL: env_value}): + assert get_log_level() == expected + + +def test_default_logger_level(): + assert LOGGER.level == logging.INFO + + +def test_set_logger_with_defaults(tmp_path): + log_file = tmp_path / "test.log" + set_logger(file_path=str(log_file)) + assert LOGGER.level == logging.INFO + assert len(LOGGER.handlers) == 1 + assert isinstance(LOGGER.handlers[0], logging.FileHandler) + assert LOGGER.handlers[0].baseFilename == str(log_file) + + +def test_set_logger_with_custom_level(tmp_path): + log_file = tmp_path / "test.log" + set_logger(file_path=str(log_file), level=logging.DEBUG) + assert LOGGER.level == logging.DEBUG + + +def test_set_logger_with_custom_format(tmp_path): + log_file = tmp_path / "test.log" + custom_format = "%(levelname)s: %(message)s" + set_logger(file_path=str(log_file), format_string=custom_format) + print(LOGGER.handlers[0].formatter._fmt) + assert LOGGER.handlers[0].formatter._fmt == custom_format + + +def test_set_logger_with_custom_handler(tmp_path): + custom_handler = logging.StreamHandler() + set_logger(log_handler=custom_handler) + assert LOGGER.handlers[0] == custom_handler + + +def test_logger_output(tmp_path, caplog): + log_file = tmp_path / "test.log" + set_logger(file_path=str(log_file)) + with caplog.at_level(logging.INFO): + LOGGER.info("Test log message") + assert "Test log message" in caplog.text diff --git a/python/python/tests/test_migration.py b/python/python/tests/test_migration.py index 1dcfa0dfffc..97ae4398e22 100644 --- a/python/python/tests/test_migration.py +++ b/python/python/tests/test_migration.py @@ -62,3 +62,19 @@ def test_fix_data_storage_version(tmp_path: Path): OSError, match="The dataset contains a mixture of file versions" ): ds.delete("false") + + +def test_old_btree_bitmap_indices(tmp_path: Path): + """ + In versions below 0.21.0 we used the legacy file format for btree and bitmap + indices. In version 0.21.0 we switched to the new format. This test ensures + that we can still read the old indices. + """ + ds = prep_dataset(tmp_path, "v0.20.0", "old_btree_bitmap_indices.lance") + + assert ds.to_table(filter="bitmap > 2") == pa.table( + {"bitmap": [3, 4], "btree": [3, 4]} + ) + assert ds.to_table(filter="btree > 2") == pa.table( + {"bitmap": [3, 4], "btree": [3, 4]} + ) diff --git a/python/python/tests/test_ray.py b/python/python/tests/test_ray.py index b85f185affa..b0b44f408ec 100644 --- a/python/python/tests/test_ray.py +++ b/python/python/tests/test_ray.py @@ -4,12 +4,13 @@ from pathlib import Path import lance +import pandas as pd import pyarrow as pa import pytest ray = pytest.importorskip("ray") - +from lance.ray.fragment_api import add_columns # noqa: E402 from lance.ray.sink import ( # noqa: E402 LanceCommitter, LanceDatasink, @@ -20,7 +21,7 @@ # Use this hook until we have official DataSink in Ray. _register_hooks() -ray.init() +ray.init(ignore_reinit_error=True) def test_ray_sink(tmp_path: Path): @@ -116,3 +117,80 @@ def test_ray_empty_write_lance(tmp_path: Path): # empty write would not generate dataset. with pytest.raises(ValueError): lance.dataset(tmp_path) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_ray_write_lance_none_str(tmp_path: Path): + def f(row): + return { + "id": row["id"], + "str": None, + } + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())]) + (ray.data.range(10).map(f).write_lance(tmp_path, schema=schema)) + + ds = lance.dataset(tmp_path) + ds.count_rows() == 10 + assert ds.schema == schema + + tbl = ds.to_table() + pylist = tbl["str"].to_pylist() + assert len(pylist) == 10 + for item in pylist: + assert item is None + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_ray_write_lance_none_str_datasink(tmp_path: Path): + def f(row): + return { + "id": row["id"], + "str": None, + } + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())]) + + sink = LanceDatasink(tmp_path, schema=schema) + (ray.data.range(10).map(f).write_datasink(sink)) + ds = lance.dataset(tmp_path) + ds.count_rows() == 10 + assert ds.schema == schema + + tbl = ds.to_table() + pylist = tbl["str"].to_pylist() + assert len(pylist) == 10 + for item in pylist: + assert item is None + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_lance_parallel_merge_columns(tmp_path: Path): + def generate_label(batch: pa.RecordBatch) -> pa.RecordBatch: + heights = batch.column("height").to_pylist() + tags = ["big" if height > 5 else "small" for height in heights] + df = pd.DataFrame({"size_labels": tags}) + + return pa.RecordBatch.from_pandas( + df, schema=pa.schema([pa.field("size_labels", pa.string())]) + ) + + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("height", pa.int64()), + pa.field("weight", pa.int64()), + ] + ) + ( + ray.data.range(11) + .repartition(1) + .map(lambda x: {"id": x["id"], "height": (x["id"] + 5), "weight": x["id"]}) + .write_lance(tmp_path, schema=schema) + ) + lance_ds = lance.dataset(tmp_path) + add_columns(lance_ds, generate_label, ["height"]) + ds = lance.dataset(tmp_path) + tbl = ds.to_table() + size_labels = sorted(tbl.column("size_labels").to_pylist()) + assert size_labels[:5] == ["big"] * 5 diff --git a/python/python/tests/test_s3_ddb.py b/python/python/tests/test_s3_ddb.py index 9e006fec60e..bcf08fcf4d9 100644 --- a/python/python/tests/test_s3_ddb.py +++ b/python/python/tests/test_s3_ddb.py @@ -24,11 +24,11 @@ # These are all keys that are accepted by storage_options CONFIG = { "allow_http": "true", - "aws_access_key_id": "ACCESSKEY", - "aws_secret_access_key": "SECRETKEY", - "aws_endpoint": "http://localhost:9000", - "dynamodb_endpoint": "http://localhost:8000", - "aws_region": "us-west-2", + "aws_access_key_id": "ACCESS_KEY", + "aws_secret_access_key": "SECRET_KEY", + "aws_endpoint": "http://localhost:4566", + "dynamodb_endpoint": "http://localhost:4566", + "aws_region": "us-east-1", } @@ -287,3 +287,35 @@ def test_file_writer_reader(s3_bucket: str): bytes(reader.read_global_buffer(global_buffer_pos)).decode() == global_buffer_text ) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.integration +@pytest.mark.skipif(not _RAY_AVAILABLE, reason="ray is not available") +def test_ray_read_lance(s3_bucket: str): + storage_options = copy.deepcopy(CONFIG) + table = pa.table({"a": [1, 2], "b": ["a", "b"]}) + path = f"s3://{s3_bucket}/test_ray_read.lance" + lance.write_dataset(table, path, storage_options=storage_options) + ds = ray.data.read_lance(path, storage_options=storage_options, concurrency=1) + ds.take(1) + + +@pytest.mark.integration +def test_append_fragment(s3_bucket: str): + storage_options = copy.deepcopy(CONFIG) + table = pa.table({"a": [1, 2], "b": ["a", "b"]}) + lance.fragment.LanceFragment.create( + f"s3://{s3_bucket}/test_append.lance", table, storage_options=storage_options + ) + + +@pytest.mark.integration +def test_s3_drop(s3_bucket: str): + storage_options = copy.deepcopy(CONFIG) + table_name = uuid.uuid4().hex + tmp_path = f"s3://{s3_bucket}/{table_name}.lance" + table = pa.table({"x": [0]}) + dataset = lance.write_dataset(table, tmp_path, storage_options=storage_options) + dataset.validate() + lance.LanceDataset.drop(tmp_path, storage_options=storage_options) diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 2dad325f967..c679d99e722 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -3,7 +3,9 @@ import os import random +import shutil import string +import zipfile from datetime import date, datetime, timedelta from pathlib import Path @@ -11,6 +13,7 @@ import numpy as np import pyarrow as pa import pytest +from lance.query import BoostQuery, MatchQuery, MultiMatchQuery, PhraseQuery from lance.vector import vec_to_table @@ -34,6 +37,27 @@ def gen_str(n, split="", char_set=string.ascii_letters + string.digits): return tbl +def set_language_model_path(): + os.environ["LANCE_LANGUAGE_MODEL_HOME"] = os.path.join( + os.path.dirname(__file__), "models" + ) + + +@pytest.fixture() +def lindera_ipadic(): + set_language_model_path() + model_path = os.path.join(os.path.dirname(__file__), "models", "lindera", "ipadic") + cwd = os.getcwd() + try: + os.chdir(model_path) + with zipfile.ZipFile("main.zip", "r") as zip_ref: + zip_ref.extractall() + os.chdir(cwd) + yield + finally: + shutil.rmtree(os.path.join(model_path, "main")) + + @pytest.fixture() def dataset(tmp_path): tbl = create_table() @@ -86,6 +110,42 @@ def test_indexed_scalar_scan(indexed_dataset: lance.LanceDataset, data_table: pa assert actual_price == expected_price +def test_indexed_between(tmp_path): + dataset = lance.write_dataset(pa.table({"val": range(0, 10000)}), tmp_path) + dataset.create_scalar_index("val", index_type="BTREE") + + scanner = dataset.scanner(filter="val BETWEEN 10 AND 20", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 11 + + scanner = dataset.scanner(filter="val >= 10 AND val <= 20", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 11 + + # The following cases are slightly ill-formed since end is before start + # but we should handle them gracefully and simply return an empty result + # (previously we panicked here) + scanner = dataset.scanner(filter="val >= 5000 AND val <= 0", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 0 + + scanner = dataset.scanner(filter="val BETWEEN 5000 AND 0", prefilter=True) + + assert "MaterializeIndex" in scanner.explain_plan() + + actual_data = scanner.to_table() + assert actual_data.num_rows == 0 + + def test_temporal_index(tmp_path): # Timestamps now = datetime.now() @@ -172,6 +232,49 @@ def test_indexed_vector_scan_postfilter( assert scanner.to_table().num_rows == 0 +def test_fixed_size_binary(tmp_path): + arr = pa.array([b"0123012301230123", b"2345234523452345"], pa.uuid()) + + ds = lance.write_dataset(pa.table({"uuid": arr}), tmp_path) + + ds.create_scalar_index("uuid", "BTREE") + + query = ( + "uuid = arrow_cast(0x32333435323334353233343532333435, 'FixedSizeBinary(16)')" + ) + assert "MaterializeIndex" in ds.scanner(filter=query).explain_plan() + + table = ds.scanner(filter=query).to_table() + assert table.num_rows == 1 + assert table.column("uuid").to_pylist() == arr.slice(1, 1).to_pylist() + + +def test_index_take_batch_size(tmp_path): + dataset = lance.write_dataset( + pa.table({"ints": range(1024)}), tmp_path, max_rows_per_file=100 + ) + dataset.create_scalar_index("ints", index_type="BTREE") + batches = dataset.scanner( + with_row_id=True, filter="ints > 0", batch_size=50 + ).to_batches() + batches = list(batches) + assert len(batches) == 21 + + dataset = lance.write_dataset( + pa.table({"strings": [f"string-{i}" for i in range(1024)]}), + tmp_path, + max_rows_per_file=100, + mode="overwrite", + ) + dataset.create_scalar_index("strings", index_type="NGRAM") + filter = "contains(strings, 'ing')" + batches = dataset.scanner( + with_row_id=True, filter=filter, batch_size=50, limit=1024 + ).to_batches() + batches = list(batches) + assert len(batches) == 21 + + def test_all_null_chunk(tmp_path): def gen_string(idx: int): if idx % 2 == 0: @@ -191,25 +294,26 @@ def gen_string(idx: int): # environment variable. This test ensures that the environment variable # is respected. def test_lance_mem_pool_env_var(tmp_path): - strings = pa.array([f"string-{i}" * 10 for i in range(100 * 1024)]) - table = pa.Table.from_arrays([strings], ["str"]) + ints = pa.array([i * 10 for i in range(100 * 1024)]) + table = pa.Table.from_arrays([ints], ["int"]) dataset = lance.write_dataset(table, tmp_path) # Should succeed - dataset.create_scalar_index("str", index_type="BTREE") + dataset.create_scalar_index("int", index_type="BTREE") try: # Should fail if we intentionally use a very small memory pool os.environ["LANCE_MEM_POOL_SIZE"] = "1024" with pytest.raises(Exception): - dataset.create_scalar_index("str", index_type="BTREE", replace=True) + dataset.create_scalar_index("int", index_type="BTREE", replace=True) # Should succeed again since bypassing spilling takes precedence os.environ["LANCE_BYPASS_SPILLING"] = "1" - dataset.create_scalar_index("str", index_type="BTREE", replace=True) + dataset.create_scalar_index("int", index_type="BTREE", replace=True) finally: del os.environ["LANCE_MEM_POOL_SIZE"] - del os.environ["LANCE_BYPASS_SPILLING"] + if "LANCE_BYPASS_SPILLING" in os.environ: + del os.environ["LANCE_BYPASS_SPILLING"] @pytest.mark.parametrize("with_position", [True, False]) @@ -229,6 +333,11 @@ def test_full_text_search(dataset, with_position): for row in results: assert query in row.as_py() + with pytest.raises(ValueError, match="Cannot include deleted rows"): + dataset.to_table( + with_row_id=True, full_text_query=query, include_deleted_rows=True + ) + def test_filter_with_fts_index(dataset): dataset.create_scalar_index("doc", index_type="INVERTED", with_position=False) @@ -261,6 +370,15 @@ def test_indexed_filter_with_fts_index(tmp_path): ds.create_scalar_index("text", "INVERTED") ds.create_scalar_index("sentiment", "BITMAP") + # append more data to test flat FTS + data = pa.table( + { + "text": ["flat", "search"], + "sentiment": ["positive", "positive"], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="append") + results = ds.to_table( full_text_query="puppy", filter="sentiment='positive'", @@ -270,6 +388,163 @@ def test_indexed_filter_with_fts_index(tmp_path): assert results["_rowid"].to_pylist() == [2, 3] +def test_fts_stats(dataset): + dataset.create_scalar_index( + "doc", index_type="INVERTED", with_position=False, remove_stop_words=True + ) + stats = dataset.stats.index_stats("doc_idx") + assert stats["index_type"] == "Inverted" + stats = stats["indices"][0] + params = stats["params"] + + assert params["with_position"] is False + assert params["base_tokenizer"] == "simple" + assert params["language"] == "English" + assert params["max_token_length"] == 40 + assert params["lower_case"] is True + assert params["stem"] is False + assert params["remove_stop_words"] is True + assert params["ascii_folding"] is False + + +def test_fts_on_list(tmp_path): + data = pa.table( + { + "text": [ + ["lance database", "the", "search"], + ["lance database"], + ["lance", "search"], + ["database", "search"], + ["unrelated", "doc"], + ] + } + ) + ds = lance.write_dataset(data, tmp_path) + ds.create_scalar_index("text", "INVERTED", with_position=True) + + results = ds.to_table(full_text_query="lance") + assert results.num_rows == 3 + + results = ds.to_table(full_text_query=PhraseQuery("lance database", "text")) + assert results.num_rows == 2 + + +def test_fts_fuzzy_query(tmp_path): + data = pa.table( + { + "text": [ + "fa", + "fo", # spellchecker:disable-line + "fob", + "focus", + "foo", + "food", + "foul", + ] + } + ) + + ds = lance.write_dataset(data, tmp_path) + ds.create_scalar_index("text", "INVERTED") + + results = ds.to_table( + full_text_query=MatchQuery("foo", "text", fuzziness=1), + ) + assert results.num_rows == 4 + assert set(results["text"].to_pylist()) == { + "foo", + "fo", # 1 deletion # spellchecker:disable-line + "fob", # 1 substitution + "food", # 1 insertion + } + + results = ds.to_table( + full_text_query=MatchQuery("foo", "text", fuzziness=1, max_expansions=3), + ) + assert results.num_rows == 3 + + +def test_fts_phrase_query(tmp_path): + data = pa.table( + { + "text": [ + "frodo was a puppy", + "frodo was a happy puppy", + "frodo was a very happy puppy", + "frodo was a puppy with a tail", + ] + } + ) + + ds = lance.write_dataset(data, tmp_path) + ds.create_scalar_index("text", "INVERTED") + + results = ds.to_table( + full_text_query='"frodo was a puppy"', + ) + assert results.num_rows == 2 + assert set(results["text"].to_pylist()) == { + "frodo was a puppy", + "frodo was a puppy with a tail", + } + + results = ds.to_table( + full_text_query=PhraseQuery("frodo was a puppy", "text"), + ) + assert results.num_rows == 2 + assert set(results["text"].to_pylist()) == { + "frodo was a puppy", + "frodo was a puppy with a tail", + } + + +def test_fts_boost_query(tmp_path): + data = pa.table( + { + "text": [ + "frodo was a puppy", + "frodo was a happy puppy", + "frodo was a puppy with a tail", + ] + } + ) + + ds = lance.write_dataset(data, tmp_path) + ds.create_scalar_index("text", "INVERTED") + results = ds.to_table( + full_text_query=BoostQuery( + MatchQuery("puppy", "text"), + MatchQuery("happy", "text"), + negative_boost=0.5, + ), + ) + assert results.num_rows == 3 + assert set(results["text"].to_pylist()) == { + "frodo was a puppy", + "frodo was a puppy with a tail", + "frodo was a happy puppy", + } + + +def test_fts_multi_match_query(tmp_path): + data = pa.table( + { + "title": ["title common", "title hello", "title vector"], + "content": ["content world", "content database", "content common"], + } + ) + + ds = lance.write_dataset(data, tmp_path) + ds.create_scalar_index("title", "INVERTED") + ds.create_scalar_index("content", "INVERTED") + + results = ds.to_table( + full_text_query=MultiMatchQuery("common", ["title", "content"]), + ) + assert set(results["title"].to_pylist()) == {"title common", "title vector"} + assert set(results["content"].to_pylist()) == {"content world", "content common"} + + def test_fts_with_postfilter(tmp_path): tab = pa.table({"text": ["Frodo the puppy"] * 100, "id": range(100)}) dataset = lance.write_dataset(tab, tmp_path) @@ -307,6 +582,170 @@ def test_fts_all_deleted(dataset): dataset.to_table(full_text_query=first_row_doc) +def test_indexed_filter_with_fts_index_with_lindera_ipadic_jp_tokenizer( + tmp_path, lindera_ipadic +): + os.environ["LANCE_LANGUAGE_MODEL_HOME"] = os.path.join( + os.path.dirname(__file__), "models" + ) + data = pa.table( + { + "text": [ + "æˆç”°å›½éš›ç©ºæ¸¯", + "æ±äº¬å›½éš›ç©ºæ¸¯", + "羽田空港", + ], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + ds.create_scalar_index("text", "INVERTED", base_tokenizer="lindera/ipadic") + + results = ds.to_table( + full_text_query="æˆç”°", + prefilter=True, + with_row_id=True, + ) + assert results["_rowid"].to_pylist() == [0] + + +def test_lindera_ipadic_jp_tokenizer_invalid_user_dict_path(tmp_path, lindera_ipadic): + data = pa.table( + { + "text": [ + "æˆç”°å›½éš›ç©ºæ¸¯", + ], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + with pytest.raises(OSError): + ds.create_scalar_index( + "text", "INVERTED", base_tokenizer="lindera/invalid_dict" + ) + + +def test_lindera_ipadic_jp_tokenizer_csv_user_dict_without_type( + tmp_path, lindera_ipadic +): + data = pa.table( + { + "text": [ + "æˆç”°å›½éš›ç©ºæ¸¯", + ], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + with pytest.raises(OSError): + ds.create_scalar_index( + "text", "INVERTED", base_tokenizer="lindera/invalid_dict2" + ) + + +def test_lindera_ipadic_jp_tokenizer_csv_user_dict(tmp_path, lindera_ipadic): + data = pa.table( + { + "text": [ + "æˆç”°å›½éš›ç©ºæ¸¯", + "æ±äº¬å›½éš›ç©ºæ¸¯", + "羽田空港", + ], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + ds.create_scalar_index("text", "INVERTED", base_tokenizer="lindera/user_dict") + results = ds.to_table( + full_text_query="æˆç”°", + prefilter=True, + with_row_id=True, + ) + assert len(results) == 0 + results = ds.to_table( + full_text_query="æˆç”°å›½éš›ç©ºæ¸¯", + prefilter=True, + with_row_id=True, + ) + assert results["_rowid"].to_pylist() == [0] + + +def test_lindera_ipadic_jp_tokenizer_bin_user_dict(tmp_path, lindera_ipadic): + data = pa.table( + { + "text": [ + "æˆç”°å›½éš›ç©ºæ¸¯", + ], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + ds.create_scalar_index("text", "INVERTED", base_tokenizer="lindera/user_dict2") + + +def test_jieba_tokenizer(tmp_path): + set_language_model_path() + data = pa.table( + { + "text": ["我们都有光明的å‰é€”", "光明的å‰é€”"], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + ds.create_scalar_index("text", "INVERTED", base_tokenizer="jieba/default") + results = ds.to_table( + full_text_query="我们", + prefilter=True, + with_row_id=True, + ) + assert results["_rowid"].to_pylist() == [0] + + +def test_jieba_invalid_user_dict_tokenizer(tmp_path): + set_language_model_path() + data = pa.table( + { + "text": [ + "我们都有光明的å‰é€”", + ], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + with pytest.raises(OSError): + ds.create_scalar_index("text", "INVERTED", base_tokenizer="jieba/invalid_dict") + + +def test_jieba_invalid_main_dict_tokenizer(tmp_path): + set_language_model_path() + data = pa.table( + { + "text": [ + "我们都有光明的å‰é€”", + ], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + with pytest.raises(OSError): + ds.create_scalar_index("text", "INVERTED", base_tokenizer="jieba/invalid_dict2") + + +def test_jieba_user_dict_tokenizer(tmp_path): + set_language_model_path() + data = pa.table( + { + "text": ["我们都有光明的å‰é€”", "光明的å‰é€”"], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + ds.create_scalar_index("text", "INVERTED", base_tokenizer="jieba/user_dict") + results = ds.to_table( + full_text_query="çš„å‰", + prefilter=True, + with_row_id=True, + ) + assert len(results) == 0 + results = ds.to_table( + full_text_query="光明的å‰é€”", + prefilter=True, + with_row_id=True, + ) + assert results["_rowid"].to_pylist() == [1, 0] + + def test_bitmap_index(tmp_path: Path): """Test create bitmap index""" tbl = pa.Table.from_arrays( @@ -319,6 +758,146 @@ def test_bitmap_index(tmp_path: Path): assert indices[0]["type"] == "Bitmap" +def test_ngram_index(tmp_path: Path): + """Test create ngram index""" + tbl = pa.Table.from_arrays( + [ + pa.array( + [["apple", "apples", "banana", "coconut"][i % 4] for i in range(100)] + ) + ], + names=["words"], + ) + dataset = lance.write_dataset(tbl, tmp_path / "dataset") + dataset.create_scalar_index("words", index_type="NGRAM") + indices = dataset.list_indices() + assert len(indices) == 1 + assert indices[0]["type"] == "NGram" + + scan_plan = dataset.scanner(filter="contains(words, 'apple')").explain_plan(True) + assert "MaterializeIndex" in scan_plan + + assert dataset.to_table(filter="contains(words, 'apple')").num_rows == 50 + assert dataset.to_table(filter="contains(words, 'banana')").num_rows == 25 + assert dataset.to_table(filter="contains(words, 'coconut')").num_rows == 25 + assert dataset.to_table(filter="contains(words, 'apples')").num_rows == 25 + assert ( + dataset.to_table( + filter="contains(words, 'apple') AND contains(words, 'banana')" + ).num_rows + == 0 + ) + assert ( + dataset.to_table( + filter="contains(words, 'apple') OR contains(words, 'banana')" + ).num_rows + == 75 + ) + + +def test_null_handling(tmp_path: Path): + tbl = pa.table( + { + "x": [1, 2, None, 3], + } + ) + dataset = lance.write_dataset(tbl, tmp_path / "dataset") + + def check(has_index: bool): + assert dataset.to_table(filter="x IS NULL").num_rows == 1 + assert dataset.to_table(filter="x IS NOT NULL").num_rows == 3 + assert dataset.to_table(filter="x > 0").num_rows == 3 + assert dataset.to_table(filter="x < 5").num_rows == 3 + assert dataset.to_table(filter="x IN (1, 2)").num_rows == 2 + # Note: there is a bit of discrepancy here. Datafusion does not consider + # NULL==NULL when doing an IN operation due to classic SQL shenanigans. + # We should decide at some point which behavior we want and make this + # consistent. + if has_index: + assert dataset.to_table(filter="x IN (1, 2, NULL)").num_rows == 3 + else: + assert dataset.to_table(filter="x IN (1, 2, NULL)").num_rows == 2 + + check(False) + dataset.create_scalar_index("x", index_type="BITMAP") + check(True) + dataset.create_scalar_index("x", index_type="BTREE") + check(True) + + +def test_nan_handling(tmp_path: Path): + tbl = pa.table( + { + "x": [ + 1.0, + float("-nan"), + float("infinity"), + float("-infinity"), + 2.0, + float("nan"), + 3.0, + ], + } + ) + dataset = lance.write_dataset(tbl, tmp_path / "dataset") + + # There is no way, in DF, to query for NAN / INF, that I'm aware of. + # So the best we can do here is make sure that the presence of NAN / INF + # doesn't interfere with normal operation of the btree. + def check(has_index: bool): + assert dataset.to_table(filter="x IS NULL").num_rows == 0 + assert dataset.to_table(filter="x IS NOT NULL").num_rows == 7 + assert dataset.to_table(filter="x > 0").num_rows == 5 + assert dataset.to_table(filter="x < 5").num_rows == 5 + assert dataset.to_table(filter="x IN (1, 2)").num_rows == 2 + + check(False) + dataset.create_scalar_index("x", index_type="BITMAP") + check(True) + dataset.create_scalar_index("x", index_type="BTREE") + check(True) + + +def test_scalar_index_with_nulls(tmp_path): + # Create a test dataframe with 50% null values. + test_table_size = 10_000 + test_table = pa.table( + { + "item_id": list(range(test_table_size)), + "inner_id": list(range(test_table_size)), + "category": ["a", None] * (test_table_size // 2), + "numeric_int": [1, None] * (test_table_size // 2), + "numeric_float": [0.1, None] * (test_table_size // 2), + "boolean_col": [True, None] * (test_table_size // 2), + "timestamp_col": [datetime(2023, 1, 1), None] * (test_table_size // 2), + "ngram_col": ["apple", None] * (test_table_size // 2), + } + ) + ds = lance.write_dataset(test_table, tmp_path) + ds.create_scalar_index("inner_id", index_type="BTREE") + ds.create_scalar_index("category", index_type="BTREE") + ds.create_scalar_index("boolean_col", index_type="BTREE") + ds.create_scalar_index("timestamp_col", index_type="BTREE") + ds.create_scalar_index("ngram_col", index_type="NGRAM") + # Test querying with filters on columns with nulls. + k = test_table_size // 2 + result = ds.to_table(filter="category = 'a'", limit=k) + assert len(result) == k + # Booleans should be stored as strings in the table for backwards compatibility. + result = ds.to_table(filter="boolean_col IS TRUE", limit=k) + assert len(result) == k + result = ds.to_table(filter="timestamp_col IS NOT NULL", limit=k) + assert len(result) == k + + # Ensure ngram index works with nulls + result = ds.to_table(filter="ngram_col = 'apple'") + assert len(result) == k + result = ds.to_table(filter="ngram_col IS NULL") + assert len(result) == k + result = ds.to_table(filter="contains(ngram_col, 'appl')") + assert len(result) == k + + def test_label_list_index(tmp_path: Path): tags = pa.array(["tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7"]) tag_list = pa.ListArray.from_arrays([0, 2, 4], tags) @@ -328,3 +907,171 @@ def test_label_list_index(tmp_path: Path): indices = dataset.list_indices() assert len(indices) == 1 assert indices[0]["type"] == "LabelList" + + +def test_create_index_empty_dataset(tmp_path: Path): + # Creating an index on an empty dataset is (currently) not terribly useful but + # we shouldn't return strange errors. + schema = pa.schema( + [ + pa.field("btree", pa.int32()), + pa.field("bitmap", pa.int32()), + pa.field("label_list", pa.list_(pa.string())), + pa.field("inverted", pa.string()), + pa.field("ngram", pa.string()), + ] + ) + ds = lance.write_dataset([], tmp_path, schema=schema) + + for index_type in ["BTREE", "BITMAP", "LABEL_LIST", "INVERTED", "NGRAM"]: + ds.create_scalar_index(index_type.lower(), index_type=index_type) + + # Make sure the empty index doesn't cause searches to fail + ds.insert( + pa.table( + { + "btree": pa.array([1], pa.int32()), + "bitmap": pa.array([1], pa.int32()), + "label_list": [["foo", "bar"]], + "inverted": ["blah"], + "ngram": ["apple"], + } + ) + ) + + def test_searches(): + assert ds.to_table(filter="btree = 1").num_rows == 1 + assert ds.to_table(filter="btree = 0").num_rows == 0 + assert ds.to_table(filter="bitmap = 1").num_rows == 1 + assert ds.to_table(filter="bitmap = 0").num_rows == 0 + assert ds.to_table(filter="array_has_any(label_list, ['foo'])").num_rows == 1 + assert ds.to_table(filter="array_has_any(label_list, ['oof'])").num_rows == 0 + assert ds.to_table(filter="inverted = 'blah'").num_rows == 1 + assert ds.to_table(filter="inverted = 'halb'").num_rows == 0 + assert ds.to_table(filter="contains(ngram, 'apple')").num_rows == 1 + assert ds.to_table(filter="contains(ngram, 'banana')").num_rows == 0 + assert ds.to_table(filter="ngram = 'apple'").num_rows == 1 + + test_searches() + + # Make sure fetching index stats on empty index is ok + for idx in ds.list_indices(): + ds.stats.index_stats(idx["name"]) + + # Make sure updating empty indices is ok + ds.optimize.optimize_indices() + + # Finally, make sure we can still search after updating + test_searches() + + +def test_optimize_no_new_data(tmp_path: Path): + tbl = pa.table( + { + "btree": pa.array([None], pa.int64()), + "bitmap": pa.array([None], pa.int64()), + "ngram": pa.array([None], pa.string()), + } + ) + dataset = lance.write_dataset(tbl, tmp_path) + dataset.create_scalar_index("btree", index_type="BTREE") + dataset.create_scalar_index("bitmap", index_type="BITMAP") + dataset.create_scalar_index("ngram", index_type="NGRAM") + + assert dataset.to_table(filter="btree IS NULL").num_rows == 1 + assert dataset.to_table(filter="bitmap IS NULL").num_rows == 1 + assert dataset.to_table(filter="ngram IS NULL").num_rows == 1 + + dataset.insert([], schema=tbl.schema) + dataset.optimize.optimize_indices() + + assert dataset.to_table(filter="btree IS NULL").num_rows == 1 + assert dataset.to_table(filter="bitmap IS NULL").num_rows == 1 + assert dataset.to_table(filter="ngram IS NULL").num_rows == 1 + + dataset.insert(pa.table({"btree": [2]})) + dataset.optimize.optimize_indices() + + assert dataset.to_table(filter="btree IS NULL").num_rows == 1 + assert dataset.to_table(filter="bitmap IS NULL").num_rows == 2 + assert dataset.to_table(filter="ngram IS NULL").num_rows == 2 + + dataset.insert(pa.table({"bitmap": [2]})) + dataset.optimize.optimize_indices() + + assert dataset.to_table(filter="btree IS NULL").num_rows == 2 + assert dataset.to_table(filter="bitmap IS NULL").num_rows == 2 + assert dataset.to_table(filter="ngram IS NULL").num_rows == 3 + + dataset.insert(pa.table({"ngram": ["apple"]})) + + assert dataset.to_table(filter="btree IS NULL").num_rows == 3 + assert dataset.to_table(filter="bitmap IS NULL").num_rows == 3 + assert dataset.to_table(filter="ngram IS NULL").num_rows == 3 + + +def test_drop_index(tmp_path): + test_table_size = 100 + test_table = pa.table( + { + "btree": list(range(test_table_size)), + "bitmap": list(range(test_table_size)), + "fts": ["a" for _ in range(test_table_size)], + "ngram": ["a" for _ in range(test_table_size)], + } + ) + ds = lance.write_dataset(test_table, tmp_path) + ds.create_scalar_index("btree", index_type="BTREE") + ds.create_scalar_index("bitmap", index_type="BITMAP") + ds.create_scalar_index("fts", index_type="INVERTED") + ds.create_scalar_index("ngram", index_type="NGRAM") + + assert len(ds.list_indices()) == 4 + + # Attempt to drop index (name does not exist) + with pytest.raises(RuntimeError, match="index not found"): + ds.drop_index("nonexistent_name") + + for idx in ds.list_indices(): + idx_name = idx["name"] + ds.drop_index(idx_name) + + assert len(ds.list_indices()) == 0 + + # Ensure we can still search columns + assert ds.to_table(filter="btree = 1").num_rows == 1 + assert ds.to_table(filter="bitmap = 1").num_rows == 1 + assert ds.to_table(filter="fts = 'a'").num_rows == test_table_size + assert ds.to_table(filter="contains(ngram, 'a')").num_rows == test_table_size + + +def test_index_prewarm(tmp_path: Path): + scan_stats = None + + def scan_stats_callback(stats: lance.ScanStatistics): + nonlocal scan_stats + scan_stats = stats + + test_table_size = 100 + test_table = pa.table( + { + "fts": ["a" for _ in range(test_table_size)], + } + ) + + # Write index, cache should not be populated + ds = lance.write_dataset(test_table, tmp_path) + ds.create_scalar_index("fts", index_type="INVERTED") + ds.scanner(scan_stats_callback=scan_stats_callback, full_text_query="a").to_table() + assert scan_stats.parts_loaded > 0 + + # Fresh load, no prewarm, cache should not be populated + ds = lance.dataset(tmp_path) + ds.scanner(scan_stats_callback=scan_stats_callback, full_text_query="a").to_table() + assert scan_stats.parts_loaded > 0 + + # Prewarm index, cache should be populated + ds = lance.dataset(tmp_path) + ds.prewarm_index("fts_idx") + ds.scanner(scan_stats_callback=scan_stats_callback, full_text_query="a").to_table() + assert scan_stats.parts_loaded == 0 diff --git a/python/python/tests/test_schema.py b/python/python/tests/test_schema.py index a8c379aff4d..fcff283ebe2 100644 --- a/python/python/tests/test_schema.py +++ b/python/python/tests/test_schema.py @@ -30,3 +30,33 @@ def test_lance_schema(tmp_path: Path): assert schema.to_pyarrow() == data.schema assert LanceSchema.from_pyarrow(data.schema) == schema + + fields = schema.fields() + assert len(fields) == 3 + assert fields[0].name() == "x" + assert fields[0].id() == 0 + assert fields[1].name() == "s" + assert fields[1].id() == 1 + + s_children = fields[1].children() + assert len(s_children) == 2 + assert s_children[0].name() == "a" + assert s_children[0].id() == 2 + assert s_children[1].name() == "b" + assert s_children[1].id() == 3 + + assert fields[2].name() == "y" + assert fields[2].id() == 4 + + l_children = fields[2].children() + assert len(l_children) == 1 + assert l_children[0].name() == "item" + assert l_children[0].id() == 5 + + # Changing column name does not change the id + dataset.alter_columns({"path": "s.a", "name": "new_name"}) + schema = dataset.lance_schema + fields = schema.fields() + s_fields = fields[1].children() + assert s_fields[0].name() == "new_name" + assert s_fields[0].id() == 2 diff --git a/python/python/tests/test_schema_evolution.py b/python/python/tests/test_schema_evolution.py index b2b6acbcfef..6560d8c7e7d 100644 --- a/python/python/tests/test_schema_evolution.py +++ b/python/python/tests/test_schema_evolution.py @@ -512,3 +512,31 @@ def some_udf(batch): with pytest.raises(ValueError, match="A checkpoint file cannot be used"): frag.merge_columns(some_udf, columns=["a"]) + + +def test_add_cols_all_null_with_sql(tmp_path: Path): + tab = pa.table( + { + "a": range(100), + } + ) + dataset = lance.write_dataset( + tab, tmp_path, max_rows_per_file=50, data_storage_version="stable" + ) + fragments_before = dataset.get_fragments() + dataset.add_columns({"b": "CAST(NULL AS INT)"}) + fragments_after = dataset.get_fragments() + + # assert this was a metadata only operation and no data was written + assert len(fragments_before) == len(fragments_after) + for frag_before, frag_after in zip(fragments_before, fragments_after): + assert frag_before.fragment_id == frag_after.fragment_id + assert frag_before.data_files() == frag_after.data_files() + + # assert the schema is as expected + assert dataset.schema == pa.schema( + { + "a": pa.int64(), + "b": pa.int32(), + } + ) diff --git a/python/python/tests/test_torch.py b/python/python/tests/test_torch.py new file mode 100644 index 00000000000..e949399b135 --- /dev/null +++ b/python/python/tests/test_torch.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import lance +import numpy as np +import pyarrow as pa +import pytest +from lance.torch.data import SafeLanceDataset, get_safe_loader # noqa: E402 + + +@pytest.fixture(scope="module") +def temp_lance_dataset(tmp_path_factory): + """Create temporary Lance dataset for testing""" + test_dir = tmp_path_factory.mktemp("lance_data") + dataset_path = test_dir / "test_dataset.lance" + + # Generate test data with batch_size aligned sample count + num_samples = 96 # 16 samples/batch * 6 batches + data = pa.table( + { + "id": range(num_samples), + "embedding": [ + np.random.rand(128).astype(np.float32).tobytes() + for _ in range(num_samples) + ], + } + ) + + lance.write_dataset(data, dataset_path) + yield str(dataset_path) + + +def test_dataset_initialization(temp_lance_dataset): + """Verify dataset basic functionality""" + ds = SafeLanceDataset(temp_lance_dataset) + + # Validate metadata + assert len(ds) == 96, "Sample count should match configured size" + + # Validate single sample format + sample = ds[0] + assert isinstance(sample, dict), "Sample should be dictionary type" + assert {"id", "embedding"}.issubset(sample.keys()), "Missing required fields" + + +def test_multiprocess_loading(temp_lance_dataset, capsys): + """Verify multi-worker data loading""" + dataset = SafeLanceDataset(temp_lance_dataset) + loader = get_safe_loader( + dataset, + num_workers=2, + batch_size=16, + drop_last=False, # Ensure full batches + ) + + total_samples = 0 + for batch in loader: + assert batch["id"].shape == (16,), "Batch dimension mismatch" + total_samples += batch["id"].shape[0] + + # Validate complete dataset loading + assert total_samples == 96, "Should load all samples" diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 7a855863794..487d5f14ca1 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -13,6 +13,7 @@ import pyarrow.compute as pc import pytest from lance import LanceFragment +from lance.dataset import VectorIndexReader torch = pytest.importorskip("torch") from lance.util import validate_vector_index # noqa: E402 @@ -48,6 +49,48 @@ def gen_str(n): return tbl +def create_multivec_table( + nvec=1000, nvec_per_row=5, ndim=128, nans=0, nullify=False, dtype=np.float32 +): + mat = np.random.randn(nvec, nvec_per_row, ndim) + if nans > 0: + nans_mat = np.empty((nans, ndim)) + nans_mat[:] = np.nan + mat = np.concatenate((mat, nans_mat), axis=0) + mat = mat.astype(dtype) + price = np.random.rand(nvec + nans) * 100 + + def gen_str(n): + return "".join(random.choices(string.ascii_letters + string.digits, k=n)) + + meta = np.array([gen_str(100) for _ in range(nvec + nans)]) + + multi_vec_type = pa.list_(pa.list_(pa.float32(), ndim)) + tbl = pa.Table.from_arrays( + [ + pa.array((mat[i].tolist() for i in range(nvec)), type=multi_vec_type), + ], + schema=pa.schema( + [ + pa.field("vector", pa.list_(pa.list_(pa.float32(), ndim))), + ] + ), + ) + tbl = ( + tbl.append_column("price", pa.array(price)) + .append_column("meta", pa.array(meta)) + .append_column("id", pa.array(range(nvec + nans))) + ) + if nullify: + idx = tbl.schema.get_field_index("vector") + vecs = tbl[idx].to_pylist() + nullified = [vec if i % 2 == 0 else None for i, vec in enumerate(vecs)] + field = tbl.schema.field(idx) + vecs = pa.array(nullified, field.type) + tbl = tbl.set_column(idx, field, vecs) + return tbl + + @pytest.fixture() def dataset(tmp_path): tbl = create_table() @@ -63,6 +106,23 @@ def indexed_dataset(tmp_path): ) +@pytest.fixture() +def multivec_dataset(): + tbl = create_multivec_table() + yield lance.write_dataset(tbl, "memory://") + + +@pytest.fixture() +def indexed_multivec_dataset(multivec_dataset): + yield multivec_dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=16, + metric="cosine", + ) + + def run(ds, q=None, assert_func=None): if q is None: q = np.random.randn(128) @@ -194,6 +254,22 @@ def test_index_with_nans(tmp_path): validate_vector_index(dataset, "vector") +def test_torch_index_with_nans(tmp_path): + # 1024 rows, the entire table should be sampled + tbl = create_table(nvec=1000, nans=24) + + dataset = lance.write_dataset(tbl, tmp_path) + dataset = dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=16, + accelerator=torch.device("cpu"), + one_pass_ivfpq=True, + ) + validate_vector_index(dataset, "vector") + + def test_index_with_no_centroid_movement(tmp_path): # this test makes the centroids essentially [1..] # this makes sure the early stop condition in the index building code @@ -373,6 +449,37 @@ def test_has_index(dataset, tmp_path): assert ann_ds.list_indices()[0]["fields"] == ["vector"] +def test_index_type(dataset, tmp_path): + ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") + + ann_ds = ann_ds.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=16, + replace=True, + ) + assert ann_ds.list_indices()[0]["type"] == "IVF_PQ" + + ann_ds = ann_ds.create_index( + "vector", + index_type="IVF_HNSW_SQ", + num_partitions=4, + num_sub_vectors=16, + replace=True, + ) + assert ann_ds.list_indices()[0]["type"] == "IVF_HNSW_SQ" + + ann_ds = ann_ds.create_index( + "vector", + index_type="IVF_HNSW_PQ", + num_partitions=4, + num_sub_vectors=16, + replace=True, + ) + assert ann_ds.list_indices()[0]["type"] == "IVF_HNSW_PQ" + + def test_create_dot_index(dataset, tmp_path): assert not dataset.has_index ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") @@ -386,6 +493,44 @@ def test_create_dot_index(dataset, tmp_path): assert ann_ds.has_index +def test_create_4bit_ivf_pq_index(dataset, tmp_path): + assert not dataset.has_index + ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") + ann_ds = ann_ds.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=1, + num_sub_vectors=16, + num_bits=4, + metric="l2", + ) + index = ann_ds.stats.index_stats("vector_idx") + assert index["indices"][0]["sub_index"]["nbits"] == 4 + + +def test_ivf_flat_over_binary_vector(tmp_path): + dim = 128 + nvec = 1000 + data = np.random.randint(0, 256, (nvec, dim // 8)).tolist() + array = pa.array(data, type=pa.list_(pa.uint8(), dim // 8)) + tbl = pa.Table.from_pydict({"vector": array}) + ds = lance.write_dataset(tbl, tmp_path) + ds.create_index("vector", index_type="IVF_FLAT", num_partitions=4, metric="hamming") + stats = ds.stats.index_stats("vector_idx") + assert stats["indices"][0]["metric_type"] == "hamming" + assert stats["index_type"] == "IVF_FLAT" + + query = np.random.randint(0, 256, dim // 8).astype(np.uint8) + ds.to_table( + nearest={ + "column": "vector", + "q": query, + "k": 10, + "metric": "hamming", + } + ) + + def test_create_ivf_hnsw_pq_index(dataset, tmp_path): assert not dataset.has_index ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") @@ -410,11 +555,57 @@ def test_create_ivf_hnsw_sq_index(dataset, tmp_path): assert ann_ds.list_indices()[0]["fields"] == ["vector"] +def test_multivec_ann(indexed_multivec_dataset: lance.LanceDataset): + query = np.random.rand(5, 128) + results = indexed_multivec_dataset.scanner( + nearest={"column": "vector", "q": query, "k": 100} + ).to_table() + assert results.num_rows == 100 + assert results["vector"].type == pa.list_(pa.list_(pa.float32(), 128)) + assert len(results["vector"][0]) == 5 + + # query with single vector also works + query = np.random.rand(128) + results = indexed_multivec_dataset.to_table( + nearest={"column": "vector", "q": query, "k": 100} + ) + # we don't verify the number of results here, + # because for multivector, it's not guaranteed to return k results + assert results["vector"].type == pa.list_(pa.list_(pa.float32(), 128)) + assert len(results["vector"][0]) == 5 + + query = [query, query] + doubled_results = indexed_multivec_dataset.to_table( + nearest={"column": "vector", "q": query, "k": 100} + ) + assert len(results) == len(doubled_results) + for i in range(len(results)): + assert ( + results["_distance"][i].as_py() * 2 + == doubled_results["_distance"][i].as_py() + ) + + # query with a vector that dim not match + query = np.random.rand(256) + with pytest.raises(ValueError, match="does not match index column size"): + indexed_multivec_dataset.to_table( + nearest={"column": "vector", "q": query, "k": 100} + ) + + # query with a list of vectors that some dim not match + query = [np.random.rand(128)] * 5 + [np.random.rand(256)] + with pytest.raises(ValueError, match="All query vectors must have the same length"): + indexed_multivec_dataset.to_table( + nearest={"column": "vector", "q": query, "k": 100} + ) + + def test_pre_populated_ivf_centroids(dataset, tmp_path: Path): centroids = np.random.randn(5, 128).astype(np.float32) # IVF5 dataset_with_index = dataset.create_index( ["vector"], index_type="IVF_PQ", + metric="cosine", ivf_centroids=centroids, num_partitions=5, num_sub_vectors=8, @@ -439,7 +630,7 @@ def test_pre_populated_ivf_centroids(dataset, tmp_path: Path): "index_type": "IVF_PQ", "uuid": index_uuid, "uri": expected_filepath, - "metric_type": "l2", + "metric_type": "cosine", "num_partitions": 5, "sub_index": { "dimension": 128, @@ -447,6 +638,7 @@ def test_pre_populated_ivf_centroids(dataset, tmp_path: Path): "metric_type": "l2", "nbits": 8, "num_sub_vectors": 8, + "transposed": True, }, } @@ -467,6 +659,7 @@ def test_pre_populated_ivf_centroids(dataset, tmp_path: Path): idx_stats = actual_statistics["indices"][0] partitions = idx_stats.pop("partitions") idx_stats.pop("centroids") + idx_stats.pop("loss") assert idx_stats == expected_statistics assert len(partitions) == 5 partition_keys = {"size"} @@ -626,8 +819,8 @@ def has_target(target, results): def check_index(has_knn_combined, delete_has_happened): for query in sample_queries: - results = dataset.to_table(nearest=query).column("vector") - assert has_target(query["q"], results) + results = dataset.to_table(nearest=query) + assert has_target(query["q"], results["vector"]) plan = dataset.scanner(nearest=query).explain_plan(verbose=True) assert ("KNNVectorDistance" in plan) == has_knn_combined for query in sample_delete_queries: @@ -943,3 +1136,94 @@ def test_optimize_indices(indexed_dataset): indexed_dataset.optimize.optimize_indices(num_indices_to_merge=0) indices = indexed_dataset.list_indices() assert len(indices) == 2 + + +def test_retrain_indices(indexed_dataset): + data = create_table() + indexed_dataset = lance.write_dataset(data, indexed_dataset.uri, mode="append") + indices = indexed_dataset.list_indices() + assert len(indices) == 1 + + indexed_dataset.optimize.optimize_indices(num_indices_to_merge=0) + indices = indexed_dataset.list_indices() + assert len(indices) == 2 + + stats = indexed_dataset.stats.index_stats("vector_idx") + centroids = stats["indices"][0]["centroids"] + delta_centroids = stats["indices"][1]["centroids"] + assert centroids == delta_centroids + + indexed_dataset.optimize.optimize_indices(retrain=True) + new_centroids = indexed_dataset.stats.index_stats("vector_idx")["indices"][0][ + "centroids" + ] + indices = indexed_dataset.list_indices() + assert len(indices) == 1 + assert centroids != new_centroids + + +def test_no_include_deleted_rows(indexed_dataset): + with pytest.raises(ValueError, match="Cannot include deleted rows"): + indexed_dataset.to_table( + nearest={ + "column": "vector", + "q": np.random.randn(128), + "k": 10, + }, + with_row_id=True, + include_deleted_rows=True, + ) + + +def test_drop_indices(indexed_dataset): + idx_name = indexed_dataset.list_indices()[0]["name"] + + indexed_dataset.drop_index(idx_name) + indices = indexed_dataset.list_indices() + assert len(indices) == 0 + + test_vec = ( + indexed_dataset.take([0], columns=["vector"]).column("vector").to_pylist()[0] + ) + + # make sure we can still search the column (will do flat search) + results = indexed_dataset.to_table( + nearest={ + "column": "vector", + "q": test_vec, + "k": 15, + "nprobes": 1, + }, + ) + + assert len(results) == 15 + + +def test_read_partition(indexed_dataset): + idx_name = indexed_dataset.list_indices()[0]["name"] + reader = VectorIndexReader(indexed_dataset, idx_name) + + num_rows = indexed_dataset.count_rows() + row_sum = 0 + for part_id in range(reader.num_partitions()): + res = reader.read_partition(part_id) + row_sum += res.num_rows + assert "_rowid" in res.column_names + assert row_sum == num_rows + + row_sum = 0 + for part_id in range(reader.num_partitions()): + res = reader.read_partition(part_id, with_vector=True) + row_sum += res.num_rows + pq_column = res["__pq_code"] + assert "_rowid" in res.column_names + assert pq_column.type == pa.list_(pa.uint8(), 16) + assert row_sum == num_rows + + # error tests + with pytest.raises(IndexError, match="out of range"): + reader.read_partition(reader.num_partitions() + 1) + + with pytest.raises(ValueError, match="not vector index"): + indexed_dataset.create_scalar_index("id", index_type="BTREE") + VectorIndexReader(indexed_dataset, "id_idx") diff --git a/python/python/tests/torch_tests/test_data.py b/python/python/tests/torch_tests/test_data.py index 02f424fdaa5..9c3e92caea0 100644 --- a/python/python/tests/torch_tests/test_data.py +++ b/python/python/tests/torch_tests/test_data.py @@ -277,3 +277,49 @@ def test_convert_int_tensors(tmp_path: Path, dtype): first = next(iter(torch_ds)) assert first["vec"].dtype == torch.uint8 if dtype == np.uint8 else torch.int64 assert first["vec"].shape == (4, 32) + + +def test_blob_api(tmp_path: Path): + ints = pa.array(range(100), type=pa.int64()) + vals = pa.array([b"0" * 1024 for _ in range(100)], pa.large_binary()) + schema = pa.schema( + [ + pa.field("int", ints.type), + pa.field( + "val", pa.large_binary(), metadata={"lance-encoding:blob": "true"} + ), + ] + ) + tbl = pa.Table.from_arrays([ints, vals], schema=schema) + + ds = lance.write_dataset(tbl, tmp_path / "data.lance") + torch_ds = LanceDataset( + ds, + batch_size=4, + ) + with pytest.raises(NotImplementedError): + next(iter(torch_ds)) + + def to_tensor_fn(batch, *args, **kwargs): + ints = torch.tensor(batch["int"].to_numpy()) + vals = [] + for blob in batch["val"]: + blob.seek(100) + data = blob.read(100) + tensor = torch.tensor(np.frombuffer(data, dtype=np.uint8)) + vals.append(tensor) + + # vals.append(torch.tensor(blob)) + vals = torch.stack(vals) + return {"int": ints, "val": vals} + + torch_ds = LanceDataset( + ds, + batch_size=4, + to_tensor_fn=to_tensor_fn, + ) + first = next(iter(torch_ds)) + assert first["int"].dtype == torch.int64 + assert first["int"].shape == (4,) + assert first["val"].dtype == torch.uint8 + assert first["val"].shape == (4, 100) diff --git a/python/src/arrow.rs b/python/src/arrow.rs index bf3fb5f68c4..d7e30c01a63 100644 --- a/python/src/arrow.rs +++ b/python/src/arrow.rs @@ -33,7 +33,7 @@ impl BFloat16 { } #[classmethod] - fn from_bytes(_cls: &PyType, bytes: &[u8]) -> PyResult { + fn from_bytes(_cls: &Bound<'_, PyType>, bytes: &[u8]) -> PyResult { if bytes.len() != 2 { PyValueError::new_err(format!( "BFloat16::from_bytes: expected 2 bytes, got {}", diff --git a/python/src/datagen.rs b/python/src/datagen.rs index c23949b2031..b0a3c4e1b44 100644 --- a/python/src/datagen.rs +++ b/python/src/datagen.rs @@ -2,7 +2,11 @@ use arrow::pyarrow::PyArrowType; use arrow_array::RecordBatch; use arrow_schema::Schema; use lance_datagen::{BatchCount, ByteCount}; -use pyo3::{pyfunction, types::PyModule, wrap_pyfunction, PyResult, Python}; +use pyo3::{ + pyfunction, + types::{PyModule, PyModuleMethods}, + wrap_pyfunction, Bound, PyResult, Python, +}; const DEFAULT_BATCH_SIZE_BYTES: u64 = 32 * 1024; const DEFAULT_BATCH_COUNT: u32 = 4; @@ -13,6 +17,7 @@ pub fn is_datagen_supported() -> bool { } #[pyfunction] +#[pyo3(signature=(schema, batch_count=None, bytes_in_batch=None))] pub fn rand_batches( schema: PyArrowType, batch_count: Option, @@ -35,10 +40,10 @@ pub fn rand_batches( .collect::>>>() } -pub fn register_datagen(py: Python, m: &PyModule) -> PyResult<()> { +pub fn register_datagen(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let datagen = PyModule::new(py, "datagen")?; datagen.add_wrapped(wrap_pyfunction!(is_datagen_supported))?; datagen.add_wrapped(wrap_pyfunction!(rand_batches))?; - m.add_submodule(datagen)?; + m.add_submodule(&datagen)?; Ok(()) } diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 40636558bef..270da8f2803 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1,52 +1,50 @@ -// Copyright 2023 Lance Developers. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors use std::collections::HashMap; use std::str; use std::sync::Arc; +use arrow::array::AsArray; +use arrow::datatypes::UInt8Type; use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::pyarrow::*; -use arrow_array::{Float32Array, RecordBatch, RecordBatchReader}; +use arrow_array::Array; +use arrow_array::{make_array, RecordBatch, RecordBatchReader}; use arrow_data::ArrayData; use arrow_schema::{DataType, Schema as ArrowSchema}; use async_trait::async_trait; use blob::LanceBlobFile; use chrono::Duration; - -use arrow_array::Array; use futures::{StreamExt, TryFutureExt}; + use lance::dataset::builder::DatasetBuilder; use lance::dataset::refs::{Ref, TagContents}; -use lance::dataset::scanner::MaterializationStyle; -use lance::dataset::transaction::{ - RewriteGroup as LanceRewriteGroup, RewrittenIndex as LanceRewrittenIndex, Transaction, +use lance::dataset::scanner::{ + DatasetRecordBatchStream, ExecutionStatsCallback, MaterializationStyle, }; +use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt}; use lance::dataset::{ - fragment::FileFragment as LanceFileFragment, progress::WriteFragmentProgress, - scanner::Scanner as LanceScanner, transaction::Operation as LanceOperation, + fragment::FileFragment as LanceFileFragment, + progress::WriteFragmentProgress, + scanner::Scanner as LanceScanner, + transaction::{Operation, Transaction}, Dataset as LanceDataset, MergeInsertBuilder as LanceMergeInsertBuilder, ReadParams, UpdateBuilder, Version, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, WriteMode, WriteParams, }; use lance::dataset::{ - BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination, + BatchInfo, BatchUDF, CommitBuilder, MergeStats, NewColumnTransform, UDFCheckpointStore, + WriteDestination, }; use lance::dataset::{ColumnAlteration, ProjectionRequest}; +use lance::index::vector::utils::get_vector_type; use lance::index::{vector::VectorIndexParams, DatasetIndexInternalExt}; use lance_arrow::as_fixed_size_list_array; -use lance_core::datatypes::Schema; +use lance_index::metrics::NoOpMetricsCollector; +use lance_index::scalar::inverted::query::{ + BoostQuery, FtsQuery, MatchQuery, MultiMatchQuery, Operator, PhraseQuery, +}; use lance_index::scalar::InvertedIndexParams; use lance_index::{ optimize::OptimizeOptions, @@ -60,25 +58,28 @@ use lance_index::{ use lance_io::object_store::ObjectStoreParams; use lance_linalg::distance::MetricType; use lance_table::format::Fragment; -use lance_table::format::Index; use lance_table::io::commit::CommitHandler; +use log::error; use object_store::path::Path; -use pyo3::exceptions::{PyNotImplementedError, PyStopIteration, PyTypeError}; -use pyo3::types::{PyBytes, PyInt, PyList, PySet, PyString, PyTuple}; +use pyo3::exceptions::{PyStopIteration, PyTypeError}; +use pyo3::types::{PyBytes, PyInt, PyList, PySet, PyString}; use pyo3::{ exceptions::{PyIOError, PyKeyError, PyValueError}, + pybacked::PyBackedStr, pyclass, types::{IntoPyDict, PyDict}, PyObject, PyResult, }; -use pyo3::{intern, prelude::*}; -use snafu::{location, Location}; -use uuid::Uuid; +use pyo3::{prelude::*, IntoPyObjectExt}; +use snafu::location; use crate::error::PythonErrorExt; -use crate::fragment::{FileFragment, FragmentMetadata}; +use crate::file::object_store_from_uri_or_path; +use crate::fragment::FileFragment; +use crate::scanner::ScanStatistics; use crate::schema::LanceSchema; use crate::session::Session; +use crate::utils::PyLance; use crate::RT; use crate::{LanceReader, Scanner}; @@ -89,32 +90,12 @@ pub mod blob; pub mod cleanup; pub mod commit; pub mod optimize; +pub mod stats; const DEFAULT_NPROBS: usize = 1; const DEFAULT_INDEX_CACHE_SIZE: usize = 256; const DEFAULT_METADATA_CACHE_SIZE: usize = 256; -#[pyclass(name = "_Operation", module = "_lib")] -#[derive(Clone)] -pub struct Operation(LanceOperation); - -fn into_fragments(fragments: Vec) -> Vec { - fragments - .into_iter() - .map(|f| f.inner) - .collect::>() -} - -fn convert_schema(arrow_schema: &ArrowSchema) -> PyResult { - // Note: the field ids here are wrong. - Schema::try_from(arrow_schema).map_err(|e| { - PyValueError::new_err(format!( - "Failed to convert Arrow schema to Lance schema: {}", - e - )) - }) -} - fn convert_reader(reader: &Bound) -> PyResult> { let py = reader.py(); if reader.is_instance_of::() { @@ -139,22 +120,23 @@ pub struct MergeInsertBuilder { #[pymethods] impl MergeInsertBuilder { #[new] - pub fn new(dataset: &PyAny, on: &PyAny) -> PyResult { + pub fn new(dataset: &Bound<'_, PyAny>, on: &Bound<'_, PyAny>) -> PyResult { let dataset: Py = dataset.extract()?; let ds = dataset.borrow(on.py()).ds.clone(); // Either a single string, which we put in a vector or an iterator // of strings, which we collect into a vector - let on = PyAny::downcast::(on) + let on = on + .downcast::() .map(|val| vec![val.to_string()]) .or_else(|_| { - let iterator = on.iter().map_err(|_| { + let iterator = on.try_iter().map_err(|_| { PyTypeError::new_err( "The `on` argument to merge_insert must be a str or iterable of str", ) })?; let mut keys = Vec::new(); for key in iterator { - keys.push(PyAny::downcast::(key?)?.to_string()); + keys.push(key?.downcast::()?.to_string()); } PyResult::Ok(keys) })?; @@ -170,6 +152,7 @@ impl MergeInsertBuilder { Ok(Self { builder, dataset }) } + #[pyo3(signature=(condition=None))] pub fn when_matched_update_all<'a>( mut slf: PyRefMut<'a, Self>, condition: Option<&str>, @@ -190,6 +173,7 @@ impl MergeInsertBuilder { Ok(slf) } + #[pyo3(signature=(expr=None))] pub fn when_not_matched_by_source_delete<'a>( mut slf: PyRefMut<'a, Self>, expr: Option<&str>, @@ -214,185 +198,51 @@ impl MergeInsertBuilder { .try_build() .map_err(|err| PyValueError::new_err(err.to_string()))?; - let new_self = RT + let (new_dataset, stats) = RT .spawn(Some(py), job.execute_reader(new_data))? .map_err(|err| PyIOError::new_err(err.to_string()))?; - let dataset = self.dataset.as_ref(py); - - dataset.borrow_mut().ds = new_self.0; - let merge_stats = new_self.1; - let merge_dict = PyDict::new(py); - merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?; - merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?; - merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?; - - Ok(merge_dict.into()) - } -} - -#[pyclass(name = "_RewriteGroup", module = "_lib")] -#[derive(Clone)] -pub struct RewriteGroup(LanceRewriteGroup); - -#[pymethods] -impl RewriteGroup { - #[new] - pub fn new(old_fragments: Vec, new_fragments: Vec) -> Self { - let old_fragments = into_fragments(old_fragments); - let new_fragments = into_fragments(new_fragments); - Self(LanceRewriteGroup { - old_fragments, - new_fragments, - }) - } -} - -#[pyclass(name = "_RewrittenIndex", module = "_lib")] -#[derive(Clone)] -pub struct RewrittenIndex(LanceRewrittenIndex); - -#[pymethods] -impl RewrittenIndex { - #[new] - pub fn new(old_index: String, new_index: String) -> PyResult { - let old_id: Uuid = old_index - .parse() - .map_err(|e: uuid::Error| PyValueError::new_err(e.to_string()))?; - let new_id: Uuid = new_index - .parse() - .map_err(|e: uuid::Error| PyValueError::new_err(e.to_string()))?; - Ok(Self(LanceRewrittenIndex { old_id, new_id })) - } -} + let dataset = self.dataset.bind(py); -#[pymethods] -impl Operation { - fn __repr__(&self) -> String { - format!("{:?}", self.0) - } + dataset.borrow_mut().ds = new_dataset; - #[staticmethod] - fn overwrite( - schema: PyArrowType, - fragments: Vec, - ) -> PyResult { - let schema = convert_schema(&schema.0)?; - let fragments = into_fragments(fragments); - let op = LanceOperation::Overwrite { - fragments, - schema, - config_upsert_values: None, - }; - Ok(Self(op)) + Ok(Self::build_stats(&stats, py)?.into()) } - #[staticmethod] - fn append(fragments: Vec) -> PyResult { - let fragments = into_fragments(fragments); - let op = LanceOperation::Append { fragments }; - Ok(Self(op)) - } + pub fn execute_uncommitted<'a>( + &mut self, + new_data: &Bound<'a, PyAny>, + ) -> PyResult<(PyLance, Bound<'a, PyDict>)> { + let py = new_data.py(); + let new_data = convert_reader(new_data)?; - #[staticmethod] - fn delete( - updated_fragments: Vec, - deleted_fragment_ids: Vec, - predicate: String, - ) -> PyResult { - let updated_fragments = into_fragments(updated_fragments); - let op = LanceOperation::Delete { - updated_fragments, - deleted_fragment_ids, - predicate, - }; - Ok(Self(op)) - } + let job = self + .builder + .try_build() + .map_err(|err| PyValueError::new_err(err.to_string()))?; - #[staticmethod] - fn merge(fragments: Vec, schema: LanceSchema) -> PyResult { - let schema = schema.0; - let fragments = into_fragments(fragments); - let op = LanceOperation::Merge { fragments, schema }; - Ok(Self(op)) - } + let (transaction, stats) = RT + .spawn(Some(py), job.execute_uncommitted(new_data))? + .map_err(|err| PyIOError::new_err(err.to_string()))?; - #[staticmethod] - fn restore(version: u64) -> PyResult { - let op = LanceOperation::Restore { version }; - Ok(Self(op)) - } + let stats = Self::build_stats(&stats, py)?; - #[staticmethod] - fn rewrite( - groups: Vec, - rewritten_indices: Vec, - ) -> PyResult { - let groups = groups.into_iter().map(|g| g.0).collect(); - let rewritten_indices = rewritten_indices.into_iter().map(|r| r.0).collect(); - let op = LanceOperation::Rewrite { - groups, - rewritten_indices, - }; - Ok(Self(op)) + Ok((PyLance(transaction), stats)) } +} - #[staticmethod] - fn create_index( - uuid: String, - name: String, - fields: Vec, - dataset_version: u64, - fragment_ids: &PySet, - ) -> PyResult { - let fragment_ids: Vec = fragment_ids - .iter() - .map(|item| item.extract::()) - .collect::>>()?; - let new_indices = vec![Index { - uuid: Uuid::parse_str(&uuid).map_err(|e| PyValueError::new_err(e.to_string()))?, - name, - fields, - dataset_version, - fragment_bitmap: Some(fragment_ids.into_iter().collect()), - // TODO: we should use lance::dataset::Dataset::commit_existing_index once - // we have a way to determine index details from an existing index. - index_details: None, - }]; - let op = LanceOperation::CreateIndex { - new_indices, - removed_indices: vec![], - }; - Ok(Self(op)) - } - - /// Convert to a pydict that can be used as kwargs into the Operation dataclasses - fn to_dict<'a>(&self, py: Python<'a>) -> PyResult> { - let dict = PyDict::new_bound(py); - match &self.0 { - LanceOperation::Append { fragments } => { - let fragments = fragments - .iter() - .cloned() - .map(FragmentMetadata::new) - .map(|f| f.into_py(py)) - .collect::>(); - dict.set_item("fragments", fragments).unwrap(); - } - _ => { - return Err(PyNotImplementedError::new_err(format!( - "Operation.to_dict is not implemented for this operation: {:?}", - self.0 - ))); - } - } - +impl MergeInsertBuilder { + fn build_stats<'a>(stats: &MergeStats, py: Python<'a>) -> PyResult> { + let dict = PyDict::new(py); + dict.set_item("num_inserted_rows", stats.num_inserted_rows)?; + dict.set_item("num_updated_rows", stats.num_updated_rows)?; + dict.set_item("num_deleted_rows", stats.num_deleted_rows)?; Ok(dict) } } -pub fn transforms_from_python(transforms: &PyAny) -> PyResult { - if let Ok(transforms) = transforms.extract::<&PyDict>() { +pub fn transforms_from_python(transforms: &Bound<'_, PyAny>) -> PyResult { + if let Ok(transforms) = transforms.downcast::() { let expressions = transforms .iter() .map(|(k, v)| { @@ -410,7 +260,7 @@ pub fn transforms_from_python(transforms: &PyAny) -> PyResult = transforms.getattr("cache")?.extract()?; let result_checkpoint = result_checkpoint.map(|c| PyBatchUDFCheckpointWrapper { inner: c }); - let udf_obj = transforms.to_object(transforms.py()); + let udf_obj = transforms.into_py_any(transforms.py())?; let mapper = move |batch: &RecordBatch| -> lance::Result { Python::with_gil(|py| { let py_batch: PyArrowType = PyArrowType(batch.clone()); @@ -448,6 +298,7 @@ pub struct Dataset { impl Dataset { #[allow(clippy::too_many_arguments)] #[new] + #[pyo3(signature=(uri, version=None, block_size=None, index_cache_size=None, metadata_cache_size=None, commit_handler=None, storage_options=None, manifest=None))] fn new( py: Python, uri: String, @@ -476,11 +327,11 @@ impl Dataset { let mut builder = DatasetBuilder::from_uri(&uri).with_read_params(params); if let Some(ver) = version { - if let Ok(i) = ver.downcast::(py) { + if let Ok(i) = ver.downcast_bound::(py) { let v: u64 = i.extract()?; builder = builder.with_version(v); - } else if let Ok(v) = ver.downcast::(py) { - let t: &str = v.extract()?; + } else if let Ok(v) = ver.downcast_bound::(py) { + let t: &str = &v.to_string_lossy(); builder = builder.with_tag(t); } else { return Err(PyIOError::new_err( @@ -511,6 +362,11 @@ impl Dataset { self.clone() } + #[getter(max_field_id)] + fn max_field_id(self_: PyRef<'_, Self>) -> PyResult { + Ok(self_.ds.manifest().max_field_id()) + } + #[getter(schema)] fn schema(self_: PyRef<'_, Self>) -> PyResult { let arrow_schema = ArrowSchema::from(self_.ds.schema()); @@ -522,6 +378,32 @@ impl Dataset { LanceSchema(self_.ds.schema().clone()) } + fn replace_schema_metadata(&mut self, metadata: HashMap) -> PyResult<()> { + let mut new_self = self.ds.as_ref().clone(); + RT.block_on(None, new_self.replace_schema_metadata(metadata))? + .map_err(|err| PyIOError::new_err(err.to_string()))?; + self.ds = Arc::new(new_self); + Ok(()) + } + + fn replace_field_metadata( + &mut self, + field_name: &str, + metadata: HashMap, + ) -> PyResult<()> { + let mut new_self = self.ds.as_ref().clone(); + let field = new_self + .schema() + .field(field_name) + .ok_or_else(|| PyKeyError::new_err(format!("Field \"{}\" not found", field_name)))?; + let new_field_meta: HashMap> = + HashMap::from_iter(vec![(field.id as u32, metadata)]); + RT.block_on(None, new_self.replace_field_metadata(new_field_meta))? + .map_err(|err| PyIOError::new_err(err.to_string()))?; + self.ds = Arc::new(new_self); + Ok(()) + } + #[getter(data_storage_version)] fn data_storage_version(&self) -> PyResult { Ok(self.ds.manifest().data_storage_format.version.clone()) @@ -563,23 +445,19 @@ impl Dataset { let idx_schema = schema.project_by_ids(idx.fields.as_slice(), true); - let is_vector = idx_schema - .fields - .iter() - .any(|f| matches!(f.data_type(), DataType::FixedSizeList(_, _))); - - let idx_type = if is_vector { - IndexType::Vector - } else { - let ds = self_.ds.clone(); - RT.block_on(Some(self_.py()), async { - let scalar_idx = ds - .open_scalar_index(&idx_schema.fields[0].name, &idx.uuid.to_string()) + let ds = self_.ds.clone(); + let idx_type = RT + .block_on(Some(self_.py()), async { + let idx = ds + .open_generic_index( + &idx_schema.fields[0].name, + &idx.uuid.to_string(), + &NoOpMetricsCollector, + ) .await?; - Ok::<_, lance::Error>(scalar_idx.index_type()) + Ok::<_, lance::Error>(idx.index_type()) })? - .map_err(|e| PyIOError::new_err(e.to_string()))? - }; + .map_err(|e| PyIOError::new_err(e.to_string()))?; let field_names = idx_schema .fields @@ -603,12 +481,13 @@ impl Dataset { dict.set_item("fields", field_names).unwrap(); dict.set_item("version", idx.dataset_version).unwrap(); dict.set_item("fragment_ids", fragment_set).unwrap(); - Ok(dict.to_object(py)) + dict.into_py_any(py) }) .collect::>>() } #[allow(clippy::too_many_arguments)] + #[pyo3(signature=(columns=None, columns_with_transform=None, filter=None, prefilter=None, limit=None, offset=None, nearest=None, batch_size=None, io_buffer_size=None, batch_readahead=None, fragment_readahead=None, scan_in_order=None, fragments=None, with_row_id=None, with_row_address=None, use_stats=None, substrait_filter=None, fast_search=None, full_text_query=None, late_materialization=None, use_scalar_index=None, include_deleted_rows=None, scan_stats_callback=None))] fn scanner( self_: PyRef<'_, Self>, columns: Option>, @@ -629,9 +508,11 @@ impl Dataset { use_stats: Option, substrait_filter: Option>, fast_search: Option, - full_text_query: Option<&PyDict>, + full_text_query: Option<&Bound<'_, PyAny>>, late_materialization: Option, use_scalar_index: Option, + include_deleted_rows: Option, + scan_stats_callback: Option<&Bound<'_, PyAny>>, ) -> PyResult { let mut scanner: LanceScanner = self_.ds.scan(); match (columns, columns_with_transform) { @@ -663,27 +544,65 @@ impl Dataset { .map_err(|err| PyValueError::new_err(err.to_string()))?; } if let Some(full_text_query) = full_text_query { - let query = full_text_query - .get_item("query")? - .ok_or_else(|| PyKeyError::new_err("Need column for full text search"))? - .to_string(); - let columns = if let Some(columns) = full_text_query.get_item("columns")? { - if columns.is_none() { - None + let fts_query = if let Ok(full_text_query) = full_text_query.downcast::() { + let mut query = full_text_query + .get_item("query")? + .ok_or_else(|| PyKeyError::new_err("query must be specified"))? + .to_string(); + let columns = if let Some(columns) = full_text_query.get_item("columns")? { + if columns.is_none() { + None + } else { + Some( + columns + .downcast::()? + .iter() + .map(|c| c.extract::()) + .collect::>>()?, + ) + } } else { - Some( - PyAny::downcast::(columns)? - .iter() - .map(|c| c.extract::()) - .collect::>>()?, - ) + None + }; + + let is_phrase = query.len() >= 2 && query.starts_with('"') && query.ends_with('"'); + let is_multi_match = columns.as_ref().map(|cols| cols.len() > 1).unwrap_or(false); + + if is_phrase { + // Remove the surrounding quotes for phrase queries + query = query[1..query.len() - 1].to_string(); } + + let query: FtsQuery = match (is_phrase, is_multi_match) { + (false, _) => MatchQuery::new(query).into(), + (true, false) => PhraseQuery::new(query).into(), + (true, true) => { + return Err(PyValueError::new_err( + "Phrase queries cannot be used with multiple columns.", + )); + } + }; + let mut query = FullTextSearchQuery::new_query(query); + if let Some(cols) = columns { + query = query.with_columns(&cols).map_err(|e| { + PyValueError::new_err(format!( + "Failed to set full text search columns: {}", + e + )) + })?; + } + query + } else if let Ok(query) = full_text_query.downcast::() { + let query = query.borrow(); + FullTextSearchQuery::new_query(query.inner.clone()) } else { - None + return Err(PyValueError::new_err( + "query must be a string or a Query object", + )); }; - let full_text_query = FullTextSearchQuery::new(query).columns(columns); + scanner - .full_text_search(full_text_query) + .full_text_search(fts_query) .map_err(|err| PyValueError::new_err(err.to_string()))?; } if let Some(f) = substrait_filter { @@ -729,6 +648,10 @@ impl Dataset { scanner.fast_search(); } + if let Some(true) = include_deleted_rows { + scanner.include_deleted_rows(); + } + if let Some(fragments) = fragments { let fragments = fragments .into_iter() @@ -740,6 +663,11 @@ impl Dataset { scanner.with_fragments(fragments); } + if let Some(scan_stats_callback) = scan_stats_callback { + let callback = Self::make_scan_stats_callback(scan_stats_callback.clone())?; + scanner.scan_stats_callback(callback); + } + if let Some(late_materialization) = late_materialization { if let Ok(style_as_bool) = late_materialization.extract::(self_.py()) { if style_as_bool { @@ -773,7 +701,7 @@ impl Dataset { .get_item("q")? .ok_or_else(|| PyKeyError::new_err("Need q for nearest"))?; let data = ArrayData::from_pyarrow_bound(&qval)?; - let q = Float32Array::from(data); + let q = make_array(data); let k: usize = if let Some(k) = nearest.get_item("k")? { if k.is_none() { @@ -839,8 +767,19 @@ impl Dataset { None }; + let (_, element_type) = get_vector_type(self_.ds.schema(), &column) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + let scanner = match element_type { + DataType::UInt8 => { + let q = arrow::compute::cast(&q, &DataType::UInt8).map_err(|e| { + PyValueError::new_err(format!("Failed to cast q to binary vector: {}", e)) + })?; + let q = q.as_primitive::(); + scanner.nearest(&column, q, k) + } + _ => scanner.nearest(&column, &q, k), + }; scanner - .nearest(column.as_str(), &q, k) .map(|s| { let mut s = s.nprobs(nprobes); if let Some(factor) = refine_factor { @@ -862,12 +801,14 @@ impl Dataset { Ok(Scanner::new(scan)) } + #[pyo3(signature=(filter=None))] fn count_rows(&self, filter: Option) -> PyResult { RT.runtime .block_on(self.ds.count_rows(filter)) .map_err(|err| PyIOError::new_err(err.to_string())) } + #[pyo3(signature=(row_indices, columns = None, columns_with_transform = None))] fn take( self_: PyRef<'_, Self>, row_indices: Vec, @@ -894,6 +835,7 @@ impl Dataset { batch.to_pyarrow(self_.py()) } + #[pyo3(signature=(row_indices, columns = None, columns_with_transform = None))] fn take_rows( self_: PyRef<'_, Self>, row_indices: Vec, @@ -981,7 +923,7 @@ impl Dataset { Ok(PyArrowType(Box::new(LanceReader::from_stream(stream)))) } - fn alter_columns(&mut self, alterations: &PyList) -> PyResult<()> { + fn alter_columns(&mut self, alterations: &Bound<'_, PyList>) -> PyResult<()> { let alterations = alterations .iter() .map(|obj| { @@ -1066,7 +1008,12 @@ impl Dataset { Ok(()) } - fn update(&mut self, updates: &PyDict, predicate: Option<&str>) -> PyResult { + #[pyo3(signature=(updates, predicate=None))] + fn update( + &mut self, + updates: &Bound<'_, PyDict>, + predicate: Option<&str>, + ) -> PyResult { let mut builder = UpdateBuilder::new(self.ds.clone()); if let Some(predicate) = predicate { builder = builder @@ -1075,11 +1022,11 @@ impl Dataset { } for (key, value) in updates { - let column: &str = key.extract()?; - let expr: &str = value.extract()?; + let column: PyBackedStr = key.downcast::()?.clone().try_into()?; + let expr: PyBackedStr = value.downcast::()?.clone().try_into()?; builder = builder - .set(column, expr) + .set(column, &expr) .map_err(|err| PyValueError::new_err(err.to_string()))?; } @@ -1119,12 +1066,10 @@ impl Dataset { ) .unwrap(); let tup: Vec<(&String, &String)> = v.metadata.iter().collect(); - dict.set_item("metadata", tup.into_py_dict(py)).unwrap(); - dict.to_object(py) + dict.set_item("metadata", tup.into_py_dict(py)?).unwrap(); + dict.into_py_any(py) }) - .collect::>() - .into_iter() - .collect(); + .collect::>>()?; Ok(pyvers) }) } @@ -1140,11 +1085,11 @@ impl Dataset { } fn checkout_version(&self, py: Python, version: PyObject) -> PyResult { - if let Ok(i) = version.downcast::(py) { + if let Ok(i) = version.downcast_bound::(py) { let ref_: u64 = i.extract()?; self._checkout_version(ref_) - } else if let Ok(v) = version.downcast::(py) { - let ref_: &str = v.extract()?; + } else if let Ok(v) = version.downcast_bound::(py) { + let ref_: &str = &v.to_string_lossy(); self._checkout_version(ref_) } else { Err(PyIOError::new_err( @@ -1163,6 +1108,7 @@ impl Dataset { } /// Cleanup old versions from the dataset + #[pyo3(signature = (older_than_micros, delete_unverified = None, error_if_tagged_old_versions = None))] fn cleanup_old_versions( &self, older_than_micros: i64, @@ -1196,10 +1142,9 @@ impl Dataset { let dict = PyDict::new(py); dict.set_item("version", v.version).unwrap(); dict.set_item("manifest_size", v.manifest_size).unwrap(); - dict.to_object(py); - pytags.set_item(k, dict).unwrap(); + pytags.set_item(k, dict.into_py_any(py)?).unwrap(); } - Ok(pytags.to_object(py)) + pytags.into_py_any(py) }) } @@ -1237,7 +1182,7 @@ impl Dataset { } #[pyo3(signature = (**kwargs))] - fn optimize_indices(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { + fn optimize_indices(&mut self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> { let mut new_self = self.ds.as_ref().clone(); let mut options: OptimizeOptions = Default::default(); if let Some(kwargs) = kwargs { @@ -1251,6 +1196,9 @@ impl Dataset { .map_err(|err| PyValueError::new_err(err.to_string()))?, ); } + if let Some(retrain) = kwargs.get_item("retrain")? { + options.retrain = retrain.extract()?; + } } RT.block_on( None, @@ -1263,22 +1211,25 @@ impl Dataset { Ok(()) } + #[pyo3(signature = (columns, index_type, name = None, replace = None, storage_options = None, kwargs = None))] fn create_index( &mut self, - columns: Vec<&str>, + columns: Vec, index_type: &str, name: Option, replace: Option, storage_options: Option>, kwargs: Option<&Bound>, ) -> PyResult<()> { + let columns: Vec<&str> = columns.iter().map(|s| &**s).collect(); let index_type = index_type.to_uppercase(); let idx_type = match index_type.as_str() { "BTREE" => IndexType::Scalar, "BITMAP" => IndexType::Bitmap, + "NGRAM" => IndexType::NGram, "LABEL_LIST" => IndexType::LabelList, "INVERTED" | "FTS" => IndexType::Inverted, - "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, + "IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, _ => { return Err(PyValueError::new_err(format!( "Index type '{index_type}' is not supported." @@ -1293,6 +1244,9 @@ impl Dataset { // Temporary workaround until we add support for auto-detection of scalar index type force_index_type: Some(ScalarIndexType::Bitmap), }), + "NGRAM" => Box::new(ScalarIndexParams { + force_index_type: Some(ScalarIndexType::NGram), + }), "LABEL_LIST" => Box::new(ScalarIndexParams { force_index_type: Some(ScalarIndexType::LabelList), }), @@ -1308,9 +1262,10 @@ impl Dataset { .base_tokenizer(base_tokenizer.extract()?); } if let Some(language) = kwargs.get_item("language")? { - let language = language.extract()?; + let language: PyBackedStr = + language.downcast::()?.clone().try_into()?; params.tokenizer_config = - params.tokenizer_config.language(language).map_err(|e| { + params.tokenizer_config.language(&language).map_err(|e| { PyValueError::new_err(format!( "can't set tokenizer language to {}: {:?}", language, e @@ -1366,6 +1321,20 @@ impl Dataset { Ok(()) } + fn drop_index(&mut self, name: &str) -> PyResult<()> { + let mut new_self = self.ds.as_ref().clone(); + RT.block_on(None, new_self.drop_index(name))? + .infer_error()?; + self.ds = Arc::new(new_self); + + Ok(()) + } + + fn prewarm_index(&self, name: &str) -> PyResult<()> { + RT.block_on(None, self.ds.prewarm_index(name))? + .infer_error() + } + fn count_fragments(&self) -> usize { self.ds.count_fragments() } @@ -1375,6 +1344,12 @@ impl Dataset { .map_err(|err| PyIOError::new_err(err.to_string())) } + fn data_stats(&self) -> PyResult> { + RT.block_on(None, self.ds.calculate_data_stats())? + .infer_error() + .map(PyLance) + } + fn get_fragments(self_: PyRef<'_, Self>) -> PyResult> { let core_fragments = self_.ds.get_fragments(); @@ -1407,13 +1382,58 @@ impl Dataset { Session::new(self.ds.session()) } + #[staticmethod] + #[pyo3(signature = (dest, storage_options = None))] + fn drop(dest: String, storage_options: Option>) -> PyResult<()> { + RT.spawn(None, async move { + let (object_store, path) = + object_store_from_uri_or_path(&dest, storage_options).await?; + object_store + .remove_dir_all(path) + .await + .map_err(|e| PyIOError::new_err(e.to_string())) + })? + } + #[allow(clippy::too_many_arguments)] #[staticmethod] + #[pyo3(signature = (dest, operation, blobs_op=None, read_version = None, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] fn commit( - dest: &Bound, - operation: Operation, + dest: PyWriteDest, + operation: PyLance, + blobs_op: Option>, read_version: Option, - commit_lock: Option<&PyAny>, + commit_lock: Option<&Bound<'_, PyAny>>, + storage_options: Option>, + enable_v2_manifest_paths: Option, + detached: Option, + max_retries: Option, + ) -> PyResult { + let transaction = Transaction::new( + read_version.unwrap_or_default(), + operation.0, + blobs_op.map(|op| op.0), + None, + ); + + Self::commit_transaction( + dest, + PyLance(transaction), + commit_lock, + storage_options, + enable_v2_manifest_paths, + detached, + max_retries, + ) + } + + #[allow(clippy::too_many_arguments)] + #[staticmethod] + #[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] + fn commit_transaction( + dest: PyWriteDest, + transaction: PyLance, + commit_lock: Option<&Bound<'_, PyAny>>, storage_options: Option>, enable_v2_manifest_paths: Option, detached: Option, @@ -1427,22 +1447,16 @@ impl Dataset { ..Default::default() }); - let commit_handler = commit_lock.map(|commit_lock| { - Arc::new(PyCommitLock::new(commit_lock.to_object(commit_lock.py()))) - as Arc - }); - - let dest = if dest.is_instance_of::() { - let dataset: Self = dest.extract()?; - WriteDestination::Dataset(dataset.ds.clone()) - } else { - WriteDestination::Uri(dest.extract()?) - }; - - let transaction = - Transaction::new(read_version.unwrap_or_default(), operation.0, None, None); + let commit_handler = commit_lock + .as_ref() + .map(|commit_lock| { + commit_lock + .into_py_any(commit_lock.py()) + .map(|cl| Arc::new(PyCommitLock::new(cl)) as Arc) + }) + .transpose()?; - let mut builder = CommitBuilder::new(dest) + let mut builder = CommitBuilder::new(dest.as_dest()) .enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false)) .with_detached(detached.unwrap_or(false)) .with_max_retries(max_retries.unwrap_or(20)); @@ -1456,7 +1470,10 @@ impl Dataset { } let ds = RT - .block_on(commit_lock.map(|cl| cl.py()), builder.execute(transaction))? + .block_on( + commit_lock.map(|cl| cl.py()), + builder.execute(transaction.0), + )? .map_err(|err| PyIOError::new_err(err.to_string()))?; let uri = ds.uri().to_string(); @@ -1467,15 +1484,16 @@ impl Dataset { } #[staticmethod] - fn commit_batch<'py>( - dest: &Bound<'py, PyAny>, - transactions: Vec>, - commit_lock: Option<&'py PyAny>, + #[pyo3(signature = (dest, transactions, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))] + fn commit_batch( + dest: PyWriteDest, + transactions: Vec>, + commit_lock: Option<&Bound<'_, PyAny>>, storage_options: Option>, enable_v2_manifest_paths: Option, detached: Option, max_retries: Option, - ) -> PyResult> { + ) -> PyResult<(Self, PyLance)> { let object_store_params = storage_options .as_ref() @@ -1484,20 +1502,15 @@ impl Dataset { ..Default::default() }); - let commit_handler = commit_lock.map(|commit_lock| { - Arc::new(PyCommitLock::new(commit_lock.to_object(commit_lock.py()))) - as Arc - }); - - let py = dest.py(); - let dest = if dest.is_instance_of::() { - let dataset: Dataset = dest.extract()?; - WriteDestination::Dataset(dataset.ds.clone()) - } else { - WriteDestination::Uri(dest.extract()?) - }; + let commit_handler = commit_lock + .map(|commit_lock| { + commit_lock + .into_py_any(commit_lock.py()) + .map(|cl| Arc::new(PyCommitLock::new(cl)) as Arc) + }) + .transpose()?; - let mut builder = CommitBuilder::new(dest) + let mut builder = CommitBuilder::new(dest.as_dest()) .enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false)) .with_detached(detached.unwrap_or(false)) .with_max_retries(max_retries.unwrap_or(20)); @@ -1512,20 +1525,19 @@ impl Dataset { let transactions = transactions .into_iter() - .map(|transaction| extract_transaction(&transaction)) - .collect::>>()?; + .map(|transaction| transaction.0) + .collect(); let res = RT - .block_on(Some(py), builder.execute_batch(transactions))? + .block_on(None, builder.execute_batch(transactions))? .map_err(|err| PyIOError::new_err(err.to_string()))?; let uri = res.dataset.uri().to_string(); let ds = Self { ds: Arc::new(res.dataset), uri, }; - let merged = export_transaction(&res.merged, py)?.to_object(py); - let ds = ds.into_py(py); - Ok(PyTuple::new_bound(py, [ds, merged])) + + Ok((ds, PyLance(res.merged))) } fn validate(&self) -> PyResult<()> { @@ -1541,8 +1553,9 @@ impl Dataset { Ok(()) } - fn drop_columns(&mut self, columns: Vec<&str>) -> PyResult<()> { + fn drop_columns(&mut self, columns: Vec) -> PyResult<()> { let mut new_self = self.ds.as_ref().clone(); + let columns: Vec<_> = columns.iter().map(|s| s.as_str()).collect(); RT.block_on(None, new_self.drop_columns(&columns))? .map_err(|err| match err { lance::Error::InvalidInput { source, .. } => { @@ -1554,9 +1567,10 @@ impl Dataset { Ok(()) } + #[pyo3(signature = (reader, batch_size = None))] fn add_columns_from_reader( &mut self, - reader: &Bound, + reader: &Bound<'_, PyAny>, batch_size: Option, ) -> PyResult<()> { let batches = ArrowArrayStreamReader::from_pyarrow_bound(reader)?; @@ -1575,9 +1589,10 @@ impl Dataset { Ok(()) } + #[pyo3(signature = (transforms, read_columns = None, batch_size = None))] fn add_columns( &mut self, - transforms: &PyAny, + transforms: &Bound<'_, PyAny>, read_columns: Option>, batch_size: Option, ) -> PyResult<()> { @@ -1596,6 +1611,59 @@ impl Dataset { Ok(()) } + + /// Add NULL columns with only ArrowSchema. + #[pyo3(signature = (schema))] + fn add_columns_with_schema(&mut self, schema: PyArrowType) -> PyResult<()> { + let arrow_schema: &ArrowSchema = &schema.0; + let transform = NewColumnTransform::AllNulls(Arc::new(arrow_schema.clone())); + + let mut new_self = self.ds.as_ref().clone(); + let new_self = RT + .spawn(None, async move { + new_self.add_columns(transform, None, None).await?; + Ok(new_self) + })? + .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; + self.ds = Arc::new(new_self); + Ok(()) + } + + #[pyo3(signature = (index_name,partition_id, with_vector=false))] + fn read_index_partition( + &self, + index_name: String, + partition_id: usize, + with_vector: bool, + ) -> PyResult>> { + let stream = RT + .block_on( + None, + self.ds + .read_index_partition(&index_name, partition_id, with_vector), + )? + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let reader = Box::new(LanceReader::from_stream(DatasetRecordBatchStream::new( + stream, + ))); + Ok(PyArrowType(reader)) + } +} + +#[derive(FromPyObject)] +pub enum PyWriteDest { + Dataset(Dataset), + Uri(PyBackedStr), +} + +impl PyWriteDest { + pub fn as_dest(&self) -> WriteDestination<'_> { + match self { + Self::Dataset(ds) => WriteDestination::Dataset(ds.ds.clone()), + Self::Uri(uri) => WriteDestination::Uri(uri), + } + } } impl Dataset { @@ -1620,34 +1688,55 @@ impl Dataset { fn list_tags(&self) -> ::lance::error::Result> { RT.runtime.block_on(self.ds.tags.list()) } + + fn make_scan_stats_callback(callback: Bound<'_, PyAny>) -> PyResult { + if !callback.is_callable() { + return Err(PyValueError::new_err("Callback must be callable")); + } + + let callback = callback.unbind(); + + Ok(Arc::new(move |stats| { + Python::with_gil(|py| { + let stats = ScanStatistics::from_lance(stats); + match callback.call1(py, (stats,)) { + Ok(_) => (), + Err(e) => { + // Don't fail scan if callback fails + error!("Error in scan stats callback: {}", e); + } + } + }); + })) + } } #[pyfunction(name = "_write_dataset")] pub fn write_dataset( - reader: &Bound, - dest: &Bound, - options: &PyDict, + reader: &Bound<'_, PyAny>, + dest: PyWriteDest, + options: &Bound<'_, PyDict>, ) -> PyResult { let params = get_write_params(options)?; let py = options.py(); - let dest = if dest.is_instance_of::() { - let dataset: Dataset = dest.extract()?; - WriteDestination::Dataset(dataset.ds.clone()) - } else { - WriteDestination::Uri(dest.extract()?) - }; let ds = if reader.is_instance_of::() { let scanner: Scanner = reader.extract()?; let batches = RT .block_on(Some(py), scanner.to_reader())? .map_err(|err| PyValueError::new_err(err.to_string()))?; - RT.block_on(Some(py), LanceDataset::write(batches, dest, params))? - .map_err(|err| PyIOError::new_err(err.to_string()))? + RT.block_on( + Some(py), + LanceDataset::write(batches, dest.as_dest(), params), + )? + .map_err(|err| PyIOError::new_err(err.to_string()))? } else { let batches = ArrowArrayStreamReader::from_pyarrow_bound(reader)?; - RT.block_on(Some(py), LanceDataset::write(batches, dest, params))? - .map_err(|err| PyIOError::new_err(err.to_string()))? + RT.block_on( + Some(py), + LanceDataset::write(batches, dest.as_dest(), params), + )? + .map_err(|err| PyIOError::new_err(err.to_string()))? }; Ok(Dataset { uri: ds.uri().to_string(), @@ -1664,16 +1753,16 @@ fn parse_write_mode(mode: &str) -> PyResult { } } -pub fn get_commit_handler(options: &PyDict) -> Option> { - if options.is_none() { +pub fn get_commit_handler(options: &Bound<'_, PyDict>) -> PyResult>> { + Ok(if options.is_none() { None } else if let Ok(Some(commit_handler)) = options.get_item("commit_handler") { Some(Arc::new(PyCommitLock::new( - commit_handler.to_object(options.py()), + commit_handler.into_pyobject(options.py())?.into(), ))) } else { None - } + }) } // Gets a value from the dictionary and attempts to extract it to @@ -1681,7 +1770,10 @@ pub fn get_commit_handler(options: &PyDict) -> Option> { // it were never present in the dictionary. If the value is not // None it will try and parse it and parsing failures will be // returned (e.g. a parsing failure is not considered `None`) -fn get_dict_opt<'a, D: FromPyObject<'a>>(dict: &'a PyDict, key: &str) -> PyResult> { +fn get_dict_opt<'a, 'py, D: FromPyObject<'a>>( + dict: &'a Bound<'py, PyDict>, + key: &str, +) -> PyResult> { let value = dict.get_item(key)?; value .and_then(|v| { @@ -1694,7 +1786,7 @@ fn get_dict_opt<'a, D: FromPyObject<'a>>(dict: &'a PyDict, key: &str) -> PyResul .transpose() } -pub fn get_write_params(options: &PyDict) -> PyResult> { +pub fn get_write_params(options: &Bound<'_, PyDict>) -> PyResult> { let params = if options.is_none() { None } else { @@ -1716,7 +1808,7 @@ pub fn get_write_params(options: &PyDict) -> PyResult> { p.data_storage_version = Some(data_storage_version.parse().infer_error()?); } if let Some(progress) = get_dict_opt::(options, "progress")? { - p.progress = Arc::new(PyWriteProgress::new(progress.to_object(options.py()))); + p.progress = Arc::new(PyWriteProgress::new(progress.into_py_any(options.py())?)); } if let Some(storage_options) = @@ -1728,13 +1820,18 @@ pub fn get_write_params(options: &PyDict) -> PyResult> { }); } + if let Some(enable_move_stable_row_ids) = + get_dict_opt::(options, "enable_move_stable_row_ids")? + { + p.enable_move_stable_row_ids = enable_move_stable_row_ids; + } if let Some(enable_v2_manifest_paths) = get_dict_opt::(options, "enable_v2_manifest_paths")? { p.enable_v2_manifest_paths = enable_v2_manifest_paths; } - p.commit_handler = get_commit_handler(options); + p.commit_handler = get_commit_handler(options)?; Some(p) }; @@ -1871,6 +1968,11 @@ fn prepare_vector_index_params( } match index_type { + "IVF_FLAT" => Ok(Box::new(VectorIndexParams::ivf_flat( + ivf_params.num_partitions, + m_type, + ))), + "IVF_PQ" => Ok(Box::new(VectorIndexParams::with_ivf_pq_params( m_type, ivf_params, pq_params, ))), @@ -2052,58 +2154,76 @@ impl UDFCheckpointStore for PyBatchUDFCheckpointWrapper { } } -/// py_transaction is a dataclass with attributes -/// read_version: int -/// uuid: str -/// operation: LanceOperation.BaseOperation -/// blobs_op: Optional[LanceOperation.BaseOperation] = None -fn extract_transaction(py_transaction: &Bound) -> PyResult { - let py = py_transaction.py(); - let read_version = py_transaction.getattr("read_version")?.extract()?; - let uuid = py_transaction.getattr("uuid")?.extract()?; - let operation: Operation = py_transaction - .getattr("operation")? - .call_method0(intern!(py, "_to_inner"))? - .extract()?; - let operation = operation.0; - let blobs_op: Option = { - let blobs_op: Option> = py_transaction.getattr("blobs_op")?.extract()?; - if let Some(blobs_op) = blobs_op { - Some(blobs_op.call_method0(intern!(py, "_to_inner"))?.extract()?) - } else { - None - } - }; - let blobs_op = blobs_op.map(|op| op.0); - Ok(Transaction { - read_version, - uuid, - operation, - blobs_op, - tag: None, - }) +#[pyclass(name = "PyFullTextQuery")] +#[derive(Debug, Clone)] +pub struct PyFullTextQuery { + pub(crate) inner: FtsQuery, } -// Exports to a pydict of kwargs to instantiation the python Transaction dataclass. -fn export_transaction<'a>( - transaction: &Transaction, - py: Python<'a>, -) -> PyResult> { - let dict = PyDict::new_bound(py); - dict.set_item("read_version", transaction.read_version)?; - dict.set_item("uuid", transaction.uuid.clone())?; - dict.set_item( - "operation", - Operation(transaction.operation.clone()).to_dict(py)?, - )?; - dict.set_item( - "blobs_op", - transaction - .blobs_op - .clone() - .map(Operation) - .map(|op| op.to_dict(py)) - .transpose()?, - )?; - Ok(dict) +#[pymethods] +impl PyFullTextQuery { + #[staticmethod] + #[pyo3(signature = (query, column, boost=1.0, fuzziness=Some(0), max_expansions=50, operator="OR"))] + fn match_query( + query: String, + column: String, + boost: f32, + fuzziness: Option, + max_expansions: usize, + operator: &str, + ) -> PyResult { + Ok(Self { + inner: MatchQuery::new(query) + .with_column(Some(column)) + .with_boost(boost) + .with_fuzziness(fuzziness) + .with_max_expansions(max_expansions) + .with_operator( + Operator::try_from(operator) + .map_err(|e| PyValueError::new_err(format!("Invalid operator: {}", e)))?, + ) + .into(), + }) + } + + #[staticmethod] + #[pyo3(signature = (query, column))] + fn phrase_query(query: String, column: String) -> PyResult { + Ok(Self { + inner: PhraseQuery::new(query).with_column(Some(column)).into(), + }) + } + + #[staticmethod] + #[pyo3(signature = (positive, negative,negative_boost=None))] + fn boost_query(positive: Self, negative: Self, negative_boost: Option) -> PyResult { + Ok(Self { + inner: BoostQuery::new(positive.inner, negative.inner, negative_boost).into(), + }) + } + + #[staticmethod] + #[pyo3(signature = (query, columns, boosts=None, operator="OR"))] + fn multi_match_query( + query: String, + columns: Vec, + boosts: Option>, + operator: &str, + ) -> PyResult { + let q = MultiMatchQuery::try_new(query, columns) + .map_err(|e| PyValueError::new_err(format!("Invalid query: {}", e)))?; + let q = if let Some(boosts) = boosts { + q.try_with_boosts(boosts) + .map_err(|e| PyValueError::new_err(format!("Invalid boosts: {}", e)))? + } else { + q + }; + + let op = Operator::try_from(operator) + .map_err(|e| PyValueError::new_err(format!("Invalid operator: {}", e)))?; + + Ok(Self { + inner: q.with_operator(op).into(), + }) + } } diff --git a/python/src/dataset/blob.rs b/python/src/dataset/blob.rs index 13d47a34e21..9205f337a48 100644 --- a/python/src/dataset/blob.rs +++ b/python/src/dataset/blob.rs @@ -59,7 +59,7 @@ impl LanceBlobFile { pub fn readall<'a>(&'a self, py: Python<'a>) -> PyResult> { let inner = self.inner.clone(); let data = RT.block_on(Some(py), inner.read())?.infer_error()?; - Ok(PyBytes::new_bound(py, &data)) + Ok(PyBytes::new(py, &data)) } pub fn read_into(&self, dst: Bound<'_, PyByteArray>) -> PyResult { diff --git a/python/src/dataset/commit.rs b/python/src/dataset/commit.rs index d34cef41120..635d2ace48d 100644 --- a/python/src/dataset/commit.rs +++ b/python/src/dataset/commit.rs @@ -15,7 +15,7 @@ use std::fmt::Debug; use lance_table::io::commit::{CommitError, CommitLease, CommitLock}; -use snafu::{location, Location}; +use snafu::location; use lance_core::Error; @@ -27,14 +27,14 @@ lazy_static! { py.import("lance") .and_then(|lance| lance.getattr("commit")) .and_then(|commit| commit.getattr("CommitConflictError")) - .map(|error| error.to_object(py)) + .map(|err| err.unbind()) }) }; } fn handle_error(py_err: PyErr, py: Python) -> CommitError { let conflict_err_type = match &*PY_CONFLICT_ERROR { - Ok(err) => err.as_ref(py).get_type(), + Ok(err) => err.bind(py).get_type(), Err(import_error) => { return CommitError::OtherError(Error::Internal { message: format!("Error importing from pylance {}", import_error), @@ -43,7 +43,7 @@ fn handle_error(py_err: PyErr, py: Python) -> CommitError { } }; - if py_err.is_instance(py, conflict_err_type) { + if py_err.is_instance(py, &conflict_err_type) { CommitError::CommitConflict } else { CommitError::OtherError(Error::Internal { diff --git a/python/src/dataset/optimize.rs b/python/src/dataset/optimize.rs index eeac1ea8a86..1fc536194f1 100644 --- a/python/src/dataset/optimize.rs +++ b/python/src/dataset/optimize.rs @@ -23,7 +23,7 @@ use pyo3::{exceptions::PyNotImplementedError, pyclass::CompareOp, types::PyTuple use super::*; -fn parse_compaction_options(options: &PyDict) -> PyResult { +fn parse_compaction_options(options: &Bound<'_, PyDict>) -> PyResult { let mut opts = CompactionOptions::default(); for (key, value) in options.into_iter() { @@ -67,15 +67,13 @@ fn unwrap_dataset(dataset: PyObject) -> PyResult> { Python::with_gil(|py| dataset.getattr(py, "_ds")?.extract::>(py)) } -fn wrap_fragment(py: Python<'_>, fragment: &Fragment) -> PyResult { +fn wrap_fragment<'py>(py: Python<'py>, fragment: &Fragment) -> PyResult> { let fragment_metadata = PyModule::import(py, "lance.fragment")?.getattr("FragmentMetadata")?; let fragment_json = serde_json::to_string(&fragment).map_err(|x| { PyValueError::new_err(format!("failed to serialize fragment metadata: {}", x)) })?; - Ok(fragment_metadata - .call_method1("from_json", (fragment_json,))? - .to_object(py)) + fragment_metadata.call_method1("from_json", (fragment_json,)) } #[pyclass(name = "CompactionMetrics", module = "lance.optimize")] @@ -190,7 +188,7 @@ impl PyCompactionPlan { pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { let state = self.json()?; - let state = PyTuple::new(py, vec![state]).extract()?; + let state = PyTuple::new(py, vec![state])?.extract()?; let from_json = PyModule::import(py, "lance.optimize")? .getattr("CompactionPlan")? .getattr("from_json")? @@ -219,7 +217,7 @@ impl PyCompactionTask { let fragment_reprs: String = self .fragments(py)? .iter() - .map(|f| f.call_method0(py, "__repr__")?.extract(py)) + .map(|f| f.call_method0("__repr__")?.extract()) .collect::>>()? .join(", "); Ok(format!( @@ -236,7 +234,7 @@ impl PyCompactionTask { /// List[lance.fragment.FragmentMetadata] : The fragments that will be compacted. #[getter] - pub fn fragments(&self, py: Python<'_>) -> PyResult> { + pub fn fragments<'py>(&self, py: Python<'py>) -> PyResult>> { self.0 .task .fragments @@ -302,7 +300,7 @@ impl PyCompactionTask { pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { let state = self.json()?; - let state = PyTuple::new(py, vec![state]).extract()?; + let state = PyTuple::new(py, vec![state])?.extract()?; let from_json = PyModule::import(py, "lance.optimize")? .getattr("CompactionTask")? .getattr("from_json")? @@ -337,13 +335,13 @@ impl PyRewriteResult { let orig_fragment_reprs: String = self .original_fragments(py)? .iter() - .map(|f| f.call_method0(py, "__repr__")?.extract(py)) + .map(|f| f.call_method0("__repr__")?.extract()) .collect::>>()? .join(", "); let new_fragment_reprs: String = self .original_fragments(py)? .iter() - .map(|f| f.call_method0(py, "__repr__")?.extract(py)) + .map(|f| f.call_method0("__repr__")?.extract()) .collect::>>()? .join(", "); @@ -361,7 +359,7 @@ impl PyRewriteResult { /// List[lance.fragment.FragmentMetadata] : The metadata for fragments that are being replaced. #[getter] - pub fn original_fragments(&self, py: Python<'_>) -> PyResult> { + pub fn original_fragments<'py>(&self, py: Python<'py>) -> PyResult>> { self.0 .original_fragments .iter() @@ -371,7 +369,7 @@ impl PyRewriteResult { /// List[lance.fragment.FragmentMetadata] : The metadata for fragments that are being added. #[getter] - pub fn new_fragments(&self, py: Python<'_>) -> PyResult> { + pub fn new_fragments<'py>(&self, py: Python<'py>) -> PyResult>> { self.0 .new_fragments .iter() @@ -417,7 +415,7 @@ impl PyRewriteResult { pub fn __reduce__(&self, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { let state = self.json()?; - let state = PyTuple::new(py, vec![state]).extract()?; + let state = PyTuple::new(py, vec![state])?.extract()?; let from_json = PyModule::import(py, "lance.optimize")? .getattr("RewriteResult")? .getattr("from_json")? @@ -472,7 +470,7 @@ impl PyCompaction { // Make sure we parse the options within a scoped GIL context, so we // aren't holding the GIL while blocking the thread on the operation. let opts = Python::with_gil(|py| { - let options = options.downcast::(py)?; + let options = options.downcast_bound::(py)?; parse_compaction_options(options) })?; let mut new_ds = dataset.ds.as_ref().clone(); @@ -509,7 +507,7 @@ impl PyCompaction { // Make sure we parse the options within a scoped GIL context, so we // aren't holding the GIL while blocking the thread on the operation. let opts = Python::with_gil(|py| { - let options = options.downcast::(py)?; + let options = options.downcast_bound::(py)?; parse_compaction_options(options) })?; let plan = RT diff --git a/python/src/dataset/stats.rs b/python/src/dataset/stats.rs new file mode 100644 index 00000000000..fc294727d60 --- /dev/null +++ b/python/src/dataset/stats.rs @@ -0,0 +1,55 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use lance::dataset::statistics::{DataStatistics, FieldStatistics}; +use pyo3::{intern, types::PyAnyMethods, Bound, IntoPyObject, PyAny, PyErr, Python}; + +use crate::utils::{export_vec, PyLance}; + +impl<'py> IntoPyObject<'py> for PyLance<&FieldStatistics> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let cls = py + .import(intern!(py, "lance")) + .and_then(|m| m.getattr("FieldStatistics")) + .expect("FieldStatistics class not found"); + + let id = self.0.id; + let bytes_on_disk = self.0.bytes_on_disk; + + // unwrap due to infallible + Ok(cls.call1((id, bytes_on_disk)).unwrap()) + } +} + +impl<'py> IntoPyObject<'py> for PyLance { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let cls = py + .import(intern!(py, "lance")) + .and_then(|m| m.getattr("DataStatistics")) + .expect("DataStatistics class not found"); + + let fields = export_vec(py, &self.0.fields)?; + + // unwrap due to infallible + Ok(cls.call1((fields,)).unwrap()) + } +} diff --git a/python/src/debug.rs b/python/src/debug.rs index 8856c1fb286..8886617e2ac 100644 --- a/python/src/debug.rs +++ b/python/src/debug.rs @@ -4,29 +4,29 @@ use std::sync::Arc; use lance::{datatypes::Schema, Error}; -use lance_table::format::{DeletionFile, Fragment as LanceFragmentMetadata}; +use lance_table::format::{DeletionFile, Fragment}; use pyo3::{exceptions::PyIOError, prelude::*}; -use crate::{Dataset, FragmentMetadata, RT}; +use crate::{utils::PyLance, Dataset, RT}; /// Format the Lance schema of a dataset as a string. /// /// This can be used to view the field ids and types in the schema. #[pyfunction] -pub fn format_schema(dataset: &PyAny) -> PyResult { +pub fn format_schema(dataset: &Bound<'_, PyAny>) -> PyResult { let py = dataset.py(); let dataset = dataset.getattr("_ds")?.extract::>()?; - let dataset_ref = &dataset.as_ref(py).borrow().ds; + let dataset_ref = &dataset.bind(py).borrow().ds; let schema = dataset_ref.schema(); Ok(format!("{:#?}", schema)) } /// Print the full Lance manifest of the dataset. #[pyfunction] -pub fn format_manifest(dataset: &PyAny) -> PyResult { +pub fn format_manifest(dataset: &Bound<'_, PyAny>) -> PyResult { let py = dataset.py(); let dataset = dataset.getattr("_ds")?.extract::>()?; - let dataset_ref = &dataset.as_ref(py).borrow().ds; + let dataset_ref = &dataset.bind(py).borrow().ds; let manifest = dataset_ref.manifest(); Ok(format!("{:#?}", manifest)) } @@ -53,7 +53,7 @@ struct PrettyPrintableDataFile { } impl PrettyPrintableFragment { - fn new(fragment: &LanceFragmentMetadata, schema: &Schema) -> Self { + fn new(fragment: &Fragment, schema: &Schema) -> Self { let files = fragment .files .iter() @@ -81,18 +81,18 @@ impl PrettyPrintableFragment { /// Debug print a LanceFragment. #[pyfunction] -pub fn format_fragment(fragment: &PyAny, dataset: &PyAny) -> PyResult { - let py = fragment.py(); - let fragment = fragment - .getattr("_metadata")? - .extract::>()?; +pub fn format_fragment( + fragment: PyLance, + dataset: &Bound<'_, PyAny>, +) -> PyResult { + let py = dataset.py(); + let fragment = fragment.0; let dataset = dataset.getattr("_ds")?.extract::>()?; - let dataset_ref = &dataset.as_ref(py).borrow().ds; + let dataset_ref = &dataset.bind(py).borrow().ds; let schema = dataset_ref.schema(); - let meta = fragment.as_ref(py).borrow().inner.clone(); - let pp_meta = PrettyPrintableFragment::new(&meta, schema); + let pp_meta = PrettyPrintableFragment::new(&fragment, schema); Ok(format!("{:#?}", pp_meta)) } @@ -104,12 +104,12 @@ pub fn format_fragment(fragment: &PyAny, dataset: &PyAny) -> PyResult { #[pyfunction] #[pyo3(signature = (dataset, /, max_transactions = 10))] pub fn list_transactions( - dataset: &PyAny, + dataset: &Bound<'_, PyAny>, max_transactions: usize, ) -> PyResult>> { let py = dataset.py(); let dataset = dataset.getattr("_ds")?.extract::>()?; - let mut dataset = dataset.as_ref(py).borrow().ds.clone(); + let mut dataset = dataset.bind(py).borrow().ds.clone(); RT.block_on(Some(py), async move { let mut transactions = vec![]; diff --git a/python/src/file.rs b/python/src/file.rs index e6e3d237d12..222965f85ea 100644 --- a/python/src/file.rs +++ b/python/src/file.rs @@ -23,7 +23,9 @@ use lance_core::cache::FileMetadataCache; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; use lance_file::{ v2::{ - reader::{BufferDescriptor, CachedFileMetadata, FileReader, FileReaderOptions}, + reader::{ + BufferDescriptor, CachedFileMetadata, FileReader, FileReaderOptions, FileStatistics, + }, writer::{FileWriter, FileWriterOptions}, }, version::LanceFileVersion, @@ -36,10 +38,13 @@ use lance_io::{ use object_store::path::Path; use pyo3::{ exceptions::{PyIOError, PyRuntimeError, PyValueError}, - pyclass, pymethods, IntoPy, PyObject, PyResult, Python, + pyclass, pymethods, IntoPyObjectExt, PyObject, PyResult, Python, }; use serde::Serialize; -use std::collections::HashMap; +use std::{ + collections::HashMap, + sync::{Mutex, MutexGuard}, +}; use std::{pin::Pin, sync::Arc}; use url::Url; @@ -113,6 +118,58 @@ impl LanceColumnMetadata { } } +/// Statistics summarize some of the file metadata for quick summary info +#[pyclass(get_all)] +#[derive(Clone, Debug, Serialize)] +pub struct LanceFileStatistics { + /// Statistics about each of the columns in the file + columns: Vec, +} + +#[pymethods] +impl LanceFileStatistics { + fn __repr__(&self) -> String { + let column_reprs: Vec = self.columns.iter().map(|col| col.__repr__()).collect(); + format!("FileStatistics(columns=[{}])", column_reprs.join(", ")) + } +} + +/// Summary information describing a column +#[pyclass(get_all)] +#[derive(Clone, Debug, Serialize)] +pub struct LanceColumnStatistics { + /// The number of pages in the column + num_pages: usize, + /// The total number of data & metadata bytes in the column + /// + /// This is the compressed on-disk size + size_bytes: u64, +} + +#[pymethods] +impl LanceColumnStatistics { + fn __repr__(&self) -> String { + format!( + "ColumnStatistics(num_pages={}, size_bytes={})", + self.num_pages, self.size_bytes + ) + } +} + +impl LanceFileStatistics { + fn new(inner: &FileStatistics) -> Self { + let columns = inner + .columns + .iter() + .map(|column_stat| LanceColumnStatistics { + num_pages: column_stat.num_pages, + size_bytes: column_stat.size_bytes, + }) + .collect(); + Self { columns } + } +} + #[pyclass(get_all)] #[derive(Clone, Debug, Serialize)] pub struct LanceFileMetadata { @@ -141,7 +198,7 @@ pub struct LanceFileMetadata { impl LanceFileMetadata { fn new(inner: &CachedFileMetadata, py: Python) -> Self { let arrow_schema = arrow_schema::Schema::from(inner.file_schema.as_ref()); - let schema = Some(PyArrowType(arrow_schema).into_py(py)); + let schema = PyArrowType(arrow_schema).into_py_any(py).ok(); Self { major_version: inner.major_version, minor_version: inner.minor_version, @@ -174,7 +231,7 @@ impl LanceFileMetadata { #[pyclass] pub struct LanceFileWriter { - inner: Box, + inner: Arc>>, } impl LanceFileWriter { @@ -205,14 +262,21 @@ impl LanceFileWriter { Ok(FileWriter::new_lazy(object_writer, options)) }?; Ok(Self { - inner: Box::new(inner), + inner: Arc::new(Mutex::new(Box::new(inner))), }) } + + fn inner_lock(&self) -> PyResult>> { + self.inner + .lock() + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + } } #[pymethods] impl LanceFileWriter { #[new] + #[pyo3(signature=(path, schema=None, data_cache_bytes=None, version=None, storage_options=None, keep_original_array=None))] pub fn new( path: String, schema: Option>, @@ -231,24 +295,27 @@ impl LanceFileWriter { )) } - pub fn write_batch(&mut self, batch: PyArrowType) -> PyResult<()> { + pub fn write_batch(&self, batch: PyArrowType) -> PyResult<()> { RT.runtime - .block_on(self.inner.write_batch(&batch.0)) + .block_on(self.inner_lock()?.write_batch(&batch.0)) .infer_error() } - pub fn finish(&mut self) -> PyResult { - RT.runtime.block_on(self.inner.finish()).infer_error() + pub fn finish(&self) -> PyResult { + RT.runtime + .block_on(self.inner_lock()?.finish()) + .infer_error() } - pub fn add_global_buffer(&mut self, bytes: Vec) -> PyResult { + pub fn add_global_buffer(&self, bytes: Vec) -> PyResult { RT.runtime - .block_on(self.inner.add_global_buffer(Bytes::from(bytes))) + .block_on(self.inner_lock()?.add_global_buffer(Bytes::from(bytes))) .infer_error() } - pub fn add_schema_metadata(&mut self, key: String, value: String) { - self.inner.add_schema_metadata(key, value) + pub fn add_schema_metadata(&self, key: String, value: String) -> PyResult<()> { + self.inner_lock()?.add_schema_metadata(key, value); + Ok(()) } } @@ -266,7 +333,7 @@ fn path_to_parent(path: &Path) -> PyResult<(Path, String)> { pub async fn object_store_from_uri_or_path_no_options( uri_or_path: impl AsRef, -) -> PyResult<(ObjectStore, Path)> { +) -> PyResult<(Arc, Path)> { object_store_from_uri_or_path(uri_or_path, None).await } @@ -277,7 +344,7 @@ pub async fn object_store_from_uri_or_path_no_options( pub async fn object_store_from_uri_or_path( uri_or_path: impl AsRef, storage_options: Option>, -) -> PyResult<(ObjectStore, Path)> { +) -> PyResult<(Arc, Path)> { if let Ok(mut url) = Url::parse(uri_or_path.as_ref()) { if url.scheme().len() > 1 { let path = object_store::path::Path::parse(url.path()).map_err(|e| { @@ -309,7 +376,7 @@ pub async fn object_store_from_uri_or_path( let path = Path::parse(uri_or_path.as_ref()).map_err(|e| { PyIOError::new_err(format!("Invalid path `{}`: {}", uri_or_path.as_ref(), e)) })?; - let object_store = ObjectStore::local(); + let object_store = Arc::new(ObjectStore::local()); Ok((object_store, path)) } @@ -326,7 +393,7 @@ impl LanceFileReader { let (object_store, path) = object_store_from_uri_or_path(uri_or_path, storage_options).await?; let scheduler = ScanScheduler::new( - Arc::new(object_store), + object_store, SchedulerConfig { io_buffer_size_bytes: 2 * 1024 * 1024 * 1024, }, @@ -390,6 +457,7 @@ impl LanceFileReader { #[pymethods] impl LanceFileReader { #[new] + #[pyo3(signature=(path, storage_options=None))] pub fn new(path: String, storage_options: Option>) -> PyResult { RT.runtime.block_on(Self::open(path, storage_options)) } @@ -443,6 +511,11 @@ impl LanceFileReader { LanceFileMetadata::new(inner_meta, py) } + pub fn file_statistics(&self) -> LanceFileStatistics { + let inner_stat = self.inner.file_statistics(); + LanceFileStatistics::new(&inner_stat) + } + pub fn read_global_buffer(&mut self, index: u32) -> PyResult> { let buffer_bytes = RT .runtime @@ -451,3 +524,65 @@ impl LanceFileReader { Ok(buffer_bytes.to_vec()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lance_file_statistics_repr_empty() { + let stats = LanceFileStatistics { columns: vec![] }; + + let repr_str = stats.__repr__(); + assert_eq!(repr_str, "FileStatistics(columns=[])"); + } + + #[test] + fn test_lance_file_statistics_repr_single_column() { + let stats = LanceFileStatistics { + columns: vec![LanceColumnStatistics { + num_pages: 5, + size_bytes: 1024, + }], + }; + + let repr_str = stats.__repr__(); + assert_eq!( + repr_str, + "FileStatistics(columns=[ColumnStatistics(num_pages=5, size_bytes=1024)])" + ); + } + + #[test] + fn test_lance_file_statistics_repr_multiple_columns() { + let stats = LanceFileStatistics { + columns: vec![ + LanceColumnStatistics { + num_pages: 5, + size_bytes: 1024, + }, + LanceColumnStatistics { + num_pages: 3, + size_bytes: 512, + }, + ], + }; + + let repr_str = stats.__repr__(); + assert_eq!( + repr_str, + "FileStatistics(columns=[ColumnStatistics(num_pages=5, size_bytes=1024), ColumnStatistics(num_pages=3, size_bytes=512)])" + ); + } + + #[test] + fn test_lance_column_statistics_repr() { + let column_stats = LanceColumnStatistics { + num_pages: 10, + size_bytes: 2048, + }; + + let repr_str = column_stats.__repr__(); + assert_eq!(repr_str, "ColumnStatistics(num_pages=10, size_bytes=2048)"); + } +} diff --git a/python/src/fragment.rs b/python/src/fragment.rs index bc46ce54ad3..8828eb77d71 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -16,23 +16,27 @@ use std::fmt::Write as _; use std::sync::Arc; use arrow::ffi_stream::ArrowArrayStreamReader; -use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use arrow_array::RecordBatchReader; use arrow_schema::Schema as ArrowSchema; use futures::TryFutureExt; use lance::dataset::fragment::FileFragment as LanceFragment; -use lance::dataset::transaction::Operation; -use lance::dataset::{InsertBuilder, NewColumnTransform, WriteDestination}; +use lance::dataset::transaction::{Operation, Transaction}; +use lance::dataset::{InsertBuilder, NewColumnTransform}; use lance::Error; -use lance_table::format::{DataFile as LanceDataFile, Fragment as LanceFragmentMetadata}; +use lance_table::format::{DataFile, DeletionFile, DeletionFileType, Fragment, RowIdMeta}; use lance_table::io::deletion::deletion_file_path; -use pyo3::prelude::*; -use pyo3::{exceptions::*, pyclass::CompareOp, types::PyDict}; -use snafu::{location, Location}; - -use crate::dataset::{get_write_params, transforms_from_python}; +use object_store::path::Path; +use pyo3::basic::CompareOp; +use pyo3::types::PyTuple; +use pyo3::{exceptions::*, types::PyDict}; +use pyo3::{intern, prelude::*}; +use snafu::location; + +use crate::dataset::{get_write_params, transforms_from_python, PyWriteDest}; use crate::error::PythonErrorExt; use crate::schema::LanceSchema; +use crate::utils::{export_vec, extract_vec, PyLance}; use crate::{Dataset, Scanner, RT}; #[pyclass(name = "_Fragment", module = "_lib")] @@ -80,13 +84,13 @@ impl FileFragment { filename: &str, dataset: &Dataset, fragment_id: usize, - ) -> PyResult { + ) -> PyResult> { let metadata = RT.block_on(None, async { LanceFragment::create_from_file(filename, dataset.ds.as_ref(), fragment_id, None) .await .map_err(|err| PyIOError::new_err(err.to_string())) })??; - Ok(FragmentMetadata::new(metadata)) + Ok(PyLance(metadata)) } #[staticmethod] @@ -95,8 +99,8 @@ impl FileFragment { dataset_uri: &str, fragment_id: Option, reader: &Bound, - kwargs: Option<&PyDict>, - ) -> PyResult { + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult> { let params = if let Some(kw_params) = kwargs { get_write_params(kw_params)? } else { @@ -112,7 +116,7 @@ impl FileFragment { .await .map_err(|err| PyIOError::new_err(err.to_string()))?; - Ok(FragmentMetadata::new(metadata)) + Ok(PyLance(metadata)) }) }) } @@ -121,19 +125,21 @@ impl FileFragment { self.fragment.id() } - pub fn metadata(&self) -> FragmentMetadata { - FragmentMetadata::new(self.fragment.metadata().clone()) + pub fn metadata(&self) -> PyLance { + PyLance(self.fragment.metadata().clone()) } - fn count_rows(&self, _filter: Option) -> PyResult { + #[pyo3(signature=(filter=None))] + fn count_rows(&self, filter: Option) -> PyResult { RT.runtime.block_on(async { self.fragment - .count_rows() + .count_rows(filter) .await .map_err(|e| PyIOError::new_err(e.to_string())) }) } + #[pyo3(signature=(row_indices, columns=None))] fn take( self_: PyRef<'_, Self>, row_indices: Vec, @@ -159,6 +165,7 @@ impl FileFragment { } #[allow(clippy::too_many_arguments)] + #[pyo3(signature=(columns=None, columns_with_transform=None, batch_size=None, filter=None, limit=None, offset=None, with_row_id=None, with_row_address=None, batch_readahead=None))] fn scanner( self_: PyRef<'_, Self>, columns: Option>, @@ -168,6 +175,7 @@ impl FileFragment { limit: Option, offset: Option, with_row_id: Option, + with_row_address: Option, batch_readahead: Option, ) -> PyResult { let mut scanner = self_.fragment.scan(); @@ -207,6 +215,9 @@ impl FileFragment { if with_row_id.unwrap_or(false) { scanner.with_row_id(); } + if with_row_address.unwrap_or(false) { + scanner.with_row_address(); + } if let Some(batch_readahead) = batch_readahead { scanner.batch_readahead(batch_readahead); } @@ -215,11 +226,12 @@ impl FileFragment { Ok(Scanner::new(scn)) } + #[pyo3(signature=(reader, batch_size=None))] fn add_columns_from_reader( &mut self, reader: &Bound, batch_size: Option, - ) -> PyResult<(FragmentMetadata, LanceSchema)> { + ) -> PyResult<(PyLance, LanceSchema)> { let batches = ArrowArrayStreamReader::from_pyarrow_bound(reader)?; let transforms = NewColumnTransform::Reader(Box::new(batches)); @@ -231,15 +243,16 @@ impl FileFragment { })? .infer_error()?; - Ok((FragmentMetadata::new(fragment), LanceSchema(schema))) + Ok((PyLance(fragment), LanceSchema(schema))) } + #[pyo3(signature=(transforms, read_columns=None, batch_size=None))] fn add_columns( &mut self, - transforms: &PyAny, + transforms: &Bound<'_, PyAny>, read_columns: Option>, batch_size: Option, - ) -> PyResult<(FragmentMetadata, LanceSchema)> { + ) -> PyResult<(PyLance, LanceSchema)> { let transforms = transforms_from_python(transforms)?; let fragment = self.fragment.clone(); @@ -251,7 +264,26 @@ impl FileFragment { })? .infer_error()?; - Ok((FragmentMetadata::new(fragment), LanceSchema(schema))) + Ok((PyLance(fragment), LanceSchema(schema))) + } + + fn merge( + &mut self, + reader: PyArrowType, + left_on: String, + right_on: String, + max_field_id: i32, + ) -> PyResult<(PyLance, LanceSchema)> { + let mut fragment = self.fragment.clone(); + let (fragment, schema) = RT + .spawn(None, async move { + fragment + .merge_columns(reader.0, &left_on, &right_on, max_field_id) + .await + })? + .infer_error()?; + + Ok((PyLance(fragment), LanceSchema(schema))) } fn delete(&self, predicate: &str) -> PyResult> { @@ -273,13 +305,13 @@ impl FileFragment { } /// Returns the data file objects associated with this fragment. - fn data_files(self_: PyRef<'_, Self>) -> PyResult> { - let data_files: Vec = self_ + fn data_files(self_: PyRef<'_, Self>) -> PyResult>> { + let data_files: Vec<_> = self_ .fragment .metadata() .files .iter() - .map(|f| DataFile::new(f.clone())) + .map(|f| PyLance(f.clone())) .collect(); Ok(data_files) } @@ -309,218 +341,354 @@ impl From for LanceFragment { } } -/// Metadata of a DataFile. -#[pyclass(name = "_DataFile", module = "_lib")] -pub struct DataFile { - pub(crate) inner: LanceDataFile, +fn do_write_fragments( + dest: PyWriteDest, + reader: &Bound, + kwargs: Option<&Bound<'_, PyDict>>, +) -> PyResult { + let batches = convert_reader(reader)?; + + let params = kwargs + .and_then(|params| get_write_params(params).transpose()) + .transpose()? + .unwrap_or_default(); + + RT.block_on( + Some(reader.py()), + InsertBuilder::new(dest.as_dest()) + .with_params(¶ms) + .execute_uncommitted_stream(batches), + )? + .map_err(|err| PyIOError::new_err(err.to_string())) +} + +#[pyfunction(name = "_write_fragments")] +#[pyo3(signature = (dest, reader, **kwargs))] +pub fn write_fragments( + dest: PyWriteDest, + reader: &Bound, + kwargs: Option<&Bound<'_, PyDict>>, +) -> PyResult> { + let written = do_write_fragments(dest, reader, kwargs)?; + + assert!( + written.blobs_op.is_none(), + "Blob writing is not yet supported by the python _write_fragments API" + ); + + let get_fragments = |operation| match operation { + Operation::Overwrite { fragments, .. } => Ok(fragments), + Operation::Append { fragments, .. } => Ok(fragments), + _ => Err(Error::Internal { + message: "Unexpected operation".into(), + location: location!(), + }), + }; + let fragments = + get_fragments(written.operation).map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + + export_vec(reader.py(), &fragments) } -impl DataFile { - fn new(inner: LanceDataFile) -> Self { - Self { inner } +#[pyfunction(name = "_write_fragments_transaction")] +#[pyo3(signature = (dest, reader, **kwargs))] +pub fn write_fragments_transaction<'py>( + dest: PyWriteDest, + reader: &'py Bound<'py, PyAny>, + kwargs: Option<&Bound<'py, PyDict>>, +) -> PyResult> { + let written = do_write_fragments(dest, reader, kwargs)?; + + PyLance(written).into_pyobject(reader.py()) +} + +fn convert_reader(reader: &Bound) -> PyResult> { + if reader.is_instance_of::() { + let scanner: Scanner = reader.extract()?; + let reader = RT.block_on( + Some(reader.py()), + scanner + .to_reader() + .map_err(|err| PyValueError::new_err(err.to_string())), + )??; + Ok(Box::new(reader)) + } else { + Ok(Box::new(ArrowArrayStreamReader::from_pyarrow_bound( + reader, + )?)) } } +#[pyclass(name = "DeletionFile", module = "lance.fragment")] +pub struct PyDeletionFile(pub DeletionFile); + #[pymethods] -impl DataFile { - fn __repr__(&self) -> String { - format!("DataFile({})", self.path()) +impl PyDeletionFile { + #[new] + fn new(read_version: u64, id: u64, file_type: &str, num_deleted_rows: usize) -> PyResult { + let file_type = match file_type { + "array" => DeletionFileType::Array, + "bitmap" => DeletionFileType::Bitmap, + _ => { + return Err(PyValueError::new_err(format!( + "file_type must be either 'array' or 'bitmap', got '{}'", + file_type + ))) + } + }; + Ok(Self(DeletionFile { + read_version, + id, + file_type, + num_deleted_rows: Some(num_deleted_rows), + })) } - fn path(&self) -> String { - self.inner.path.clone() - } + fn asdict(slf: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(slf.py()); - fn field_ids(&self) -> Vec { - self.inner.fields.clone() + dict.set_item(intern!(slf.py(), "read_version"), slf.0.read_version)?; + dict.set_item(intern!(slf.py(), "id"), slf.0.id)?; + dict.set_item(intern!(slf.py(), "file_type"), slf.file_type())?; + dict.set_item( + intern!(slf.py(), "num_deleted_rows"), + slf.0.num_deleted_rows, + )?; + + Ok(dict) } - fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { - match op { - CompareOp::Eq => Ok(self.inner == other.inner), - CompareOp::Ne => Ok(self.inner != other.inner), - _ => Err(PyNotImplementedError::new_err( - "Only == and != are supported for DataFile", - )), + fn __repr__(&self) -> String { + let mut repr = "DeletionFile(".to_string(); + write!(repr, "type='{}'", self.file_type()).unwrap(); + if let Some(num_deleted_rows) = self.0.num_deleted_rows { + write!(repr, ", num_deleted_rows={}", num_deleted_rows).unwrap(); } + write!(repr, ")").unwrap(); + repr } -} -#[pyclass(name = "_FragmentMetadata", module = "lance")] -#[derive(Clone, Debug)] -pub struct FragmentMetadata { - pub(crate) inner: LanceFragmentMetadata, -} + #[getter] + fn read_version(&self) -> u64 { + self.0.read_version + } -impl FragmentMetadata { - pub(crate) fn new(inner: LanceFragmentMetadata) -> Self { - Self { inner } + #[getter] + fn id(&self) -> u64 { + self.0.id } -} -#[pymethods] -impl FragmentMetadata { - #[new] - fn init() -> Self { - Self { - inner: LanceFragmentMetadata::new(0), + #[getter] + fn num_deleted_rows(&self) -> Option { + self.0.num_deleted_rows + } + + #[getter] + fn file_type(&self) -> &str { + match self.0.file_type { + DeletionFileType::Array => "array", + DeletionFileType::Bitmap => "bitmap", } } + #[pyo3(signature = (fragment_id, base_uri=None))] + fn path(&self, fragment_id: u64, base_uri: Option<&str>) -> PyResult { + let base_path = if let Some(base_uri) = base_uri { + Path::from_url_path(base_uri).map_err(|e| { + PyValueError::new_err(format!("Invalid base URI: {}: {}", base_uri, e)) + })? + } else { + Path::default() + }; + Ok(deletion_file_path(&base_path, fragment_id, &self.0).to_string()) + } + + pub fn json(&self) -> PyResult { + serde_json::to_string(&self.0).map_err(|err| { + PyValueError::new_err(format!( + "Could not dump CompactionPlan due to error: {}", + err + )) + }) + } + #[staticmethod] - fn from_json(json: &str) -> PyResult { - let metadata = LanceFragmentMetadata::from_json(json).map_err(|err| { - PyValueError::new_err(format!("Invalid metadata json payload: {json}: {}", err)) + pub fn from_json(json: String) -> PyResult { + let deletion_file = serde_json::from_str(&json).map_err(|err| { + PyValueError::new_err(format!("Could not load DeletionFile due to error: {}", err)) })?; + Ok(Self(deletion_file)) + } - Ok(Self { inner: metadata }) + fn __reduce__(&self, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { + let state = self.json()?; + let state = PyTuple::new(py, vec![state])?.extract()?; + let from_json = PyModule::import(py, "lance.fragment")? + .getattr("DeletionFile")? + .getattr("from_json")? + .extract()?; + Ok((from_json, state)) } - fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + pub fn __richcmp__(&self, other: PyRef<'_, Self>, op: CompareOp) -> PyResult { match op { - CompareOp::Lt => Ok(self.inner.id < other.inner.id), - CompareOp::Le => Ok(self.inner.id <= other.inner.id), - CompareOp::Eq => Ok(self.inner == other.inner), - CompareOp::Ne => self.__richcmp__(other, CompareOp::Eq).map(|v| !v), - CompareOp::Gt => self.__richcmp__(other, CompareOp::Le).map(|v| !v), - CompareOp::Ge => self.__richcmp__(other, CompareOp::Lt).map(|v| !v), + CompareOp::Eq => Ok(self.0 == other.0), + CompareOp::Ne => Ok(self.0 != other.0), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported for CompactionTask", + )), } } +} - fn __repr__(&self) -> String { - format!("{:?}", self.inner) +#[pyclass(name = "RowIdMeta", module = "lance.fragment")] +pub struct PyRowIdMeta(pub RowIdMeta); + +#[pymethods] +impl PyRowIdMeta { + fn asdict(&self) -> PyResult> { + Err(PyNotImplementedError::new_err( + "PyRowIdMeta.asdict is not yet supported.s", + )) + } + + pub fn json(&self) -> PyResult { + serde_json::to_string(&self.0).map_err(|err| { + PyValueError::new_err(format!( + "Could not dump CompactionPlan due to error: {}", + err + )) + }) } - fn json(self_: PyRef<'_, Self>) -> PyResult { - let json = serde_json::to_string(&self_.inner).map_err(|e| { - PyValueError::new_err(format!("Unable to serialize FragmentMetadata: {}", e)) + #[staticmethod] + pub fn from_json(json: String) -> PyResult { + let row_id_meta = serde_json::from_str(&json).map_err(|err| { + PyValueError::new_err(format!("Could not load RowIdMeta due to error: {}", err)) })?; - Ok(json) + Ok(Self(row_id_meta)) } - /// Returns the data file objects associated with this fragment. - fn data_files(self_: PyRef<'_, Self>) -> PyResult> { - let data_files: Vec = self_ - .inner - .files - .iter() - .map(|f| DataFile::new(f.clone())) - .collect(); - Ok(data_files) + fn __reduce__(&self, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { + let state = self.json()?; + let state = PyTuple::new(py, vec![state])?.extract()?; + let from_json = PyModule::import(py, "lance.fragment")? + .getattr("RowIdMeta")? + .getattr("from_json")? + .extract()?; + Ok((from_json, state)) } - fn deletion_file(&self) -> PyResult> { - let deletion = self.inner.deletion_file.clone(); - Ok( - deletion - .map(|d| deletion_file_path(&Default::default(), self.inner.id, &d).to_string()), - ) + pub fn __richcmp__(&self, other: PyRef<'_, Self>, op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(self.0 == other.0), + CompareOp::Ne => Ok(self.0 != other.0), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported for CompactionTask", + )), + } } +} - /// Get the physical rows statistic. - /// - /// This represents the original number of rows in the fragment - /// before any deletions. - /// - /// If this is None, it is unavailable. - #[getter] - fn physical_rows(&self) -> Option { - self.inner.physical_rows - } +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &pyo3::Bound<'_, PyAny>) -> PyResult { + let files = extract_vec(&ob.getattr("files")?)?; - /// Get the number of tombstoned rows in the fragment. - /// - /// If this is None, this statistic is unavailable. It does not necessarily - /// mean there are no deletions. - #[getter] - fn num_deletions(&self) -> Option { - self.inner - .deletion_file - .as_ref() - .and_then(|d| d.num_deleted_rows) - } + let deletion_file: Option> = + ob.getattr("deletion_file")?.extract()?; + let deletion_file = deletion_file.map(|f| f.0.clone()); - /// Get the number of rows in the fragment. - /// - /// This is equivalent to physical_rows minus num_deletions. - /// - /// If this is None, this statistic is unavailable. - #[getter] - fn num_rows(&self) -> Option { - self.inner.num_rows() - } + let row_id_meta: Option> = ob.getattr("row_id_meta")?.extract()?; + let row_id_meta = row_id_meta.map(|r| r.0.clone()); - #[getter] - fn id(&self) -> u64 { - self.inner.id + Ok(Self(Fragment { + id: ob.getattr("id")?.extract()?, + files, + deletion_file, + physical_rows: ob.getattr("physical_rows")?.extract()?, + row_id_meta, + })) } } -#[pyfunction(name = "_write_fragments")] -#[pyo3(signature = (dest, reader, **kwargs))] -pub fn write_fragments( - dest: &Bound, - reader: &Bound, - kwargs: Option<&PyDict>, -) -> PyResult> { - let batches = convert_reader(reader)?; +impl<'py> IntoPyObject<'py> for PyLance<&Fragment> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; - let params = kwargs - .and_then(|params| get_write_params(params).transpose()) - .transpose()? - .unwrap_or_default(); + fn into_pyobject(self, py: Python<'py>) -> Result { + let cls = py + .import(intern!(py, "lance.fragment")) + .and_then(|m| m.getattr("FragmentMetadata")) + .expect("FragmentMetadata class not found"); - let dest = if dest.is_instance_of::() { - let dataset: Dataset = dest.extract()?; - WriteDestination::Dataset(dataset.ds.clone()) - } else { - WriteDestination::Uri(dest.extract()?) - }; + let files = export_vec(py, &self.0.files)?; + let deletion_file = self + .0 + .deletion_file + .as_ref() + .map(|f| PyDeletionFile(f.clone())); + let row_id_meta = self.0.row_id_meta.as_ref().map(|r| PyRowIdMeta(r.clone())); - let written = RT - .block_on( - Some(reader.py()), - InsertBuilder::new(dest) - .with_params(¶ms) - .execute_uncommitted_stream(batches), - )? - .map_err(|err| PyIOError::new_err(err.to_string()))?; + cls.call1(( + self.0.id, + files, + self.0.physical_rows, + deletion_file, + row_id_meta, + )) + } +} - assert!( - written.blobs_op.is_none(), - "Blob writing is not yet supported by the python _write_fragments API" - ); +impl<'py> IntoPyObject<'py> for PyLance { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; - let get_fragments = |operation| match operation { - Operation::Overwrite { fragments, .. } => Ok(fragments), - Operation::Append { fragments, .. } => Ok(fragments), - _ => Err(Error::Internal { - message: "Unexpected operation".into(), - location: location!(), - }), - }; - let fragments = - get_fragments(written.operation).map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + fn into_pyobject(self, py: Python<'py>) -> Result { + PyLance(&self.0).into_pyobject(py) + } +} - fragments - .into_iter() - .map(|f| Ok(FragmentMetadata::new(f))) - .collect() +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &pyo3::Bound<'_, PyAny>) -> PyResult { + Ok(Self(DataFile { + path: ob.getattr("path")?.extract()?, + fields: ob.getattr("fields")?.extract()?, + column_indices: ob.getattr("column_indices")?.extract()?, + file_major_version: ob.getattr("file_major_version")?.extract()?, + file_minor_version: ob.getattr("file_minor_version")?.extract()?, + })) + } } -fn convert_reader(reader: &Bound) -> PyResult> { - if reader.is_instance_of::() { - let scanner: Scanner = reader.extract()?; - let reader = RT.block_on( - Some(reader.py()), - scanner - .to_reader() - .map_err(|err| PyValueError::new_err(err.to_string())), - )??; - Ok(Box::new(reader)) - } else { - Ok(Box::new(ArrowArrayStreamReader::from_pyarrow_bound( - reader, - )?)) +impl<'py> IntoPyObject<'py> for PyLance<&DataFile> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let cls = py + .import(intern!(py, "lance.fragment")) + .and_then(|m| m.getattr("DataFile")) + .expect("DataFile class not found"); + + cls.call1(( + &self.0.path, + self.0.fields.clone(), + self.0.column_indices.clone(), + self.0.file_major_version, + self.0.file_minor_version, + )) + } +} + +impl<'py> IntoPyObject<'py> for PyLance { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + PyLance(&self.0).into_pyobject(py) } } diff --git a/python/src/indices.rs b/python/src/indices.rs index 9b7b315e8f9..3edb948356e 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -13,6 +13,8 @@ use lance_index::vector::{ }; use lance_linalg::distance::DistanceType; use pyo3::exceptions::PyValueError; +use pyo3::types::PyModuleMethods; +use pyo3::Bound; use pyo3::{ pyfunction, types::{PyList, PyModule}, @@ -136,6 +138,7 @@ fn train_pq_model( centroids: Some(ivf_centroids), offsets: vec![], lengths: vec![], + loss: None, }; let codebook = RT.block_on( Some(py), @@ -165,7 +168,7 @@ async fn do_transform_vectors( partitions_ds_uri: Option<&str>, ) -> PyResult<()> { let num_rows = dataset.ds.count_rows(None).await.infer_error()?; - let fragments = fragments.iter().map(|item| item.metadata().inner).collect(); + let fragments = fragments.iter().map(|item| item.metadata().0).collect(); let transform_input = dataset .ds .scan() @@ -198,6 +201,7 @@ async fn do_transform_vectors( #[pyfunction] #[allow(clippy::too_many_arguments)] +#[pyo3(signature=(dataset, column, dimension, num_subvectors, distance_type, ivf_centroids, pq_codebook, dst_uri, fragments, partitions_ds_uri=None))] pub fn transform_vectors( py: Python<'_>, dataset: &Dataset, @@ -238,6 +242,7 @@ pub fn transform_vectors( )? } +#[allow(deprecated)] async fn do_shuffle_transformed_vectors( unsorted_filenames: Vec, dir_path: &str, @@ -284,10 +289,7 @@ pub fn shuffle_transformed_vectors( )?; match result { - Ok(partition_files) => { - let py_list = PyList::new(py, partition_files); - Ok(py_list.into()) - } + Ok(partition_files) => PyList::new(py, partition_files).map(|py_list| py_list.into()), Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(e.to_string())), } } @@ -329,6 +331,7 @@ async fn do_load_shuffled_vectors( } #[pyfunction] +#[pyo3(signature=(filenames, dir_path, dataset, column, ivf_centroids, pq_codebook, pq_dimension, num_subvectors, distance_type, index_name=None))] #[allow(clippy::too_many_arguments)] pub fn load_shuffled_vectors( filenames: Vec, @@ -353,6 +356,7 @@ pub fn load_shuffled_vectors( centroids: Some(ivf_centroids), offsets: vec![], lengths: vec![], + loss: None, }; let codebook = pq_codebook.0; @@ -375,13 +379,13 @@ pub fn load_shuffled_vectors( )? } -pub fn register_indices(py: Python, m: &PyModule) -> PyResult<()> { +pub fn register_indices(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let indices = PyModule::new(py, "indices")?; indices.add_wrapped(wrap_pyfunction!(train_ivf_model))?; indices.add_wrapped(wrap_pyfunction!(train_pq_model))?; indices.add_wrapped(wrap_pyfunction!(transform_vectors))?; indices.add_wrapped(wrap_pyfunction!(shuffle_transformed_vectors))?; indices.add_wrapped(wrap_pyfunction!(load_shuffled_vectors))?; - m.add_submodule(indices)?; + m.add_submodule(&indices)?; Ok(()) } diff --git a/python/src/lib.rs b/python/src/lib.rs index ec39c834fd9..b4b5d353761 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -18,6 +18,10 @@ //! automatic versioning, optimized for computer vision, bioinformatics, spatial and ML data. //! [Apache Arrow](https://arrow.apache.org/) and DuckDB compatible. +// Workaround for https://github.com/rust-lang/rust-clippy/issues/12039 +// Remove after upgrading pyo3 to 0.23 +#![allow(clippy::useless_conversion)] + use std::env; use std::sync::Arc; @@ -34,16 +38,18 @@ use dataset::cleanup::CleanupStats; use dataset::optimize::{ PyCompaction, PyCompactionMetrics, PyCompactionPlan, PyCompactionTask, PyRewriteResult, }; -use dataset::MergeInsertBuilder; -use env_logger::Env; +use dataset::{MergeInsertBuilder, PyFullTextQuery}; +use env_logger::{Builder, Env}; use file::{ LanceBufferDescriptor, LanceColumnMetadata, LanceFileMetadata, LanceFileReader, - LanceFileWriter, LancePageMetadata, + LanceFileStatistics, LanceFileWriter, LancePageMetadata, }; use futures::StreamExt; use lance_index::DatasetIndexExt; +use log::Level; use pyo3::exceptions::{PyIOError, PyValueError}; use pyo3::prelude::*; +use scanner::ScanStatistics; use session::Session; #[macro_use] @@ -64,17 +70,17 @@ pub(crate) mod scanner; pub(crate) mod schema; pub(crate) mod session; pub(crate) mod tracing; +pub(crate) mod transaction; pub(crate) mod utils; pub use crate::arrow::{bfloat16_array, BFloat16}; -use crate::fragment::write_fragments; +use crate::fragment::{write_fragments, write_fragments_transaction}; pub use crate::tracing::{trace_to_chrome, TraceGuard}; use crate::utils::Hnsw; use crate::utils::KMeans; pub use dataset::write_dataset; -pub use dataset::{Dataset, Operation, RewriteGroup, RewrittenIndex}; -pub use fragment::FragmentMetadata; -use fragment::{DataFile, FileFragment}; +pub use dataset::Dataset; +use fragment::{FileFragment, PyDeletionFile, PyRowIdMeta}; pub use indices::register_indices; pub use reader::LanceReader; pub use scanner::Scanner; @@ -89,10 +95,10 @@ pub fn is_datagen_supported() -> bool { // A fallback module for when datagen is not enabled #[cfg(not(feature = "datagen"))] -fn register_datagen(py: Python, m: &PyModule) -> PyResult<()> { +fn register_datagen(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let datagen = PyModule::new(py, "datagen")?; datagen.add_wrapped(wrap_pyfunction!(is_datagen_supported))?; - m.add_submodule(datagen)?; + m.add_submodule(&datagen)?; Ok(()) } @@ -101,29 +107,40 @@ lazy_static! { static ref RT: BackgroundExecutor = BackgroundExecutor::new(); } +pub fn init_logging(mut log_builder: Builder) { + let logger = log_builder.build(); + + let max_level = logger.filter(); + + let log_level = max_level.to_level().unwrap_or(Level::Error); + + tracing::initialize_tracing(log_level); + log::set_boxed_logger(Box::new(logger)).unwrap(); + log::set_max_level(max_level); +} + #[pymodule] -fn lance(py: Python, m: &PyModule) -> PyResult<()> { +fn lance(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let env = Env::new() .filter_or("LANCE_LOG", "warn") .write_style("LANCE_LOG_STYLE"); - env_logger::init_from_env(env); + let log_builder = env_logger::Builder::from_env(env); + init_logging(log_builder); m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -133,29 +150,46 @@ fn lance(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pyfunction!(bfloat16_array))?; m.add_wrapped(wrap_pyfunction!(write_dataset))?; m.add_wrapped(wrap_pyfunction!(write_fragments))?; + m.add_wrapped(wrap_pyfunction!(write_fragments_transaction))?; m.add_wrapped(wrap_pyfunction!(schema_to_json))?; m.add_wrapped(wrap_pyfunction!(json_to_schema))?; m.add_wrapped(wrap_pyfunction!(infer_tfrecord_schema))?; m.add_wrapped(wrap_pyfunction!(read_tfrecord))?; m.add_wrapped(wrap_pyfunction!(trace_to_chrome))?; m.add_wrapped(wrap_pyfunction!(manifest_needs_migration))?; + m.add_wrapped(wrap_pyfunction!(language_model_home))?; + m.add_wrapped(wrap_pyfunction!(bytes_read_counter))?; + m.add_wrapped(wrap_pyfunction!(iops_counter))?; // Debug functions m.add_wrapped(wrap_pyfunction!(debug::format_schema))?; m.add_wrapped(wrap_pyfunction!(debug::format_manifest))?; m.add_wrapped(wrap_pyfunction!(debug::format_fragment))?; m.add_wrapped(wrap_pyfunction!(debug::list_transactions))?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; + register_datagen(py, m)?; register_indices(py, m)?; Ok(()) } +#[pyfunction(name = "iops_counter")] +fn iops_counter() -> PyResult { + Ok(::lance::io::iops_counter()) +} + +#[pyfunction(name = "bytes_read_counter")] +fn bytes_read_counter() -> PyResult { + Ok(::lance::io::bytes_read_counter()) +} + #[pyfunction(name = "_schema_to_json")] fn schema_to_json(schema: PyArrowType) -> PyResult { schema.0.to_json().map_err(|e| { @@ -174,6 +208,21 @@ fn json_to_schema(json: &str) -> PyResult> { Ok(schema.into()) } +#[pyfunction] +pub fn language_model_home() -> PyResult { + let Some(p) = lance_index::scalar::inverted::language_model_home() else { + return Err(pyo3::exceptions::PyValueError::new_err( + "Failed to get language model home", + )); + }; + let Some(pstr) = p.to_str() else { + return Err(pyo3::exceptions::PyValueError::new_err( + "Failed to convert language model home to str", + )); + }; + Ok(String::from(pstr)) +} + /// Infer schema from tfrecord file /// /// Parameters @@ -297,10 +346,10 @@ fn read_tfrecord( #[pyfunction] #[pyo3(signature = (dataset,))] -fn manifest_needs_migration(dataset: &PyAny) -> PyResult { +fn manifest_needs_migration(dataset: &Bound<'_, PyAny>) -> PyResult { let py = dataset.py(); let dataset = dataset.getattr("_ds")?.extract::>()?; - let dataset_ref = &dataset.as_ref(py).borrow().ds; + let dataset_ref = &dataset.bind(py).borrow().ds; let indices = RT .block_on(Some(py), dataset_ref.load_indices())? .map_err(|err| PyIOError::new_err(format!("Could not read dataset metadata: {}", err)))?; diff --git a/python/src/scanner.rs b/python/src/scanner.rs index d32c02ac983..9102335fa76 100644 --- a/python/src/scanner.rs +++ b/python/src/scanner.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use arrow::pyarrow::*; use arrow_array::RecordBatchReader; +use lance::dataset::scanner::ExecutionSummaryCounts; use pyo3::prelude::*; use pyo3::pyclass; @@ -46,6 +47,42 @@ impl Scanner { } } +#[pyclass(name = "ScanStatistics", module = "_lib", get_all)] +#[derive(Clone)] +/// Statistics about the scan. +pub struct ScanStatistics { + /// Number of IO operations performed. This may be slightly higher than + /// the actual number due to coalesced I/O + pub iops: usize, + /// Number of bytes read from disk + pub bytes_read: usize, + /// Number of indices loaded + pub indices_loaded: usize, + /// Number of index partitions loaded + pub parts_loaded: usize, +} + +impl ScanStatistics { + pub fn from_lance(stats: &ExecutionSummaryCounts) -> Self { + Self { + iops: stats.iops, + bytes_read: stats.bytes_read, + indices_loaded: stats.indices_loaded, + parts_loaded: stats.parts_loaded, + } + } +} + +#[pymethods] +impl ScanStatistics { + fn __repr__(&self) -> String { + format!( + "ScanStatistics(iops={}, bytes_read={}, indices_loaded={}, parts_loaded={})", + self.iops, self.bytes_read, self.indices_loaded, self.parts_loaded + ) + } +} + #[pymethods] impl Scanner { #[getter(schema)] @@ -68,6 +105,19 @@ impl Scanner { Ok(res) } + #[pyo3(signature = (*))] + fn analyze_plan(self_: PyRef<'_, Self>) -> PyResult { + let scanner = self_.scanner.clone(); + let res = RT + .spawn( + Some(self_.py()), + async move { scanner.analyze_plan().await }, + )? + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + Ok(res) + } + fn count_rows(self_: PyRef<'_, Self>) -> PyResult { let scanner = self_.scanner.clone(); RT.spawn(Some(self_.py()), async move { scanner.count_rows().await })? diff --git a/python/src/schema.rs b/python/src/schema.rs index 9670b482576..fa1ac296b60 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -3,7 +3,7 @@ use arrow::pyarrow::PyArrowType; use arrow_schema::Schema as ArrowSchema; -use lance::datatypes::Schema; +use lance::datatypes::{Field, Schema}; use lance_file::datatypes::{Fields, FieldsWithMeta}; use lance_file::format::pb; use prost::Message; @@ -12,8 +12,45 @@ use pyo3::{ exceptions::{PyNotImplementedError, PyValueError}, prelude::*, types::PyTuple, + IntoPyObjectExt, }; +#[pyclass(name = "LanceField", module = "lance.schema")] +#[derive(Clone)] +pub struct LanceField(pub Field); + +/// A field in a Lance schema +/// +/// Unlike a PyArrow field, a Lance field has an id in addition to the name. +#[pymethods] +impl LanceField { + pub fn __repr__(&self) -> PyResult { + Ok(format!("{:?}", self.0)) + } + + pub fn __richcmp__(&self, other: Self, op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(self.0 == other.0), + CompareOp::Ne => Ok(self.0 != other.0), + _ => Err(PyNotImplementedError::new_err( + "Only == and != are supported", + )), + } + } + + pub fn children(&self) -> PyResult> { + Ok(self.0.children.iter().cloned().map(Self).collect()) + } + + pub fn name(&self) -> PyResult { + Ok(self.0.name.clone()) + } + + pub fn id(&self) -> PyResult { + Ok(self.0.id) + } +} + /// A Lance Schema. /// /// Unlike a PyArrow schema, a Lance schema assigns every field an integer id. @@ -69,14 +106,14 @@ impl LanceSchema { let mut states = Vec::new(); let metadata_str = serde_json::to_string(&fields_with_meta.metadata) .map_err(|e| PyErr::new::(format!("{}", e)))? - .into_py(py); + .into_py_any(py)?; states.push(metadata_str); for field in fields_with_meta.fields.0.iter() { - states.push(field.encode_to_vec().into_py(py)); + states.push(field.encode_to_vec().into_py_any(py)?); } - let state = PyTuple::new(py, states).extract()?; + let state = PyTuple::new(py, states)?.extract()?; let from_protos = PyModule::import(py, "lance.schema")? .getattr("LanceSchema")? .getattr("_from_protos")? @@ -104,4 +141,8 @@ impl LanceSchema { let schema = Schema::from(fields_with_meta); Ok(Self(schema)) } + + pub fn fields(&self) -> PyResult> { + Ok(self.0.fields.iter().cloned().map(LanceField).collect()) + } } diff --git a/python/src/tracing.rs b/python/src/tracing.rs index a9373c3e50d..92d82efe9ba 100644 --- a/python/src/tracing.rs +++ b/python/src/tracing.rs @@ -15,27 +15,151 @@ // specific language governing permissions and limitations // under the License. -use pyo3::exceptions::PyAssertionError; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::RwLock; + use pyo3::exceptions::PyValueError; use pyo3::pyclass; use pyo3::pyfunction; use pyo3::pymethods; use pyo3::PyResult; +use tracing::field::Visit; +use tracing::span; use tracing::subscriber; +use tracing::Level; +use tracing::Subscriber; +use tracing_chrome::ChromeLayer; use tracing_chrome::{ChromeLayerBuilder, TraceStyle}; use tracing_subscriber::filter; +use tracing_subscriber::filter::Filtered; +use tracing_subscriber::filter::Targets; +use tracing_subscriber::layer::Layered; use tracing_subscriber::prelude::*; use tracing_subscriber::Registry; +pub type TracingSubscriber = Layered, Targets, Registry>, Registry>; + +lazy_static::lazy_static! { + static ref SUBSCRIBER: LoggingSubscriberRef = LoggingPassthrough::init(); +} + +struct LoggingPassthroughState { + inner: Option, + level: Level, +} + +impl Default for LoggingPassthroughState { + fn default() -> Self { + Self { + inner: None, + // This value doesn't matter, we'll override it in `initialize_tracing` + level: Level::INFO, + } + } +} + +#[derive(Default)] +struct LoggingPassthrough { + state: RwLock, +} + +impl LoggingPassthrough { + fn init() -> LoggingSubscriberRef { + let subscriber = LoggingSubscriberRef(Arc::new(Self::default())); + subscriber::set_global_default(subscriber.clone()).unwrap(); + subscriber + } +} + +#[derive(Default)] +struct EventToStr { + str: String, +} + +impl Visit for EventToStr { + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { + self.str += &format!("{}={:?} ", field.name(), value); + } +} + +#[derive(Clone)] +pub struct LoggingSubscriberRef(Arc); + +impl Subscriber for LoggingSubscriberRef { + fn enabled(&self, metadata: &tracing::Metadata<'_>) -> bool { + metadata.is_event() || self.0.state.read().unwrap().inner.is_some() + } + + fn new_span(&self, span: &span::Attributes<'_>) -> span::Id { + let state = self.0.state.read().unwrap(); + if let Some(inner) = &state.inner { + inner.new_span(span) + } else { + span::Id::from_u64(0) + } + } + + fn record(&self, span: &span::Id, values: &span::Record<'_>) { + let state = self.0.state.read().unwrap(); + if let Some(inner) = &state.inner { + inner.record(span, values); + } + } + + fn record_follows_from(&self, span: &span::Id, follows: &span::Id) { + let state = self.0.state.read().unwrap(); + if let Some(inner) = &state.inner { + inner.record_follows_from(span, follows); + } + } + + fn event(&self, event: &tracing::Event<'_>) { + let state = self.0.state.read().unwrap(); + + if event.metadata().level() <= &state.level { + let log_level = match *event.metadata().level() { + Level::TRACE => log::Level::Trace, + Level::DEBUG => log::Level::Debug, + Level::INFO => log::Level::Info, + Level::WARN => log::Level::Warn, + Level::ERROR => log::Level::Error, + }; + let mut fields = EventToStr::default(); + event.record(&mut fields); + log::log!(target: "lance::events", log_level, "target=\"{}\" {}", event.metadata().target(), fields.str); + } + + if let Some(inner) = &state.inner { + inner.event(event); + } + } + + fn enter(&self, span: &span::Id) { + let state = self.0.state.read().unwrap(); + if let Some(inner) = &state.inner { + inner.enter(span); + } + } + + fn exit(&self, span: &span::Id) { + let state = self.0.state.read().unwrap(); + if let Some(inner) = &state.inner { + inner.exit(span); + } + } +} + #[pyclass] pub struct TraceGuard { - guard: Option, + guard: Arc>>, } #[pymethods] impl TraceGuard { - pub fn finish_tracing(&mut self) { - self.guard.take(); + pub fn finish_tracing(&self) { + // We're exiting anyways, so discard the result + let _ = self.guard.lock().map(|mut g| g.take()); } } @@ -55,6 +179,7 @@ fn get_filter(level: Option<&str>) -> PyResult { } #[pyfunction] +#[pyo3(signature=(path=None, level=None))] pub fn trace_to_chrome(path: Option<&str>, level: Option<&str>) -> PyResult { let mut builder = ChromeLayerBuilder::new() .trace_style(TraceStyle::Async) @@ -71,9 +196,24 @@ pub fn trace_to_chrome(path: Option<&str>, level: Option<&str>) -> PyResult Level::TRACE, + log::Level::Debug => Level::DEBUG, + log::Level::Info => Level::INFO, + log::Level::Warn => Level::WARN, + log::Level::Error => Level::ERROR, + }; + + let mut state = SUBSCRIBER.0.state.write().unwrap(); + state.level = tracing_level; } diff --git a/python/src/transaction.rs b/python/src/transaction.rs new file mode 100644 index 00000000000..b09087bd066 --- /dev/null +++ b/python/src/transaction.rs @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use arrow::pyarrow::PyArrowType; +use arrow_schema::Schema as ArrowSchema; +use lance::dataset::transaction::{ + DataReplacementGroup, Operation, RewriteGroup, RewrittenIndex, Transaction, +}; +use lance::datatypes::Schema; +use lance_table::format::{DataFile, Fragment, Index}; +use pyo3::exceptions::PyValueError; +use pyo3::types::PySet; +use pyo3::{intern, prelude::*}; +use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; +use uuid::Uuid; + +use crate::schema::LanceSchema; +use crate::utils::{class_name, export_vec, extract_vec, PyLance}; + +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + let fragment_id = ob.getattr("fragment_id")?.extract::()?; + let new_file = &ob.getattr("new_file")?.extract::>()?; + + Ok(Self(DataReplacementGroup(fragment_id, new_file.0.clone()))) + } +} + +impl<'py> IntoPyObject<'py> for PyLance<&DataReplacementGroup> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let namespace = py + .import(intern!(py, "lance")) + .and_then(|module| module.getattr(intern!(py, "LanceOperation"))) + .expect("Failed to import LanceOperation namespace"); + + let fragment_id = self.0 .0; + let new_file = PyLance(&self.0 .1).into_pyobject(py)?; + + let cls = namespace + .getattr("DataReplacementGroup") + .expect("Failed to get DataReplacementGroup class"); + cls.call1((fragment_id, new_file)) + } +} + +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + match class_name(ob)?.as_str() { + "Overwrite" => { + let schema = extract_schema(&ob.getattr("new_schema")?)?; + + let fragments = extract_vec(&ob.getattr("fragments")?)?; + + let op = Operation::Overwrite { + schema, + fragments, + config_upsert_values: None, + }; + Ok(Self(op)) + } + "Append" => { + let fragments = extract_vec(&ob.getattr("fragments")?)?; + let op = Operation::Append { fragments }; + Ok(Self(op)) + } + "Delete" => { + let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?; + let deleted_fragment_ids = ob.getattr("deleted_fragment_ids")?.extract()?; + let predicate = ob.getattr("predicate")?.extract()?; + + let op = Operation::Delete { + updated_fragments, + deleted_fragment_ids, + predicate, + }; + Ok(Self(op)) + } + "Update" => { + let removed_fragment_ids = ob.getattr("removed_fragment_ids")?.extract()?; + + let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?; + + let new_fragments = extract_vec(&ob.getattr("new_fragments")?)?; + + let op = Operation::Update { + removed_fragment_ids, + updated_fragments, + new_fragments, + }; + Ok(Self(op)) + } + "Merge" => { + let schema = extract_schema(&ob.getattr("schema")?)?; + + let fragments = ob + .getattr("fragments")? + .extract::>>()?; + let fragments = fragments.into_iter().map(|f| f.0).collect(); + + let op = Operation::Merge { schema, fragments }; + Ok(Self(op)) + } + "Restore" => { + let version = ob.getattr("version")?.extract()?; + let op = Operation::Restore { version }; + Ok(Self(op)) + } + "Rewrite" => { + let groups = extract_vec(&ob.getattr("groups")?)?; + let rewritten_indices = extract_vec(&ob.getattr("rewritten_indices")?)?; + let op = Operation::Rewrite { + groups, + rewritten_indices, + }; + Ok(Self(op)) + } + "CreateIndex" => { + let uuid = ob.getattr("uuid")?.to_string(); + let name = ob.getattr("name")?.extract()?; + let fields = ob.getattr("fields")?.extract()?; + let dataset_version = ob.getattr("dataset_version")?.extract()?; + + let fragment_ids = ob.getattr("fragment_ids")?; + let fragment_ids_ref: &Bound<'_, PySet> = fragment_ids.downcast()?; + let fragment_ids = fragment_ids_ref + .into_iter() + .map(|id| id.extract()) + .collect::>>()?; + let fragment_bitmap = Some(fragment_ids.into_iter().collect()); + + let new_indices = vec![Index { + uuid: Uuid::parse_str(&uuid) + .map_err(|e| PyValueError::new_err(e.to_string()))?, + name, + fields, + dataset_version, + fragment_bitmap, + // TODO: we should use lance::dataset::Dataset::commit_existing_index once + // we have a way to determine index details from an existing index. + index_details: None, + }]; + + let op = Operation::CreateIndex { + removed_indices: Vec::new(), + new_indices, + }; + Ok(Self(op)) + } + "DataReplacement" => { + let replacements = extract_vec(&ob.getattr("replacements")?)?; + + let op = Operation::DataReplacement { replacements }; + + Ok(Self(op)) + } + "Project" => { + let schema = extract_schema(&ob.getattr("schema")?)?; + + let op = Operation::Project { schema }; + Ok(Self(op)) + } + unsupported => Err(PyValueError::new_err(format!( + "Unsupported operation: {unsupported}", + ))), + } + } +} + +impl<'py> IntoPyObject<'py> for PyLance<&Operation> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let namespace = py + .import(intern!(py, "lance")) + .and_then(|module| module.getattr(intern!(py, "LanceOperation"))) + .expect("Failed to import LanceOperation namespace"); + + match self.0 { + Operation::Append { ref fragments } => { + let fragments = export_vec(py, fragments.as_slice())?; + let cls = namespace + .getattr("Append") + .expect("Failed to get Append class"); + cls.call1((fragments,)) + } + Operation::Overwrite { + ref fragments, + ref schema, + .. + } => { + let fragments_py = export_vec(py, fragments.as_slice())?; + + let schema_py = LanceSchema(schema.clone()); + + let cls = namespace + .getattr("Overwrite") + .expect("Failed to get Overwrite class"); + + cls.call1((schema_py, fragments_py)) + } + Operation::Update { + removed_fragment_ids, + updated_fragments, + new_fragments, + } => { + let removed_fragment_ids = removed_fragment_ids.into_pyobject(py)?; + let updated_fragments = export_vec(py, updated_fragments.as_slice())?; + let new_fragments = export_vec(py, new_fragments.as_slice())?; + let cls = namespace + .getattr("Update") + .expect("Failed to get Update class"); + cls.call1((removed_fragment_ids, updated_fragments, new_fragments)) + } + Operation::DataReplacement { replacements } => { + let replacements = export_vec(py, replacements.as_slice())?; + let cls = namespace + .getattr("DataReplacement") + .expect("Failed to get DataReplacement class"); + cls.call1((replacements,)) + } + _ => todo!(), + } + } +} + +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &pyo3::Bound<'_, PyAny>) -> PyResult { + let read_version = ob.getattr("read_version")?.extract()?; + let uuid = ob.getattr("uuid")?.extract()?; + let operation = ob.getattr("operation")?.extract::>()?.0; + let blobs_op = ob + .getattr("blobs_op")? + .extract::>>()? + .map(|op| op.0); + Ok(Self(Transaction { + read_version, + uuid, + operation, + blobs_op, + tag: None, + })) + } +} + +impl<'py> IntoPyObject<'py> for PyLance<&Transaction> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let namespace = py + .import(intern!(py, "lance")) + .expect("Failed to import lance module"); + + let read_version = self.0.read_version; + let uuid = &self.0.uuid; + let operation = PyLance(&self.0.operation).into_pyobject(py)?; + let blobs_op = self + .0 + .blobs_op + .as_ref() + .map(|op| PyLance(op).into_pyobject(py)) + .transpose()?; + + let cls = namespace + .getattr("Transaction") + .expect("Failed to get Transaction class"); + // Unwrap due to infallible + Ok(cls + .call1((read_version, operation, uuid, blobs_op))? + .into_pyobject(py) + .unwrap()) + } +} + +impl<'py> IntoPyObject<'py> for PyLance { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + PyLance(&self.0).into_pyobject(py) + } +} + +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + Ok(Self(RewriteGroup { + old_fragments: extract_vec(&ob.getattr("old_fragments")?)?, + new_fragments: extract_vec(&ob.getattr("new_fragments")?)?, + })) + } +} + +impl<'py> IntoPyObject<'py> for PyLance<&RewriteGroup> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let cls = py + .import(intern!(py, "lance")) + .and_then(|module| module.getattr(intern!(py, "LanceTransaction"))) + .and_then(|cls| cls.getattr(intern!(py, "RewriteGroup"))) + .expect("Failed to get RewriteGroup class"); + + let old_fragments = export_vec(py, self.0.old_fragments.as_slice())?; + let new_fragments = export_vec(py, self.0.new_fragments.as_slice())?; + + cls.call1((old_fragments, new_fragments)) + } +} + +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + let old_id: String = ob.getattr("old_id")?.extract()?; + let new_id: String = ob.getattr("new_id")?.extract()?; + let old_id = Uuid::parse_str(&old_id) + .map_err(|e| PyValueError::new_err(format!("Failed to parse UUID: {}", e)))?; + let new_id = Uuid::parse_str(&new_id) + .map_err(|e| PyValueError::new_err(format!("Failed to parse UUID: {}", e)))?; + Ok(Self(RewrittenIndex { old_id, new_id })) + } +} + +impl<'py> IntoPyObject<'py> for PyLance<&RewrittenIndex> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let cls = py + .import(intern!(py, "lance")) + .and_then(|module| module.getattr(intern!(py, "LanceTransaction"))) + .and_then(|cls| cls.getattr(intern!(py, "RewrittenIndex"))) + .expect("Failed to get RewrittenIndex class"); + + let old_id = self.0.old_id.to_string(); + let new_id = self.0.new_id.to_string(); + cls.call1((old_id, new_id)) + } +} + +fn extract_schema(schema: &Bound<'_, PyAny>) -> PyResult { + match schema.downcast::() { + Ok(schema) => Ok(schema.borrow().0.clone()), + Err(_) => { + let arrow_schema = schema.extract::>()?.0; + convert_schema(&arrow_schema) + } + } +} + +fn convert_schema(arrow_schema: &ArrowSchema) -> PyResult { + // Note: the field ids here are wrong. + Schema::try_from(arrow_schema).map_err(|e| { + PyValueError::new_err(format!( + "Failed to convert Arrow schema to Lance schema: {}", + e + )) + }) +} diff --git a/python/src/utils.rs b/python/src/utils.rs index 9b8420e781b..f52d2844b64 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use arrow::compute::concat; +use arrow::datatypes::Float32Type; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use arrow_array::{cast::AsArray, Array, FixedSizeListArray, Float32Array, UInt32Array}; use arrow_data::ArrayData; @@ -26,17 +27,19 @@ use lance_file::writer::FileWriter; use lance_index::scalar::IndexWriter; use lance_index::vector::hnsw::{builder::HnswBuildParams, HNSW}; use lance_index::vector::v3::subindex::IvfSubIndex; -use lance_linalg::kmeans::compute_partitions; +use lance_linalg::kmeans::{compute_partitions, KMeansAlgoFloat}; use lance_linalg::{ distance::DistanceType, kmeans::{KMeans as LanceKMeans, KMeansParams}, }; use lance_table::io::manifest::ManifestDescribing; use object_store::path::Path; +use pyo3::intern; use pyo3::{ exceptions::{PyIOError, PyRuntimeError, PyValueError}, prelude::*, types::PyIterator, + IntoPyObjectExt, }; use crate::RT; @@ -132,14 +135,17 @@ impl KMeans { if !matches!(fixed_size_arr.value_type(), DataType::Float32) { return Err(PyValueError::new_err("Must be a FixedSizeList of Float32")); }; - let values: Arc = fixed_size_arr.values().as_primitive().clone().into(); - let centroids: &Float32Array = kmeans.centroids.as_primitive(); - let cluster_ids = UInt32Array::from(compute_partitions( - centroids.values(), - values.values(), - kmeans.dimension, - kmeans.distance_type, - )); + let values = fixed_size_arr.values().as_primitive(); + let centroids = kmeans.centroids.as_primitive(); + let cluster_ids = UInt32Array::from( + compute_partitions::>( + centroids, + values, + kmeans.dimension, + kmeans.distance_type, + ) + .0, + ); cluster_ids.into_data().to_pyarrow(py) } @@ -242,3 +248,38 @@ impl Hnsw { self.vectors.to_data().to_pyarrow(py) } } + +/// A newtype wrapper for a Lance type. +/// +/// This is used for types that have a corresponding dataclass in Python. +pub struct PyLance(pub T); + +/// Extract a Vec of PyLance types from a Python object. +pub fn extract_vec<'a, T>(ob: &Bound<'a, PyAny>) -> PyResult> +where + PyLance: FromPyObject<'a>, +{ + ob.extract::>>() + .map(|v| v.into_iter().map(|t| t.0).collect()) +} + +/// Export a Vec of Lance types to a Python object. +pub fn export_vec<'a, T>(py: Python<'a>, vec: &'a [T]) -> PyResult> +where + PyLance<&'a T>: IntoPyObject<'a>, +{ + vec.iter() + .map(|t| PyLance(t).into_py_any(py)) + .collect::, _>>() +} + +pub fn class_name(ob: &Bound<'_, PyAny>) -> PyResult { + let full_name: String = ob + .getattr(intern!(ob.py(), "__class__"))? + .getattr(intern!(ob.py(), "__name__"))? + .extract()?; + match full_name.rsplit_once('.') { + Some((_, name)) => Ok(name.to_string()), + None => Ok(full_name), + } +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000000..71747395157 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,4 @@ +# We keep this pinned to keep clippy and rustfmt in sync between local and CI. +# Feel free to upgrade to bring in new lints. +[toolchain] +channel = "1.86.0" diff --git a/rust/lance-arrow/Cargo.toml b/rust/lance-arrow/Cargo.toml index 64dea8d8db2..d6b870965b1 100644 --- a/rust/lance-arrow/Cargo.toml +++ b/rust/lance-arrow/Cargo.toml @@ -20,6 +20,7 @@ arrow-data = { workspace = true } arrow-cast = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } +bytes = { workspace = true } half = { workspace = true } num-traits = { workspace = true } rand.workspace = true diff --git a/rust/lance-arrow/src/bfloat16.rs b/rust/lance-arrow/src/bfloat16.rs index 467da00a5aa..06079d9baaf 100644 --- a/rust/lance-arrow/src/bfloat16.rs +++ b/rust/lance-arrow/src/bfloat16.rs @@ -90,7 +90,7 @@ impl BFloat16Array { } } -impl<'a> ArrayAccessor for &'a BFloat16Array { +impl ArrayAccessor for &BFloat16Array { type Item = bf16; fn value(&self, index: usize) -> Self::Item { diff --git a/rust/lance-arrow/src/deepcopy.rs b/rust/lance-arrow/src/deepcopy.rs index 7a04fc1c9f0..f3c0c13fd01 100644 --- a/rust/lance-arrow/src/deepcopy.rs +++ b/rust/lance-arrow/src/deepcopy.rs @@ -4,22 +4,28 @@ use std::sync::Arc; use arrow_array::{make_array, Array, RecordBatch}; -use arrow_buffer::{Buffer, NullBuffer}; -use arrow_data::ArrayData; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; pub fn deep_copy_buffer(buffer: &Buffer) -> Buffer { - Buffer::from(Vec::from(buffer.as_slice())) + Buffer::from(buffer.as_slice()) } -fn deep_copy_nulls(nulls: &NullBuffer) -> Buffer { - deep_copy_buffer(nulls.inner().inner()) +fn deep_copy_nulls(nulls: Option<&NullBuffer>) -> Option { + let nulls = nulls?; + let bit_buffer = deep_copy_buffer(nulls.inner().inner()); + Some(unsafe { + NullBuffer::new_unchecked( + BooleanBuffer::new(bit_buffer, nulls.offset(), nulls.len()), + nulls.null_count(), + ) + }) } pub fn deep_copy_array_data(data: &ArrayData) -> ArrayData { let data_type = data.data_type().clone(); let len = data.len(); - let null_count = data.null_count(); - let null_bit_buffer = data.nulls().map(deep_copy_nulls); + let nulls = deep_copy_nulls(data.nulls()); let offset = data.offset(); let buffers = data .buffers() @@ -32,15 +38,13 @@ pub fn deep_copy_array_data(data: &ArrayData) -> ArrayData { .map(deep_copy_array_data) .collect::>(); unsafe { - ArrayData::new_unchecked( - data_type, - len, - Some(null_count), - null_bit_buffer, - offset, - buffers, - child_data, - ) + ArrayDataBuilder::new(data_type) + .len(len) + .nulls(nulls) + .offset(offset) + .buffers(buffers) + .child_data(child_data) + .build_unchecked() } } @@ -58,3 +62,25 @@ pub fn deep_copy_batch(batch: &RecordBatch) -> crate::Result { .collect::>(); RecordBatch::try_new(batch.schema(), arrays) } + +#[cfg(test)] +pub mod tests { + use std::sync::Arc; + + use arrow_array::{Array, Int32Array}; + + #[test] + fn test_deep_copy_sliced_array_with_nulls() { + let array = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + None, + Some(5), + ])); + let sliced_array = array.slice(1, 3); + let copied_array = super::deep_copy_array(&sliced_array); + assert_eq!(sliced_array.len(), copied_array.len()); + assert_eq!(sliced_array.nulls(), copied_array.nulls()); + } +} diff --git a/rust/lance-arrow/src/floats.rs b/rust/lance-arrow/src/floats.rs index 8f289804eed..498c1e46f26 100644 --- a/rust/lance-arrow/src/floats.rs +++ b/rust/lance-arrow/src/floats.rs @@ -5,6 +5,7 @@ use std::fmt::{Debug, Display}; use std::iter::Sum; +use std::sync::Arc; use std::{ fmt::Formatter, ops::{AddAssign, DivAssign}, @@ -202,16 +203,16 @@ impl FloatArray for Float64Array { } /// Convert a float32 array to another float array. -pub fn coerce_float_vector(input: &Float32Array, float_type: FloatType) -> Result> { +pub fn coerce_float_vector(input: &Float32Array, float_type: FloatType) -> Result> { match float_type { - FloatType::BFloat16 => Ok(Box::new(BFloat16Array::from_iter_values( + FloatType::BFloat16 => Ok(Arc::new(BFloat16Array::from_iter_values( input.values().iter().map(|v| bf16::from_f32(*v)), ))), - FloatType::Float16 => Ok(Box::new(Float16Array::from_iter_values( + FloatType::Float16 => Ok(Arc::new(Float16Array::from_iter_values( input.values().iter().map(|v| f16::from_f32(*v)), ))), - FloatType::Float32 => Ok(Box::new(input.clone())), - FloatType::Float64 => Ok(Box::new(Float64Array::from_iter_values( + FloatType::Float32 => Ok(Arc::new(input.clone())), + FloatType::Float64 => Ok(Arc::new(Float64Array::from_iter_values( input.values().iter().map(|v| *v as f64), ))), } diff --git a/rust/lance-arrow/src/lib.rs b/rust/lance-arrow/src/lib.rs index eafd9586593..6c055f9126a 100644 --- a/rust/lance-arrow/src/lib.rs +++ b/rust/lance-arrow/src/lib.rs @@ -5,14 +5,18 @@ //! //! To improve Arrow-RS ergonomic -use std::collections::HashMap; use std::sync::Arc; +use std::{collections::HashMap, ptr::NonNull}; use arrow_array::{ cast::AsArray, Array, ArrayRef, ArrowNumericType, FixedSizeBinaryArray, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt32Array, UInt8Array, }; +use arrow_array::{ + new_null_array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, +}; +use arrow_buffer::MutableBuffer; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema}; use arrow_select::{interleave::interleave, take::take}; @@ -25,6 +29,8 @@ pub mod bfloat16; pub mod floats; pub use floats::*; pub mod cast; +pub mod list; +pub mod memory; type Result = std::result::Result; @@ -233,6 +239,10 @@ pub trait FixedSizeListArrayExt { /// assert_eq!(sampled.values().len(), 160); /// ``` fn sample(&self, n: usize) -> Result; + + /// Ensure the [FixedSizeListArray] of Float16, Float32, Float64, + /// Int8, Int16, Int32, Int64, UInt8, UInt32 type to its closest floating point type. + fn convert_to_floating_point(&self) -> Result; } impl FixedSizeListArrayExt for FixedSizeListArray { @@ -251,6 +261,136 @@ impl FixedSizeListArrayExt for FixedSizeListArray { let chosen = (0..self.len() as u32).choose_multiple(&mut rng, n); take(self, &UInt32Array::from(chosen), None).map(|arr| arr.as_fixed_size_list().clone()) } + + fn convert_to_floating_point(&self) -> Result { + match self.data_type() { + DataType::FixedSizeList(field, size) => match field.data_type() { + DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(self.clone()), + DataType::Int8 => Ok(Self::new( + Arc::new(arrow_schema::Field::new( + field.name(), + DataType::Float32, + field.is_nullable(), + )), + *size, + Arc::new(Float32Array::from_iter_values( + self.values() + .as_any() + .downcast_ref::() + .ok_or(ArrowError::ParseError( + "Fail to cast primitive array to Int8Type".to_string(), + ))? + .into_iter() + .filter_map(|x| x.map(|y| y as f32)), + )), + self.nulls().cloned(), + )), + DataType::Int16 => Ok(Self::new( + Arc::new(arrow_schema::Field::new( + field.name(), + DataType::Float32, + field.is_nullable(), + )), + *size, + Arc::new(Float32Array::from_iter_values( + self.values() + .as_any() + .downcast_ref::() + .ok_or(ArrowError::ParseError( + "Fail to cast primitive array to Int8Type".to_string(), + ))? + .into_iter() + .filter_map(|x| x.map(|y| y as f32)), + )), + self.nulls().cloned(), + )), + DataType::Int32 => Ok(Self::new( + Arc::new(arrow_schema::Field::new( + field.name(), + DataType::Float32, + field.is_nullable(), + )), + *size, + Arc::new(Float32Array::from_iter_values( + self.values() + .as_any() + .downcast_ref::() + .ok_or(ArrowError::ParseError( + "Fail to cast primitive array to Int8Type".to_string(), + ))? + .into_iter() + .filter_map(|x| x.map(|y| y as f32)), + )), + self.nulls().cloned(), + )), + DataType::Int64 => Ok(Self::new( + Arc::new(arrow_schema::Field::new( + field.name(), + DataType::Float64, + field.is_nullable(), + )), + *size, + Arc::new(Float64Array::from_iter_values( + self.values() + .as_any() + .downcast_ref::() + .ok_or(ArrowError::ParseError( + "Fail to cast primitive array to Int8Type".to_string(), + ))? + .into_iter() + .filter_map(|x| x.map(|y| y as f64)), + )), + self.nulls().cloned(), + )), + DataType::UInt8 => Ok(Self::new( + Arc::new(arrow_schema::Field::new( + field.name(), + DataType::Float64, + field.is_nullable(), + )), + *size, + Arc::new(Float64Array::from_iter_values( + self.values() + .as_any() + .downcast_ref::() + .ok_or(ArrowError::ParseError( + "Fail to cast primitive array to Int8Type".to_string(), + ))? + .into_iter() + .filter_map(|x| x.map(|y| y as f64)), + )), + self.nulls().cloned(), + )), + DataType::UInt32 => Ok(Self::new( + Arc::new(arrow_schema::Field::new( + field.name(), + DataType::Float64, + field.is_nullable(), + )), + *size, + Arc::new(Float64Array::from_iter_values( + self.values() + .as_any() + .downcast_ref::() + .ok_or(ArrowError::ParseError( + "Fail to cast primitive array to Int8Type".to_string(), + ))? + .into_iter() + .filter_map(|x| x.map(|y| y as f64)), + )), + self.nulls().cloned(), + )), + data_type => Err(ArrowError::ParseError(format!( + "Expect either floating type or integer got {:?}", + data_type + ))), + }, + data_type => Err(ArrowError::ParseError(format!( + "Expect either FixedSizeList got {:?}", + data_type + ))), + } + } } /// Force downcast of an [`Array`], such as an [`ArrayRef`], to @@ -347,6 +487,17 @@ pub trait RecordBatchExt { /// Merge with another [`RecordBatch`] and returns a new one. /// + /// Fields are merged based on name. First we iterate the left columns. If a matching + /// name is found in the right then we merge the two columns. If there is no match then + /// we add the left column to the output. + /// + /// To merge two columns we consider the type. If both arrays are struct arrays we recurse. + /// Otherwise we use the left array. + /// + /// Afterwards we add all non-matching right columns to the output. + /// + /// Note: This method likely does not handle nested fields correctly and you may want to consider + /// using [`merge_with_schema`] instead. /// ``` /// use std::sync::Arc; /// use arrow_array::*; @@ -380,6 +531,17 @@ pub trait RecordBatchExt { /// TODO: add merge nested fields support. fn merge(&self, other: &RecordBatch) -> Result; + /// Create a batch by merging columns between two batches with a given schema. + /// + /// A reference schema is used to determine the proper ordering of nested fields. + /// + /// For each field in the reference schema we look for corresponding fields in + /// the left and right batches. If a field is found in both batches we recursively merge + /// it. + /// + /// If a field is only in the left or right batch we take it as it is. + fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result; + /// Drop one column specified with the name and return the new [`RecordBatch`]. /// /// If the named column does not exist, it returns a copy of this [`RecordBatch`]. @@ -388,6 +550,14 @@ pub trait RecordBatchExt { /// Replace a column (specified by name) and return the new [`RecordBatch`]. fn replace_column_by_name(&self, name: &str, column: Arc) -> Result; + /// Replace a column schema (specified by name) and return the new [`RecordBatch`]. + fn replace_column_schema_by_name( + &self, + name: &str, + new_data_type: DataType, + column: Arc, + ) -> Result; + /// Get (potentially nested) column by qualified name. fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef>; @@ -448,6 +618,23 @@ impl RecordBatchExt for RecordBatch { self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array)) } + fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result { + if self.num_rows() != other.num_rows() { + return Err(ArrowError::InvalidArgumentError(format!( + "Attempt to merge two RecordBatch with different sizes: {} != {}", + self.num_rows(), + other.num_rows() + ))); + } + let left_struct_array: StructArray = self.clone().into(); + let right_struct_array: StructArray = other.clone().into(); + self.try_new_from_struct_array(merge_with_schema( + &left_struct_array, + &right_struct_array, + schema.fields(), + )) + } + fn drop_column(&self, name: &str) -> Result { let mut fields = vec![]; let mut columns = vec![]; @@ -478,6 +665,37 @@ impl RecordBatchExt for RecordBatch { Self::try_new(self.schema(), columns) } + fn replace_column_schema_by_name( + &self, + name: &str, + new_data_type: DataType, + column: Arc, + ) -> Result { + let fields = self + .schema() + .fields() + .iter() + .map(|x| { + if x.name() != name { + x.clone() + } else { + let new_field = Field::new(name, new_data_type.clone(), x.is_nullable()); + Arc::new(new_field) + } + }) + .collect::>(); + let schema = Schema::new_with_metadata(fields, self.schema().metadata.clone()); + let mut columns = self.columns().to_vec(); + let field_i = self + .schema() + .fields() + .iter() + .position(|f| f.name() == name) + .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?; + columns[field_i] = column; + Self::try_new(Arc::new(schema), columns) + } + fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef> { let split = name.split('.').collect::>(); if split.is_empty() { @@ -540,7 +758,150 @@ fn project(struct_array: &StructArray, fields: &Fields) -> Result { StructArray::try_new(fields.clone(), columns, None) } -/// Merge the fields and columns of two RecordBatch's recursively +fn lists_have_same_offsets_helper(left: &dyn Array, right: &dyn Array) -> bool { + let left_list: &GenericListArray = left.as_list(); + let right_list: &GenericListArray = right.as_list(); + left_list.offsets().inner() == right_list.offsets().inner() +} + +fn merge_list_structs_helper( + left: &dyn Array, + right: &dyn Array, + items_field_name: impl Into, + items_nullable: bool, +) -> Arc { + let left_list: &GenericListArray = left.as_list(); + let right_list: &GenericListArray = right.as_list(); + let left_struct = left_list.values(); + let right_struct = right_list.values(); + let left_struct_arr = left_struct.as_struct(); + let right_struct_arr = right_struct.as_struct(); + let merged_items = Arc::new(merge(left_struct_arr, right_struct_arr)); + let items_field = Arc::new(Field::new( + items_field_name, + merged_items.data_type().clone(), + items_nullable, + )); + Arc::new(GenericListArray::::new( + items_field, + left_list.offsets().clone(), + merged_items, + left_list.nulls().cloned(), + )) +} + +fn merge_list_struct_null_helper( + left: &dyn Array, + right: &dyn Array, + not_null: &dyn Array, + items_field_name: impl Into, +) -> Arc { + let left_list: &GenericListArray = left.as_list::(); + let not_null_list = not_null.as_list::(); + let right_list = right.as_list::(); + + let left_struct = left_list.values().as_struct(); + let not_null_struct: &StructArray = not_null_list.values().as_struct(); + let right_struct = right_list.values().as_struct(); + + let values_len = not_null_list.values().len(); + let mut merged_fields = + Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns()); + let mut merged_columns = + Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns()); + + for (_, field) in left_struct.columns().iter().zip(left_struct.fields()) { + merged_fields.push(field.clone()); + if let Some(val) = not_null_struct.column_by_name(field.name()) { + merged_columns.push(val.clone()); + } else { + merged_columns.push(new_null_array(field.data_type(), values_len)) + } + } + for (_, field) in right_struct + .columns() + .iter() + .zip(right_struct.fields()) + .filter(|(_, field)| left_struct.column_by_name(field.name()).is_none()) + { + merged_fields.push(field.clone()); + if let Some(val) = not_null_struct.column_by_name(field.name()) { + merged_columns.push(val.clone()); + } else { + merged_columns.push(new_null_array(field.data_type(), values_len)); + } + } + + let merged_struct = Arc::new(StructArray::new( + Fields::from(merged_fields), + merged_columns, + not_null_struct.nulls().cloned(), + )); + let items_field = Arc::new(Field::new( + items_field_name, + merged_struct.data_type().clone(), + true, + )); + Arc::new(GenericListArray::::new( + items_field, + not_null_list.offsets().clone(), + merged_struct, + not_null_list.nulls().cloned(), + )) +} + +fn merge_list_struct_null( + left: &dyn Array, + right: &dyn Array, + not_null: &dyn Array, +) -> Arc { + match left.data_type() { + DataType::List(left_field) => { + merge_list_struct_null_helper::(left, right, not_null, left_field.name()) + } + DataType::LargeList(left_field) => { + merge_list_struct_null_helper::(left, right, not_null, left_field.name()) + } + _ => unreachable!(), + } +} + +fn merge_list_struct(left: &dyn Array, right: &dyn Array) -> Arc { + // Merging fields into a list> is tricky and can only succeed + // in two ways. First, if both lists have the same offsets. Second, if + // one of the lists is all-null + if left.null_count() == left.len() { + return merge_list_struct_null(left, right, right); + } else if right.null_count() == right.len() { + return merge_list_struct_null(left, right, left); + } + match (left.data_type(), right.data_type()) { + (DataType::List(left_field), DataType::List(_)) => { + if !lists_have_same_offsets_helper::(left, right) { + panic!("Attempt to merge list struct arrays which do not have same offsets"); + } + merge_list_structs_helper::( + left, + right, + left_field.name(), + left_field.is_nullable(), + ) + } + (DataType::LargeList(left_field), DataType::LargeList(_)) => { + if !lists_have_same_offsets_helper::(left, right) { + panic!("Attempt to merge list struct arrays which do not have same offsets"); + } + merge_list_structs_helper::( + left, + right, + left_field.name(), + left_field.is_nullable(), + ) + } + _ => unreachable!(), + } +} + fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray { let mut fields: Vec = vec![]; let mut columns: Vec = vec![]; @@ -574,6 +935,27 @@ fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> S )); columns.push(Arc::new(merged_sub_array) as ArrayRef); } + (DataType::List(left_list), DataType::List(right_list)) + if left_list.data_type().is_struct() + && right_list.data_type().is_struct() => + { + // If there is nothing to merge just use the left field + if left_list.data_type() == right_list.data_type() { + fields.push(left_field.as_ref().clone()); + columns.push(left_column.clone()); + } + // If we have two List and they have different sets of fields then + // we can merge them if the offsets arrays are the same. Otherwise, we + // have to consider it an error. + let merged_sub_array = merge_list_struct(&left_column, &right_column); + + fields.push(Field::new( + left_field.name(), + merged_sub_array.data_type().clone(), + left_field.is_nullable(), + )); + columns.push(merged_sub_array); + } // otherwise, just use the field on the left hand side _ => { // TODO handle list-of-struct and other types @@ -614,6 +996,77 @@ fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> S StructArray::from(zipped) } +fn merge_with_schema( + left_struct_array: &StructArray, + right_struct_array: &StructArray, + fields: &Fields, +) -> StructArray { + // Helper function that returns true if both types are struct or both are non-struct + fn same_type_kind(left: &DataType, right: &DataType) -> bool { + match (left, right) { + (DataType::Struct(_), DataType::Struct(_)) => true, + (DataType::Struct(_), _) => false, + (_, DataType::Struct(_)) => false, + _ => true, + } + } + + let mut output_fields: Vec = Vec::with_capacity(fields.len()); + let mut columns: Vec = Vec::with_capacity(fields.len()); + + let left_fields = left_struct_array.fields(); + let left_columns = left_struct_array.columns(); + let right_fields = right_struct_array.fields(); + let right_columns = right_struct_array.columns(); + + for field in fields { + let left_match_idx = left_fields.iter().position(|f| { + f.name() == field.name() && same_type_kind(f.data_type(), field.data_type()) + }); + let right_match_idx = right_fields.iter().position(|f| { + f.name() == field.name() && same_type_kind(f.data_type(), field.data_type()) + }); + + match (left_match_idx, right_match_idx) { + (None, Some(right_idx)) => { + output_fields.push(right_fields[right_idx].as_ref().clone()); + columns.push(right_columns[right_idx].clone()); + } + (Some(left_idx), None) => { + output_fields.push(left_fields[left_idx].as_ref().clone()); + columns.push(left_columns[left_idx].clone()); + } + (Some(left_idx), Some(right_idx)) => { + if let DataType::Struct(child_fields) = field.data_type() { + let left_sub_array = left_columns[left_idx].as_struct(); + let right_sub_array = right_columns[right_idx].as_struct(); + let merged_sub_array = + merge_with_schema(left_sub_array, right_sub_array, child_fields); + output_fields.push(Field::new( + field.name(), + merged_sub_array.data_type().clone(), + field.is_nullable(), + )); + columns.push(Arc::new(merged_sub_array) as ArrayRef); + } else { + output_fields.push(left_fields[left_idx].as_ref().clone()); + columns.push(left_columns[left_idx].clone()); + } + } + (None, None) => { + // The field will not be included in the output + } + } + } + + let zipped: Vec<(FieldRef, ArrayRef)> = output_fields + .into_iter() + .map(Arc::new) + .zip(columns) + .collect::>(); + StructArray::from(zipped) +} + fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> { if components.is_empty() { return Some(array); @@ -654,10 +1107,73 @@ pub fn interleave_batches( RecordBatch::try_new(schema, columns) } +pub trait BufferExt { + /// Create an `arrow_buffer::Buffer`` from a `bytes::Bytes` object + /// + /// The alignment must be specified (as `bytes_per_value`) since we want to make + /// sure we can safely reinterpret the buffer. + /// + /// If the buffer is properly aligned this will be zero-copy. If not, a copy + /// will be made and an owned buffer returned. + /// + /// If `bytes_per_value` is not a power of two, then we assume the buffer is + /// never going to be reinterpreted into another type and we can safely + /// ignore the alignment. + /// + /// Yes, the method name is odd. It's because there is already a `from_bytes` + /// which converts from `arrow_buffer::bytes::Bytes` (not `bytes::Bytes`) + fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self; + + /// Allocates a new properly aligned arrow buffer and copies `bytes` into it + /// + /// `size_bytes` can be larger than `bytes` and, if so, the trailing bytes will + /// be zeroed out. + /// + /// # Panics + /// + /// Panics if `size_bytes` is less than `bytes.len()` + fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self; +} + +fn is_pwr_two(n: u64) -> bool { + n & (n - 1) == 0 +} + +impl BufferExt for arrow_buffer::Buffer { + fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self { + if is_pwr_two(bytes_per_value) && bytes.as_ptr().align_offset(bytes_per_value as usize) != 0 + { + // The original buffer is not aligned, cannot zero-copy + let size_bytes = bytes.len(); + Self::copy_bytes_bytes(bytes, size_bytes) + } else { + // The original buffer is aligned, can zero-copy + // SAFETY: the alignment is correct we can make this conversion + unsafe { + Self::from_custom_allocation( + NonNull::new(bytes.as_ptr() as _).expect("should be a valid pointer"), + bytes.len(), + Arc::new(bytes), + ) + } + } + } + + fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self { + assert!(size_bytes >= bytes.len()); + let mut buf = MutableBuffer::with_capacity(size_bytes); + let to_fill = size_bytes - bytes.len(); + buf.extend(bytes); + buf.extend(std::iter::repeat_n(0_u8, to_fill)); + Self::from(buf) + } +} + #[cfg(test)] mod tests { use super::*; - use arrow_array::{Int32Array, StringArray}; + use arrow_array::{new_empty_array, new_null_array, Int32Array, ListArray, StringArray}; + use arrow_buffer::OffsetBuffer; #[test] fn test_merge_recursive() { @@ -744,6 +1260,138 @@ mod tests { assert_eq!(result, merged_batch); } + #[test] + fn test_merge_with_schema() { + fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) { + let fields: Fields = names + .iter() + .zip(types) + .map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false)) + .collect(); + let schema = Schema::new(vec![Field::new( + "struct", + DataType::Struct(fields.clone()), + false, + )]); + let children = types.iter().map(new_empty_array).collect::>(); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef], + ); + (schema, batch.unwrap()) + } + + let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]); + let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]); + let (output_schema, _) = test_batch( + &["b", "a", "c"], + &[DataType::Int64, DataType::Int32, DataType::Int32], + ); + + // If we use merge_with_schema the schema is respected + let merged = left_batch + .merge_with_schema(&right_batch, &output_schema) + .unwrap(); + assert_eq!(merged.schema().as_ref(), &output_schema); + + // If we use merge we get first-come first-serve based on the left batch + let (naive_schema, _) = test_batch( + &["a", "b", "c"], + &[DataType::Int32, DataType::Int64, DataType::Int32], + ); + let merged = left_batch.merge(&right_batch).unwrap(); + assert_eq!(merged.schema().as_ref(), &naive_schema); + } + + #[test] + fn test_merge_list_struct() { + let x_field = Arc::new(Field::new("x", DataType::Int32, true)); + let y_field = Arc::new(Field::new("y", DataType::Int32, true)); + let x_struct_field = Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![x_field.clone()])), + true, + )); + let y_struct_field = Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![y_field.clone()])), + true, + )); + let both_struct_field = Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![x_field.clone(), y_field.clone()])), + true, + )); + let left_schema = Schema::new(vec![Field::new( + "list_struct", + DataType::List(x_struct_field.clone()), + true, + )]); + let right_schema = Schema::new(vec![Field::new( + "list_struct", + DataType::List(y_struct_field.clone()), + true, + )]); + let both_schema = Schema::new(vec![Field::new( + "list_struct", + DataType::List(both_struct_field.clone()), + true, + )]); + + let x = Arc::new(Int32Array::from(vec![1])); + let y = Arc::new(Int32Array::from(vec![2])); + let x_struct = Arc::new(StructArray::new( + Fields::from(vec![x_field.clone()]), + vec![x.clone()], + None, + )); + let y_struct = Arc::new(StructArray::new( + Fields::from(vec![y_field.clone()]), + vec![y.clone()], + None, + )); + let both_struct = Arc::new(StructArray::new( + Fields::from(vec![x_field.clone(), y_field.clone()]), + vec![x.clone(), y], + None, + )); + let both_null_struct = Arc::new(StructArray::new( + Fields::from(vec![x_field, y_field]), + vec![x, Arc::new(new_null_array(&DataType::Int32, 1))], + None, + )); + let offsets = OffsetBuffer::from_lengths([1]); + let x_s_list = ListArray::new(x_struct_field, offsets.clone(), x_struct, None); + let y_s_list = ListArray::new(y_struct_field, offsets.clone(), y_struct, None); + let both_list = ListArray::new( + both_struct_field.clone(), + offsets.clone(), + both_struct, + None, + ); + let both_null_list = ListArray::new(both_struct_field, offsets, both_null_struct, None); + let x_batch = + RecordBatch::try_new(Arc::new(left_schema), vec![Arc::new(x_s_list)]).unwrap(); + let y_batch = RecordBatch::try_new( + Arc::new(right_schema.clone()), + vec![Arc::new(y_s_list.clone())], + ) + .unwrap(); + let merged = x_batch.merge(&y_batch).unwrap(); + let expected = + RecordBatch::try_new(Arc::new(both_schema.clone()), vec![Arc::new(both_list)]).unwrap(); + assert_eq!(merged, expected); + + let y_null_list = new_null_array(y_s_list.data_type(), 1); + let y_null_batch = + RecordBatch::try_new(Arc::new(right_schema), vec![Arc::new(y_null_list.clone())]) + .unwrap(); + let expected = + RecordBatch::try_new(Arc::new(both_schema), vec![Arc::new(both_null_list)]).unwrap(); + let merged = x_batch.merge(&y_null_batch).unwrap(); + assert_eq!(merged, expected); + } + #[test] fn test_take_record_batch() { let schema = Arc::new(Schema::new(vec![ diff --git a/rust/lance-arrow/src/list.rs b/rust/lance-arrow/src/list.rs new file mode 100644 index 00000000000..0c24fc579da --- /dev/null +++ b/rust/lance-arrow/src/list.rs @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow_array::{Array, BooleanArray, GenericListArray, OffsetSizeTrait}; +use arrow_buffer::{BooleanBufferBuilder, OffsetBuffer, ScalarBuffer}; +use arrow_schema::Field; + +pub trait ListArrayExt { + /// Filters out masked null items from the list array + /// + /// It is legal for a list array to have a null entry with a non-zero length. The + /// values inside the entry are "garbage" and should be ignored. This function + /// filters the values array to remove the garbage values. + /// + /// The output list will always have zero-length nulls. + fn filter_garbage_nulls(&self) -> Self; + /// Returns a copy of the list's values array that has been sliced to size + /// + /// It is legal for a list array's offsets to not start with zero. It's also legal + /// for a list array's offsets to not extend to the entire values array. This function + /// behaves similarly to `values()` except it slices the array so that it starts at + /// the first list offset and ends at the last list offset. + fn trimmed_values(&self) -> Arc; +} + +impl ListArrayExt for GenericListArray { + fn filter_garbage_nulls(&self) -> Self { + if self.is_empty() { + return self.clone(); + } + let Some(validity) = self.nulls().cloned() else { + return self.clone(); + }; + + let mut should_keep = BooleanBufferBuilder::new(self.values().len()); + + // Handle case where offsets do not start at 0 + let preamble_len = self.offsets().first().unwrap().to_usize().unwrap(); + should_keep.append_n(preamble_len, false); + + let mut new_offsets: Vec = Vec::with_capacity(self.len() + 1); + new_offsets.push(OffsetSize::zero()); + let mut cur_len = OffsetSize::zero(); + for (offset, is_valid) in self.offsets().windows(2).zip(validity.iter()) { + let len = offset[1] - offset[0]; + if is_valid { + cur_len += len; + should_keep.append_n(len.to_usize().unwrap(), true); + new_offsets.push(cur_len); + } else { + should_keep.append_n(len.to_usize().unwrap(), false); + new_offsets.push(cur_len); + } + } + + // Offsets may not reference entire values buffer + let trailer = self.values().len() - should_keep.len(); + should_keep.append_n(trailer, false); + + let should_keep = should_keep.finish(); + let should_keep = BooleanArray::new(should_keep, None); + let new_values = arrow_select::filter::filter(self.values(), &should_keep).unwrap(); + let new_offsets = ScalarBuffer::from(new_offsets); + let new_offsets = OffsetBuffer::new(new_offsets); + + Self::new( + Arc::new(Field::new( + "item", + self.value_type(), + self.values().is_nullable(), + )), + new_offsets, + new_values, + Some(validity), + ) + } + + fn trimmed_values(&self) -> Arc { + let first_value = self + .offsets() + .first() + .map(|v| v.to_usize().unwrap()) + .unwrap_or(0); + let last_value = self + .offsets() + .last() + .map(|v| v.to_usize().unwrap()) + .unwrap_or(0); + self.values().slice(first_value, last_value - first_value) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{ListArray, UInt64Array}; + use arrow_buffer::{BooleanBuffer, NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow_schema::{DataType, Field}; + + use super::ListArrayExt; + + #[test] + fn test_filter_garbage_nulls() { + let items = UInt64Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + let offsets = ScalarBuffer::::from(vec![2, 5, 8, 9]); + let offsets = OffsetBuffer::new(offsets); + let list_validity = NullBuffer::new(BooleanBuffer::from(vec![true, false, true])); + let list_arr = ListArray::new( + Arc::new(Field::new("item", DataType::UInt64, true)), + offsets, + Arc::new(items), + Some(list_validity.clone()), + ); + + let filtered = list_arr.filter_garbage_nulls(); + + let expected_items = UInt64Array::from(vec![2, 3, 4, 8]); + let offsets = ScalarBuffer::::from(vec![0, 3, 3, 4]); + let expected = ListArray::new( + Arc::new(Field::new("item", DataType::UInt64, false)), + OffsetBuffer::new(offsets), + Arc::new(expected_items), + Some(list_validity), + ); + + assert_eq!(filtered, expected); + } + + #[test] + fn test_trim_values() { + let items = UInt64Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + let offsets = ScalarBuffer::::from(vec![2, 5, 6, 8, 9]); + let offsets = OffsetBuffer::new(offsets); + let list_arr = ListArray::new( + Arc::new(Field::new("item", DataType::UInt64, true)), + offsets, + Arc::new(items), + None, + ); + let list_arr = list_arr.slice(1, 2); + + let trimmed = list_arr.trimmed_values(); + + let expected_items = UInt64Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + let expected_items = expected_items.slice(5, 3); + + assert_eq!(trimmed.as_ref(), &expected_items); + } +} diff --git a/rust/lance-arrow/src/memory.rs b/rust/lance-arrow/src/memory.rs new file mode 100644 index 00000000000..6b8db9da769 --- /dev/null +++ b/rust/lance-arrow/src/memory.rs @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::collections::HashSet; + +use arrow_array::{Array, RecordBatch}; +use arrow_data::ArrayData; + +/// Counts memory used by buffers of Arrow arrays and RecordBatches. +/// +/// This is meant to capture how much memory is being used by the Arrow data +/// structures as they are. It does not represent the memory used if the data +/// were to be serialized and then deserialized. In particular: +/// +/// * This does not double count memory used by buffers shared by multiple +/// arrays or batches. Round-tripped data may use more memory because of this. +/// * This counts the **total** size of the buffers, even if the array is a slice. +/// Round-tripped data may use less memory because of this. +#[derive(Default)] +pub struct MemoryAccumulator { + seen: HashSet, + total: usize, +} + +impl MemoryAccumulator { + pub fn record_array(&mut self, array: &dyn Array) { + let data = array.to_data(); + self.record_array_data(&data); + } + + fn record_array_data(&mut self, data: &ArrayData) { + for buffer in data.buffers() { + let ptr = buffer.as_ptr(); + if self.seen.insert(ptr as usize) { + self.total += buffer.capacity(); + } + } + + if let Some(nulls) = data.nulls() { + let null_buf = nulls.inner().inner(); + let ptr = null_buf.as_ptr(); + if self.seen.insert(ptr as usize) { + self.total += null_buf.capacity(); + } + } + + for child in data.child_data() { + self.record_array_data(child); + } + } + + pub fn record_batch(&mut self, batch: &RecordBatch) { + for array in batch.columns() { + self.record_array(array); + } + } + + pub fn total(&self) -> usize { + self.total + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::Int32Array; + use arrow_schema::{DataType, Field, Schema}; + + use super::*; + + #[test] + fn test_memory_accumulator() { + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let slice = batch.slice(1, 2); + + let mut acc = MemoryAccumulator::default(); + + // Should record whole buffer, not just slice + acc.record_batch(&slice); + assert_eq!(acc.total(), 3 * std::mem::size_of::()); + + // Should not double count + acc.record_batch(&slice); + assert_eq!(acc.total(), 3 * std::mem::size_of::()); + } +} diff --git a/rust/lance-arrow/src/schema.rs b/rust/lance-arrow/src/schema.rs index 73f1f969647..aa48ce9352d 100644 --- a/rust/lance-arrow/src/schema.rs +++ b/rust/lance-arrow/src/schema.rs @@ -32,6 +32,8 @@ pub trait FieldExt { /// /// This is intended for display purposes and not for serialization fn to_compact_string(&self, indent: Indentation) -> String; + + fn is_packed_struct(&self) -> bool; } impl FieldExt for Field { @@ -79,6 +81,15 @@ impl FieldExt for Field { } result } + + // Check if field has metadata `packed` set to true, this check is case insensitive. + fn is_packed_struct(&self) -> bool { + let field_metadata = self.metadata(); + field_metadata + .get("packed") + .map(|v| v.to_lowercase() == "true") + .unwrap_or(false) + } } /// Extends the functionality of [arrow_schema::Schema]. diff --git a/rust/lance-core/src/cache.rs b/rust/lance-core/src/cache.rs index 3cc8800f56f..cb9ac1536c4 100644 --- a/rust/lance-core/src/cache.rs +++ b/rust/lance-core/src/cache.rs @@ -6,7 +6,6 @@ use std::any::{Any, TypeId}; use std::sync::Arc; -use deepsize::{Context, DeepSizeOf}; use futures::Future; use moka::sync::Cache; use object_store::path::Path; @@ -14,6 +13,8 @@ use object_store::path::Path; use crate::utils::path::LancePathExt; use crate::Result; +pub use deepsize::{Context, DeepSizeOf}; + type ArcAny = Arc; #[derive(Clone)] @@ -121,6 +122,19 @@ impl FileMetadataCache { } } + pub fn approx_size(&self) -> usize { + if let Some(cache) = self.cache.as_ref() { + cache.entry_count() as usize + } else { + 0 + } + } + /// Fetch an item from the cache, using a str as the key + pub fn get_by_str(&self, path: &str) -> Option> { + self.get(&Path::parse(path).unwrap()) + } + + /// Fetch an item from the cache pub fn get(&self, path: &Path) -> Option> { let cache = self.cache.as_ref()?; let temp: Path; @@ -135,6 +149,7 @@ impl FileMetadataCache { .map(|metadata| metadata.record.clone().downcast::().unwrap()) } + /// Insert an item into the cache pub fn insert(&self, path: Path, metadata: Arc) { let Some(cache) = self.cache.as_ref() else { return; @@ -147,6 +162,15 @@ impl FileMetadataCache { cache.insert((path, TypeId::of::()), SizedRecord::new(metadata)); } + /// Insert an item into the cache, using a str as the key + pub fn insert_by_str( + &self, + key: &str, + metadata: Arc, + ) { + self.insert(Path::parse(key).unwrap(), metadata); + } + /// Get an item /// /// If it exists in the cache return that diff --git a/rust/lance-core/src/container.rs b/rust/lance-core/src/container.rs new file mode 100644 index 00000000000..f92893bf076 --- /dev/null +++ b/rust/lance-core/src/container.rs @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +pub mod list; diff --git a/rust/lance-core/src/container/list.rs b/rust/lance-core/src/container/list.rs new file mode 100644 index 00000000000..5b084502a2c --- /dev/null +++ b/rust/lance-core/src/container/list.rs @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::collections::LinkedList; + +use deepsize::DeepSizeOf; + +/// A linked list that grows exponentially. It is used to store a large number of +/// elements in a memory-efficient way. The list grows by doubling the capacity of +/// the last element when it's full, the capacity can be limited by the `limit` +/// parameter. The default value is 0, which means no limit. +#[derive(Debug, Clone, Default)] +pub struct ExpLinkedList { + inner: LinkedList>, + len: usize, + // The maximum capacity of single node in the list. + // If the limit is 0, there is no limit. + // We use u16 to save memory because ExpLinkedList should not + // be used if the limit is that large. + limit: u16, +} + +impl ExpLinkedList { + /// Creates a new empty `ExpLinkedList`. + pub fn new() -> Self { + Self { + inner: LinkedList::new(), + len: 0, + limit: 0, + } + } + + pub fn with_capacity(capacity: usize) -> Self { + let mut inner = LinkedList::new(); + inner.push_back(Vec::with_capacity(capacity)); + Self { + inner, + len: 0, + limit: 0, + } + } + + /// Creates a new `ExpLinkedList` with a specified capacity limit. + /// The limit is the maximum capacity of a single node in the list. + /// If the limit is 0, there is no limit. + pub fn with_capacity_limit(limit: u16) -> Self { + Self { + inner: LinkedList::new(), + len: 0, + limit, + } + } + + /// Pushes a new element into the list. If the last element in the list + /// reaches its capacity, a new node is created with double capacity. + pub fn push(&mut self, v: T) { + match self.inner.back() { + Some(last) => { + if last.len() == last.capacity() { + let new_cap = if self.limit > 0 && last.capacity() * 2 >= self.limit as usize { + self.limit as usize + } else { + last.capacity() * 2 + }; + self.inner.push_back(Vec::with_capacity(new_cap)); + } + } + None => { + self.inner.push_back(Vec::with_capacity(1)); + } + } + self.do_push(v); + } + + fn do_push(&mut self, v: T) { + self.inner.back_mut().unwrap().push(v); + self.len += 1; + } + + /// Removes the last element from the list. + pub fn pop(&mut self) -> Option { + match self.inner.back_mut() { + Some(last) => { + if last.is_empty() { + self.inner.pop_back(); + self.pop() + } else { + self.len -= 1; + last.pop() + } + } + None => None, + } + } + + /// Clears the list, removing all elements. + /// This will free the memory used by the list. + pub fn clear(&mut self) { + self.inner.clear(); + self.len = 0; + } + + /// Returns the number of elements in the list. + pub fn len(&self) -> usize { + self.len + } + + /// Returns whether the list is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the size of list, including the size of the elements and the + /// size of the list itself, and the unused space. + /// The element size is calculated using `std::mem::size_of::()`, + /// so it is not accurate for all types. + /// For example, for `String`, it will return the size of the pointer, + /// not the size of the string itself. For that you need to use `DeepSizeOf`. + pub fn size(&self) -> usize { + let unused_space = match self.inner.back() { + Some(last) => last.capacity() - last.len(), + None => 0, + }; + (self.len() + unused_space) * std::mem::size_of::() + + std::mem::size_of::() + + self.inner.len() * std::mem::size_of::>() + } + + /// Returns an iterator over the elements in the list. + pub fn iter(&self) -> ExpLinkedListIter<'_, T> { + ExpLinkedListIter::new(self) + } +} + +impl DeepSizeOf for ExpLinkedList { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.inner + .iter() + .map(|v| v.deep_size_of_children(context)) + .sum() + } +} + +impl FromIterator for ExpLinkedList { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let size_hint = iter.size_hint().0; + let cap = if size_hint > 0 { size_hint } else { 1 }; + let mut list = Self::with_capacity(cap); + for item in iter { + list.push(item); + } + list + } +} + +pub struct ExpLinkedListIter<'a, T> { + inner: std::collections::linked_list::Iter<'a, Vec>, + inner_iter: Option>, +} + +impl<'a, T> ExpLinkedListIter<'a, T> { + pub fn new(inner: &'a ExpLinkedList) -> Self { + Self { + inner: inner.inner.iter(), + inner_iter: None, + } + } +} + +impl<'a, T> Iterator for ExpLinkedListIter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if let Some(inner_iter) = &mut self.inner_iter { + if let Some(v) = inner_iter.next() { + return Some(v); + } + } + if let Some(inner) = self.inner.next() { + self.inner_iter = Some(inner.iter()); + return self.next(); + } + None + } +} + +pub struct ExpLinkedListIntoIter { + inner: std::collections::linked_list::IntoIter>, + inner_iter: Option>, + len: usize, +} + +impl ExpLinkedListIntoIter { + pub fn new(list: ExpLinkedList) -> Self { + let len = list.len(); + Self { + inner: list.inner.into_iter(), + inner_iter: None, + len, + } + } +} + +impl Iterator for ExpLinkedListIntoIter { + type Item = T; + + fn next(&mut self) -> Option { + if let Some(inner_iter) = &mut self.inner_iter { + if let Some(v) = inner_iter.next() { + return Some(v); + } + } + if let Some(inner) = self.inner.next() { + self.inner_iter = Some(inner.into_iter()); + return self.next(); + } + None + } + + fn size_hint(&self) -> (usize, Option) { + (self.len, Some(self.len)) + } +} + +impl IntoIterator for ExpLinkedList { + type Item = T; + type IntoIter = ExpLinkedListIntoIter; + + fn into_iter(self) -> Self::IntoIter { + ExpLinkedListIntoIter::new(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_exp_linked_list(list: &mut ExpLinkedList) { + assert_eq!(list.len(), 100); + assert!(!list.is_empty()); + + // removes the last 50 elements + for i in 0..50 { + assert_eq!(list.pop(), Some(99 - i)); + } + assert_eq!(list.len(), 50); + assert!(!list.is_empty()); + + // iterate over the list + for (i, v) in list.iter().enumerate() { + assert_eq!(*v, i); + } + + // clear the list + list.clear(); + assert_eq!(list.len(), 0); + assert!(list.is_empty()); + assert_eq!(list.pop(), None); + } + + #[test] + fn test_exp_linked_list_basic() { + let mut list = ExpLinkedList::new(); + for i in 0..100 { + list.push(i); + assert_eq!(list.len(), i + 1); + } + test_exp_linked_list(&mut list); + } + + #[test] + fn test_exp_linked_list_from() { + let mut list = (0..100).collect(); + test_exp_linked_list(&mut list); + } + + #[test] + fn test_exp_linked_list_with_capacity_limit() { + let mut list = ExpLinkedList::with_capacity_limit(10); + for i in 0..100 { + list.push(i); + assert_eq!(list.len(), i + 1); + } + assert_eq!(list.inner.back().unwrap().capacity(), 10); + test_exp_linked_list(&mut list); + } +} diff --git a/rust/lance-core/src/datatypes.rs b/rust/lance-core/src/datatypes.rs index e7d3f28a973..2199ea72003 100644 --- a/rust/lance-core/src/datatypes.rs +++ b/rust/lance-core/src/datatypes.rs @@ -12,23 +12,26 @@ use deepsize::DeepSizeOf; use lance_arrow::bfloat16::{ is_bfloat16_field, ARROW_EXT_META_KEY, ARROW_EXT_NAME_KEY, BFLOAT16_EXT_NAME, }; -use snafu::{location, Location}; +use snafu::location; mod field; mod schema; use crate::{Error, Result}; pub use field::{ - Encoding, Field, NullabilityComparison, SchemaCompareOptions, StorageClass, + Encoding, Field, NullabilityComparison, OnTypeMismatch, SchemaCompareOptions, StorageClass, LANCE_STORAGE_CLASS_SCHEMA_META_KEY, }; -pub use schema::Schema; +pub use schema::{OnMissing, Projectable, Projection, Schema}; pub const COMPRESSION_META_KEY: &str = "lance-encoding:compression"; pub const COMPRESSION_LEVEL_META_KEY: &str = "lance-encoding:compression-level"; pub const BLOB_META_KEY: &str = "lance-encoding:blob"; pub const PACKED_STRUCT_LEGACY_META_KEY: &str = "packed"; pub const PACKED_STRUCT_META_KEY: &str = "lance-encoding:packed"; +pub const STRUCTURAL_ENCODING_META_KEY: &str = "lance-encoding:structural-encoding"; +pub const STRUCTURAL_ENCODING_MINIBLOCK: &str = "miniblock"; +pub const STRUCTURAL_ENCODING_FULLZIP: &str = "fullzip"; lazy_static::lazy_static! { pub static ref BLOB_DESC_FIELDS: Fields = @@ -214,19 +217,26 @@ impl TryFrom<&LogicalType> for DataType { let splits = lt.0.split(':').collect::>(); match splits[0] { "fixed_size_list" => { - if splits.len() != 3 { + if splits.len() < 3 { return Err(Error::Schema { message: format!("Unsupported logical type: {}", lt), location: location!(), }); } - let size: i32 = splits[2].parse::().map_err(|e: _| Error::Schema { - message: e.to_string(), - location: location!(), - })?; + let size: i32 = + splits + .last() + .unwrap() + .parse::() + .map_err(|e: _| Error::Schema { + message: e.to_string(), + location: location!(), + })?; + + let inner_type = splits[1..splits.len() - 1].join(":"); - match splits[1] { + match inner_type.as_str() { BFLOAT16_EXT_NAME => { let field = ArrowField::new("item", Self::FixedSizeBinary(2), true) .with_metadata( diff --git a/rust/lance-core/src/datatypes/field.rs b/rust/lance-core/src/datatypes/field.rs index 91eade2fa7c..47850789df8 100644 --- a/rust/lance-core/src/datatypes/field.rs +++ b/rust/lance-core/src/datatypes/field.rs @@ -4,8 +4,8 @@ //! Lance Schema Field use std::{ - cmp::max, - collections::HashMap, + cmp::{max, Ordering}, + collections::{HashMap, VecDeque}, fmt::{self, Display}, str::FromStr, sync::Arc, @@ -21,11 +21,11 @@ use arrow_array::{ use arrow_schema::{DataType, Field as ArrowField}; use deepsize::DeepSizeOf; use lance_arrow::{bfloat16::ARROW_EXT_NAME_KEY, *}; -use snafu::{location, Location}; +use snafu::location; use super::{ schema::{compare_fields, explain_fields_difference}, - Dictionary, LogicalType, + Dictionary, LogicalType, Projection, }; use crate::{Error, Result}; @@ -108,6 +108,13 @@ impl FromStr for StorageClass { } } +/// What to do on a merge operation if the types of the fields don't match +#[derive(Debug, Clone, Copy, PartialEq, Eq, DeepSizeOf)] +pub enum OnTypeMismatch { + TakeSelf, + Error, +} + /// Lance Schema Field /// #[derive(Debug, Clone, PartialEq, DeepSizeOf)] @@ -162,6 +169,106 @@ impl Field { self.storage_class } + /// Merge a field with another field using a reference field to ensure + /// the correct order of fields + /// + /// For each child in the reference field we look for a matching child + /// in self and other. + /// + /// If we find a match in both we recursively merge the children. + /// If we find a match in one but not the other we take the matching child. + /// + /// Primitive fields we simply clone self and return. + /// + /// Matches are determined using field names and so ids are not required. + pub fn merge_with_reference(&self, other: &Self, reference: &Self) -> Self { + let mut new_children = Vec::with_capacity(reference.children.len()); + let mut self_children_itr = self.children.iter().peekable(); + let mut other_children_itr = other.children.iter().peekable(); + for ref_child in &reference.children { + match (self_children_itr.peek(), other_children_itr.peek()) { + (Some(&only_child), None) => { + // other is exhausted so just check if self matches + if only_child.name == ref_child.name { + new_children.push(only_child.clone()); + self_children_itr.next(); + } + } + (None, Some(&only_child)) => { + // Self is exhausted so just check if other matches + if only_child.name == ref_child.name { + new_children.push(only_child.clone()); + other_children_itr.next(); + } + } + (Some(&self_child), Some(&other_child)) => { + // Both iterators have potential, see if any match + match ( + ref_child.name.cmp(&self_child.name), + ref_child.name.cmp(&other_child.name), + ) { + (Ordering::Equal, Ordering::Equal) => { + // Both match, recursively merge + new_children + .push(self_child.merge_with_reference(other_child, ref_child)); + self_children_itr.next(); + other_children_itr.next(); + } + (Ordering::Equal, _) => { + // Self matches, other doesn't, use self as-is + new_children.push(self_child.clone()); + self_children_itr.next(); + } + (_, Ordering::Equal) => { + // Other matches, self doesn't, use other as-is + new_children.push(other_child.clone()); + other_children_itr.next(); + } + _ => { + // Neither match, field is projected out + } + } + } + (None, None) => { + // Both iterators are exhausted, we can quit, all remaining fields projected out + break; + } + } + } + Self { + children: new_children, + ..self.clone() + } + } + + pub fn apply_projection(&self, projection: &Projection) -> Option { + let children = self + .children + .iter() + .filter_map(|c| c.apply_projection(projection)) + .collect::>(); + + // The following case is invalid: + // - This is a nested field (has children) + // - All children were projected away + // - Caller is asking for the parent field + assert!( + // One of the following must be true + !children.is_empty() // Some children were projected + || !projection.contains_field_id(self.id) // Caller is not asking for this field + || self.children.is_empty() // This isn't a nested field + ); + + if children.is_empty() && !projection.contains_field_id(self.id) { + None + } else { + Some(Self { + children, + ..self.clone() + }) + } + } + pub(crate) fn explain_differences( &self, expected: &Self, @@ -456,7 +563,7 @@ impl Field { /// Project by a field. /// - pub fn project_by_field(&self, other: &Self) -> Result { + pub fn project_by_field(&self, other: &Self, on_type_mismatch: OnTypeMismatch) -> Result { if self.name != other.name { return Err(Error::Schema { message: format!( @@ -496,7 +603,7 @@ impl Field { location: location!(), }); }; - fields.push(child.project_by_field(other_field)?); + fields.push(child.project_by_field(other_field, on_type_mismatch)?); } let mut cloned = self.clone(); cloned.children = fields; @@ -504,7 +611,8 @@ impl Field { } (DataType::List(_), DataType::List(_)) | (DataType::LargeList(_), DataType::LargeList(_)) => { - let projected = self.children[0].project_by_field(&other.children[0])?; + let projected = + self.children[0].project_by_field(&other.children[0], on_type_mismatch)?; let mut cloned = self.clone(); cloned.children = vec![projected]; Ok(cloned) @@ -524,13 +632,33 @@ impl Field { { Ok(self.clone()) } - _ => Err(Error::Schema { - message: format!( - "Attempt to project incompatible fields: {} and {}", - self, other - ), - location: location!(), - }), + _ => match on_type_mismatch { + OnTypeMismatch::Error => Err(Error::Schema { + message: format!( + "Attempt to project incompatible fields: {} and {}", + self, other + ), + location: location!(), + }), + OnTypeMismatch::TakeSelf => Ok(self.clone()), + }, + } + } + + pub(crate) fn resolve<'a>( + &'a self, + split: &mut VecDeque<&str>, + fields: &mut Vec<&'a Self>, + ) -> bool { + fields.push(self); + if split.is_empty() { + return true; + } + let first = split.pop_front().unwrap(); + if let Some(child) = self.children.iter().find(|c| c.name == first) { + child.resolve(split, fields) + } else { + false } } @@ -546,7 +674,11 @@ impl Field { } let self_type = self.data_type(); let other_type = other.data_type(); - if self_type.is_struct() && other_type.is_struct() { + + if matches!( + (&self_type, &other_type), + (DataType::Struct(_), DataType::Struct(_)) | (DataType::List(_), DataType::List(_)) + ) { let children = self .children .iter() @@ -705,6 +837,19 @@ impl Field { self.children.iter_mut().for_each(Self::reset_id); } + pub fn field_by_id_mut(&mut self, id: impl Into) -> Option<&mut Self> { + let id = id.into(); + for child in self.children.as_mut_slice() { + if child.id == id { + return Some(child); + } + if let Some(grandchild) = child.field_by_id_mut(id) { + return Some(grandchild); + } + } + None + } + pub fn field_by_id(&self, id: impl Into) -> Option<&Self> { let id = id.into(); for child in self.children.as_slice() { @@ -731,6 +876,15 @@ impl Field { } None } + + // Check if field has metadata `packed` set to true, this check is case insensitive. + pub fn is_packed_struct(&self) -> bool { + let field_metadata = &self.metadata; + field_metadata + .get("packed") + .map(|v| v.to_lowercase() == "true") + .unwrap_or(false) + } } impl fmt::Display for Field { @@ -948,19 +1102,19 @@ mod tests { let f2: Field = ArrowField::new("a", DataType::Null, true) .try_into() .unwrap(); - let p1 = f1.project_by_field(&f2).unwrap(); + let p1 = f1.project_by_field(&f2, OnTypeMismatch::Error).unwrap(); assert_eq!(p1, f1); let f3: Field = ArrowField::new("b", DataType::Null, true) .try_into() .unwrap(); - assert!(f1.project_by_field(&f3).is_err()); + assert!(f1.project_by_field(&f3, OnTypeMismatch::Error).is_err()); let f4: Field = ArrowField::new("a", DataType::Int32, true) .try_into() .unwrap(); - assert!(f1.project_by_field(&f4).is_err()); + assert!(f1.project_by_field(&f4, OnTypeMismatch::Error).is_err()); } #[test] diff --git a/rust/lance-core/src/datatypes/schema.rs b/rust/lance-core/src/datatypes/schema.rs index 4d4589ee137..f6f64190283 100644 --- a/rust/lance-core/src/datatypes/schema.rs +++ b/rust/lance-core/src/datatypes/schema.rs @@ -4,18 +4,19 @@ //! Schema use std::{ - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque}, fmt::{self, Debug, Formatter}, + sync::Arc, }; use arrow_array::RecordBatch; use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; use deepsize::DeepSizeOf; use lance_arrow::*; -use snafu::{location, Location}; +use snafu::location; -use super::field::{Field, SchemaCompareOptions, StorageClass}; -use crate::{Error, Result}; +use super::field::{Field, OnTypeMismatch, SchemaCompareOptions, StorageClass}; +use crate::{Error, Result, ROW_ADDR, ROW_ID}; /// Lance Schema. #[derive(Default, Debug, Clone, DeepSizeOf)] @@ -152,6 +153,26 @@ impl Schema { ArrowSchema::from(self).to_compact_string(indent) } + /// Given a string column reference, resolve the path of fields + /// + /// For example, given a.b.c we will return the fields [a, b, c] + /// + /// Returns None if we can't find a segment at any point + pub fn resolve(&self, column: impl AsRef) -> Option> { + let mut split = column.as_ref().split('.').collect::>(); + let mut fields = Vec::with_capacity(split.len()); + let first = split.pop_front().unwrap(); + if let Some(field) = self.field(first) { + if field.resolve(&mut split, &mut fields) { + Some(fields) + } else { + None + } + } else { + None + } + } + fn do_project>(&self, columns: &[T], err_on_missing: bool) -> Result { let mut candidates: Vec = vec![]; for col in columns { @@ -304,13 +325,15 @@ impl Schema { pub fn project_by_schema>( &self, projection: S, + on_missing: OnMissing, + on_type_mismatch: OnTypeMismatch, ) -> Result { let projection = projection.try_into()?; let mut new_fields = vec![]; for field in projection.fields.iter() { if let Some(self_field) = self.field(&field.name) { - new_fields.push(self_field.project_by_field(field)?); - } else { + new_fields.push(self_field.project_by_field(field, on_type_mismatch)?); + } else if matches!(on_missing, OnMissing::Error) { return Err(Error::Schema { message: format!("Field {} not found", field.name), location: location!(), @@ -377,7 +400,19 @@ impl Schema { } /// Get field by its id. - // TODO: pub(crate) + pub fn field_by_id_mut(&mut self, id: impl Into) -> Option<&mut Field> { + let id = id.into(); + for field in self.fields.iter_mut() { + if field.id == id { + return Some(field); + } + if let Some(grandchild) = field.field_by_id_mut(id) { + return Some(grandchild); + } + } + None + } + pub fn field_by_id(&self, id: impl Into) -> Option<&Field> { let id = id.into(); for field in self.fields.iter() { @@ -527,6 +562,10 @@ impl Schema { }; Ok(schema) } + + pub fn all_fields_nullable(&self) -> bool { + SchemaFieldIterPreOrder::new(self).all(|f| f.nullable) + } } impl PartialEq for Schema { @@ -738,6 +777,263 @@ fn explain_metadata_difference( } } +/// What to do when a column is missing in the schema +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OnMissing { + Error, + Ignore, +} + +/// A trait for something that we can project fields from. +pub trait Projectable: Debug + Send + Sync { + fn schema(&self) -> &Schema; +} + +impl Projectable for Schema { + fn schema(&self) -> &Schema { + self + } +} + +/// A projection is a selection of fields in a schema +/// +/// In addition we record whether the row_id or row_addr are +/// selected (these fields have no field id) +#[derive(Clone)] +pub struct Projection { + base: Arc, + pub field_ids: HashSet, + pub with_row_id: bool, + pub with_row_addr: bool, +} + +impl Debug for Projection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Projection") + .field("schema", &self.to_schema()) + .field("with_row_id", &self.with_row_id) + .field("with_row_addr", &self.with_row_addr) + .finish() + } +} + +impl Projection { + /// Create a new empty projection + pub fn empty(base: Arc) -> Self { + Self { + base, + field_ids: HashSet::new(), + with_row_id: false, + with_row_addr: false, + } + } + + pub fn with_row_id(mut self) -> Self { + self.with_row_id = true; + self + } + + pub fn with_row_addr(mut self) -> Self { + self.with_row_addr = true; + self + } + + /// Add a column (and any of its parents) to the projection from a string reference + pub fn union_column(mut self, column: impl AsRef, on_missing: OnMissing) -> Result { + let column = column.as_ref(); + if column == ROW_ID { + self.with_row_id = true; + return Ok(self); + } else if column == ROW_ADDR { + self.with_row_addr = true; + return Ok(self); + } + + if let Some(fields) = self.base.schema().resolve(column) { + self.field_ids.extend(fields.iter().map(|f| f.id)); + } else if matches!(on_missing, OnMissing::Error) { + return Err(Error::InvalidInput { + source: format!("Column {} does not exist", column).into(), + location: location!(), + }); + } + Ok(self) + } + + /// True if the projection selects the given field id + pub fn contains_field_id(&self, id: i32) -> bool { + self.field_ids.contains(&id) + } + + /// True if the projection selects fields other than the row id / addr + pub fn has_data_fields(&self) -> bool { + !self.field_ids.is_empty() + } + + /// Add multiple columns (and their parents) to the projection + pub fn union_columns( + mut self, + columns: impl IntoIterator>, + on_missing: OnMissing, + ) -> Result { + for column in columns { + self = self.union_column(column, on_missing)?; + } + Ok(self) + } + + /// Adds all fields from the base schema satisfying a predicate + pub fn union_predicate(mut self, predicate: impl Fn(&Field) -> bool) -> Self { + for field in self.base.schema().fields_pre_order() { + if predicate(field) { + self.field_ids.insert(field.id); + } + } + self + } + + /// Removes all fields in the base schema satisfying a predicate + pub fn subtract_predicate(mut self, predicate: impl Fn(&Field) -> bool) -> Self { + for field in self.base.schema().fields_pre_order() { + if predicate(field) { + self.field_ids.remove(&field.id); + } + } + self + } + + /// Creates a new projection that is the intersection of this projection and another + pub fn intersect(mut self, other: &Self) -> Self { + self.field_ids = HashSet::from_iter(self.field_ids.intersection(&other.field_ids).copied()); + self.with_row_id = self.with_row_id && other.with_row_id; + self.with_row_addr = self.with_row_addr && other.with_row_addr; + self + } + + /// Adds all fields from the provided schema to the projection + /// + /// Fields are only added if they exist in the base schema, otherwise they + /// are ignored. + /// + /// Will panic if a field in the given schema has a non-negative id and is not in the base schema. + pub fn union_schema(mut self, other: &Schema) -> Self { + for field in other.fields_pre_order() { + if field.id >= 0 { + self.field_ids.insert(field.id); + } else if field.name == ROW_ID { + self.with_row_id = true; + } else if field.name == ROW_ADDR { + self.with_row_addr = true; + } else { + // If a field is not in our schema then it should probably have an id of -1. If it isn't -1 + // that probably implies some kind of weird schema mixing is going on and we should panic. + debug_assert_eq!(field.id, -1); + } + } + self + } + + /// Creates a new projection that is the union of this projection and another + pub fn union_projection(mut self, other: &Self) -> Self { + self.field_ids.extend(&other.field_ids); + self.with_row_id = self.with_row_id || other.with_row_id; + self.with_row_addr = self.with_row_addr || other.with_row_addr; + self + } + + /// Adds all fields from the given schema to the projection + /// + /// on_missing controls what happen to fields that are not in the base schema + /// + /// Name based matching is used to determine if a field is in the base schema. + pub fn union_arrow_schema( + mut self, + other: &ArrowSchema, + on_missing: OnMissing, + ) -> Result { + self.with_row_id |= other.fields().iter().any(|f| f.name() == ROW_ID); + self.with_row_addr |= other.fields().iter().any(|f| f.name() == ROW_ADDR); + let other = + self.base + .schema() + .project_by_schema(other, on_missing, OnTypeMismatch::TakeSelf)?; + Ok(self.union_schema(&other)) + } + + /// Removes all fields from the projection that are in the given schema + /// + /// on_missing controls what happen to fields that are not in the base schema + /// + /// Name based matching is used to determine if a field is in the base schema. + pub fn subtract_arrow_schema( + mut self, + other: &ArrowSchema, + on_missing: OnMissing, + ) -> Result { + self.with_row_id &= !other.fields().iter().any(|f| f.name() == ROW_ID); + self.with_row_addr &= !other.fields().iter().any(|f| f.name() == ROW_ADDR); + let other = + self.base + .schema() + .project_by_schema(other, on_missing, OnTypeMismatch::TakeSelf)?; + Ok(self.subtract_schema(&other)) + } + + /// Removes all fields from this projection that are present in the given projection + pub fn subtract_projection(mut self, other: &Self) -> Self { + self.field_ids = self + .field_ids + .difference(&other.field_ids) + .copied() + .collect(); + self.with_row_addr = self.with_row_addr && !other.with_row_addr; + self.with_row_id = self.with_row_id && !other.with_row_id; + self + } + + /// Removes all fields from the projection that are in the given schema + /// + /// Fields are only removed if they exist in the base schema, otherwise they + /// are ignored. + /// + /// Will panic if a field in the given schema has a non-negative id and is not in the base schema. + pub fn subtract_schema(mut self, other: &Schema) -> Self { + for field in other.fields_pre_order() { + if field.id >= 0 { + self.field_ids.remove(&field.id); + } else if field.name == ROW_ID { + self.with_row_id = false; + } else if field.name == ROW_ADDR { + self.with_row_addr = false; + } else { + debug_assert_eq!(field.id, -1); + } + } + self + } + + /// True if the projection does not select any fields + pub fn is_empty(&self) -> bool { + self.field_ids.is_empty() + } + + /// Convert the projection to a schema + pub fn to_schema(&self) -> Schema { + let field_ids = self.field_ids.iter().copied().collect::>(); + self.base.schema().project_by_ids(&field_ids, false) + } + + /// Convert the projection to a schema + pub fn into_schema(self) -> Schema { + self.to_schema() + } + + /// Convert the projection to a schema reference + pub fn into_schema_ref(self) -> Arc { + Arc::new(self.into_schema()) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -909,7 +1205,9 @@ mod tests { false, ), ]); - let projected = schema.project_by_schema(&projection).unwrap(); + let projected = schema + .project_by_schema(&projection, OnMissing::Error, OnTypeMismatch::TakeSelf) + .unwrap(); assert_eq!(ArrowSchema::from(&projected), projection); } @@ -1011,6 +1309,30 @@ mod tests { ArrowField::new("c", DataType::Float64, false), ]); assert_eq!(actual, expected); + + let schema_with_list_struct = ArrowSchema::new(vec![ArrowField::new( + "struct_list", + DataType::List(Arc::new(ArrowField::new( + "item", + DataType::Struct(ArrowFields::from(vec![ + ArrowField::new("f1", DataType::Utf8, true), + ArrowField::new("f2", DataType::Boolean, false), + ])), + true, + ))), + true, + )]); + let schema_with_list_struct = Schema::try_from(&schema_with_list_struct).unwrap(); + + let with_missing_field = schema_with_list_struct.project_by_ids(&[1, 3], false); + let intersection = schema_with_list_struct + .intersection_ignore_types(&with_missing_field) + .unwrap(); + assert_eq!(intersection, with_missing_field); + let intersection = with_missing_field + .intersection_ignore_types(&schema_with_list_struct) + .unwrap(); + assert_eq!(intersection, with_missing_field); } #[test] @@ -1325,4 +1647,78 @@ mod tests { let res = out_of_order.explain_difference(&expected, &options); assert!(res.is_none(), "Expected None, got {:?}", res); } + + #[test] + pub fn test_all_fields_nullable() { + let test_cases = vec![ + ( + vec![], // empty schema + true, + ), + ( + vec![ + Field::new_arrow("a", DataType::Int32, true).unwrap(), + Field::new_arrow("b", DataType::Utf8, true).unwrap(), + ], // basic case + true, + ), + ( + vec![ + Field::new_arrow("a", DataType::Int32, false).unwrap(), + Field::new_arrow("b", DataType::Utf8, true).unwrap(), + ], + false, + ), + ( + // check nested schema, parent is nullable + vec![Field::new_arrow( + "struct", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "a", + DataType::Int32, + false, + )])), + true, + ) + .unwrap()], + false, + ), + ( + // check nested schema, child is nullable + vec![Field::new_arrow( + "struct", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])), + false, + ) + .unwrap()], + false, + ), + ( + // check nested schema, all is nullable + vec![Field::new_arrow( + "struct", + DataType::Struct(ArrowFields::from(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])), + true, + ) + .unwrap()], + true, + ), + ]; + + for (fields, expected) in test_cases { + let schema = Schema { + fields, + metadata: Default::default(), + }; + assert_eq!(schema.all_fields_nullable(), expected); + } + } } diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index c186d77c37b..6eed53b010e 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -51,6 +51,14 @@ pub enum Error { source: BoxedError, location: Location, }, + #[snafu(display("Retryable commit conflict for version {version}: {source}, {location}"))] + RetryableCommitConflict { + version: u64, + source: BoxedError, + location: Location, + }, + #[snafu(display("Too many concurrent writers. {message}, {location}"))] + TooMuchWriteContention { message: String, location: Location }, #[snafu(display("Encountered internal error. Please file a bug report at https://github.com/lancedb/lance/issues. {message}, {location}"))] Internal { message: String, location: Location }, #[snafu(display("A prerequisite task failed: {message}, {location}"))] @@ -151,6 +159,24 @@ impl Error { } } +pub trait LanceOptionExt { + /// Unwraps an option, returning an internal error if the option is None. + /// + /// Can be used when an option is expected to have a value. + fn expect_ok(self) -> Result; +} + +impl LanceOptionExt for Option { + #[track_caller] + fn expect_ok(self) -> Result { + let location = std::panic::Location::caller().to_snafu_location(); + self.ok_or_else(|| Error::Internal { + message: "Expected option to have value".to_string(), + location, + }) + } +} + trait ToSnafuLocation { fn to_snafu_location(&'static self) -> snafu::Location; } @@ -226,6 +252,16 @@ impl From for Error { } } +impl From for Error { + #[track_caller] + fn from(e: prost::UnknownEnumValue) -> Self { + Self::IO { + source: box_error(e), + location: std::panic::Location::caller().to_snafu_location(), + } + } +} + impl From for Error { #[track_caller] fn from(e: tokio::task::JoinError) -> Self { diff --git a/rust/lance-core/src/lib.rs b/rust/lance-core/src/lib.rs index 9ab18540768..ed16d10c3e2 100644 --- a/rust/lance-core/src/lib.rs +++ b/rust/lance-core/src/lib.rs @@ -4,6 +4,7 @@ use arrow_schema::{DataType, Field as ArrowField}; pub mod cache; +pub mod container; pub mod datatypes; pub mod error; pub mod traits; diff --git a/rust/lance-core/src/utils.rs b/rust/lance-core/src/utils.rs index acf215caeb4..f04ca305f93 100644 --- a/rust/lance-core/src/utils.rs +++ b/rust/lance-core/src/utils.rs @@ -2,12 +2,14 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors pub mod address; +pub mod backoff; pub mod bit; pub mod cpu; pub mod deletion; pub mod futures; pub mod hash; pub mod mask; +pub mod parse; pub mod path; pub mod testing; pub mod tokio; diff --git a/rust/lance-core/src/utils/backoff.rs b/rust/lance-core/src/utils/backoff.rs new file mode 100644 index 00000000000..d2093cb5b6d --- /dev/null +++ b/rust/lance-core/src/utils/backoff.rs @@ -0,0 +1,92 @@ +use rand::Rng; +use std::time::Duration; + +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +/// Computes backoff as +/// +/// ```text +/// backoff = base^attempt * unit + jitter +/// ``` +/// +/// The defaults are base=2, unit=50ms, jitter=50ms, min=0ms, max=5s. This gives +/// a backoff of 50ms, 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, 5s, (not including jitter). +/// +/// You can have non-exponential backoff by setting base=1. +pub struct Backoff { + base: u32, + unit: u32, + jitter: i32, + min: u32, + max: u32, + attempt: u32, +} + +impl Default for Backoff { + fn default() -> Self { + Self { + base: 2, + unit: 50, + jitter: 50, + min: 0, + max: 5000, + attempt: 0, + } + } +} + +impl Backoff { + pub fn with_base(self, base: u32) -> Self { + Self { base, ..self } + } + + pub fn with_jitter(self, jitter: i32) -> Self { + Self { jitter, ..self } + } + + pub fn with_min(self, min: u32) -> Self { + Self { min, ..self } + } + + pub fn with_max(self, max: u32) -> Self { + Self { max, ..self } + } + + pub fn next_backoff(&mut self) -> Duration { + let backoff = self + .base + .saturating_pow(self.attempt) + .saturating_mul(self.unit); + let jitter = rand::thread_rng().gen_range(-self.jitter..=self.jitter); + let backoff = (backoff.saturating_add_signed(jitter)).clamp(self.min, self.max); + self.attempt += 1; + Duration::from_millis(backoff as u64) + } + + pub fn attempt(&self) -> u32 { + self.attempt + } + + pub fn reset(&mut self) { + self.attempt = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_backoff() { + let mut backoff = Backoff::default().with_jitter(0); + assert_eq!(backoff.next_backoff().as_millis(), 50); + assert_eq!(backoff.attempt(), 1); + assert_eq!(backoff.next_backoff().as_millis(), 100); + assert_eq!(backoff.attempt(), 2); + assert_eq!(backoff.next_backoff().as_millis(), 200); + assert_eq!(backoff.attempt(), 3); + assert_eq!(backoff.next_backoff().as_millis(), 400); + assert_eq!(backoff.attempt(), 4); + } +} diff --git a/rust/lance-core/src/utils/bit.rs b/rust/lance-core/src/utils/bit.rs index 75a13e783aa..7d69fee8da0 100644 --- a/rust/lance-core/src/utils/bit.rs +++ b/rust/lance-core/src/utils/bit.rs @@ -60,7 +60,6 @@ pub fn log_2_ceil(val: u32) -> u32 { } #[cfg(test)] - pub mod tests { use crate::utils::bit::log_2_ceil; diff --git a/rust/lance-core/src/utils/cpu.rs b/rust/lance-core/src/utils/cpu.rs index 4427922dd18..be60e11984c 100644 --- a/rust/lance-core/src/utils/cpu.rs +++ b/rust/lance-core/src/utils/cpu.rs @@ -113,3 +113,10 @@ mod loongarch64 { flags & libc::HWCAP_LOONGARCH_LASX != 0 } } + +#[cfg(all(target_arch = "aarch64", target_os = "android"))] +mod aarch64 { + pub fn has_neon_f16_support() -> bool { + false + } +} diff --git a/rust/lance-core/src/utils/deletion.rs b/rust/lance-core/src/utils/deletion.rs index 1735f90b8cd..9ad2acd036a 100644 --- a/rust/lance-core/src/utils/deletion.rs +++ b/rust/lance-core/src/utils/deletion.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::{collections::HashSet, ops::Range}; +use std::{collections::HashSet, ops::Range, sync::Arc}; use arrow_array::BooleanArray; use deepsize::{Context, DeepSizeOf}; @@ -60,6 +60,14 @@ impl DeletionVector { } } + fn range_cardinality(&self, range: Range) -> u64 { + match self { + Self::NoDeletions => 0, + Self::Set(set) => range.fold(0, |acc, i| acc + set.contains(&i) as u64), + Self::Bitmap(bitmap) => bitmap.range_cardinality(range), + } + } + pub fn iter(&self) -> Box + Send + '_> { match self { Self::NoDeletions => Box::new(std::iter::empty()), @@ -97,15 +105,15 @@ impl DeletionVector { // Note: deletion vectors are based on 32-bit offsets. However, this function works // even when given 64-bit row addresses. That is because `id as u32` returns the lower // 32 bits (the row offset) and the upper 32 bits are ignored. - pub fn build_predicate(&self, row_ids: std::slice::Iter) -> Option { + pub fn build_predicate(&self, row_addrs: std::slice::Iter) -> Option { match self { Self::Bitmap(bitmap) => Some( - row_ids + row_addrs .map(|&id| !bitmap.contains(id as u32)) .collect::>(), ), Self::Set(set) => Some( - row_ids + row_addrs .map(|&id| !set.contains(&(id as u32))) .collect::>(), ), @@ -115,6 +123,64 @@ impl DeletionVector { } } +/// Maps a naive offset into a fragment to the local row offset that is +/// not deleted. +/// +/// For example, if the deletion vector is [0, 1, 2], then the mapping +/// would be: +/// +/// - 0 -> 3 +/// - 1 -> 4 +/// - 2 -> 5 +/// +/// and so on. +/// +/// This expects a monotonically increasing sequence of input offsets. State +/// is re-used between calls to `map_offset` to make the mapping more efficient. +pub struct OffsetMapper { + dv: Arc, + left: u32, + last_diff: u32, +} + +impl OffsetMapper { + pub fn new(dv: Arc) -> Self { + Self { + dv, + left: 0, + last_diff: 0, + } + } + + pub fn map_offset(&mut self, offset: u32) -> u32 { + // The best initial guess is the offset + last diff. That's the right + // answer if there are no deletions in the range between the last + // offset and the current one. + let mut mid = offset + self.last_diff; + let mut right = offset + self.dv.len() as u32; + loop { + let deleted_in_range = self.dv.range_cardinality(0..(mid + 1)) as u32; + match mid.cmp(&(offset + deleted_in_range)) { + std::cmp::Ordering::Equal if !self.dv.contains(mid) => { + self.last_diff = mid - offset; + return mid; + } + std::cmp::Ordering::Less => { + assert_ne!(self.left, mid + 1); + self.left = mid + 1; + mid = self.left + (right - self.left) / 2; + } + // There are cases where the mid is deleted but also equal in + // comparison. For those we need to find a lower value. + std::cmp::Ordering::Greater | std::cmp::Ordering::Equal => { + right = mid; + mid = self.left + (right - self.left) / 2; + } + } + } + } +} + impl Default for DeletionVector { fn default() -> Self { Self::NoDeletions @@ -194,7 +260,6 @@ impl Extend for DeletionVector { /// pub fn get(i: u32) -> bool { ... } /// } /// impl BitAnd for DeletionVector { ... } - impl IntoIterator for DeletionVector { type IntoIter = Box + Send>; type Item = u32; @@ -242,4 +307,28 @@ mod test { let dv = DeletionVector::from_iter(0..(BITMAP_THRESDHOLD as u32)); assert!(matches!(dv, DeletionVector::Bitmap(_))); } + + #[test] + fn test_map_offsets() { + let dv = DeletionVector::from_iter(vec![3, 5]); + let mut mapper = OffsetMapper::new(Arc::new(dv)); + + let offsets = [0, 1, 2, 3, 4, 5, 6]; + let mut output = Vec::new(); + for offset in offsets.iter() { + output.push(mapper.map_offset(*offset)); + } + assert_eq!(output, vec![0, 1, 2, 4, 6, 7, 8]); + + let dv = DeletionVector::from_iter(vec![0, 1, 2]); + let mut mapper = OffsetMapper::new(Arc::new(dv)); + + let offsets = [0, 1, 2, 3, 4, 5, 6]; + + let mut output = Vec::new(); + for offset in offsets.iter() { + output.push(mapper.map_offset(*offset)); + } + assert_eq!(output, vec![3, 4, 5, 6, 7, 8, 9]); + } } diff --git a/rust/lance-core/src/utils/futures.rs b/rust/lance-core/src/utils/futures.rs index 9acce93ce27..2293874c91e 100644 --- a/rust/lance-core/src/utils/futures.rs +++ b/rust/lance-core/src/utils/futures.rs @@ -8,6 +8,7 @@ use std::{ }; use futures::{stream::BoxStream, Stream, StreamExt}; +use pin_project::pin_project; use tokio::sync::Semaphore; use tokio_util::sync::PollSemaphore; @@ -74,7 +75,7 @@ impl<'a, T: Clone> SharedStream<'a, T> { } } -impl<'a, T: Clone> Stream for SharedStream<'a, T> { +impl Stream for SharedStream<'_, T> { type Item = T; fn poll_next( @@ -216,6 +217,53 @@ impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> { } } +#[pin_project] +pub struct FinallyStream { + #[pin] + stream: S, + f: Option, +} + +impl FinallyStream { + pub fn new(stream: S, f: F) -> Self { + Self { stream, f: Some(f) } + } +} + +impl Stream for FinallyStream { + type Item = S::Item; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + let res = this.stream.poll_next(cx); + if matches!(res, std::task::Poll::Ready(None)) { + // It's possible that None is polled multiple times, but we only call the function once + if let Some(f) = this.f.take() { + f(); + } + } + res + } +} + +pub trait FinallyStreamExt: Stream + Sized { + fn finally(self, f: F) -> FinallyStream { + FinallyStream { + stream: self, + f: Some(f), + } + } +} + +impl FinallyStreamExt for S { + fn finally(self, f: F) -> FinallyStream { + FinallyStream::new(self, f) + } +} + #[cfg(test)] mod tests { diff --git a/rust/lance-core/src/utils/hash.rs b/rust/lance-core/src/utils/hash.rs index 58e6fd47bfb..14ef805a58f 100644 --- a/rust/lance-core/src/utils/hash.rs +++ b/rust/lance-core/src/utils/hash.rs @@ -7,13 +7,13 @@ use std::hash::Hasher; // the equality for this `U8SliceKey` means that the &[u8] contents are equal. #[derive(Eq)] pub struct U8SliceKey<'a>(pub &'a [u8]); -impl<'a> PartialEq for U8SliceKey<'a> { +impl PartialEq for U8SliceKey<'_> { fn eq(&self, other: &Self) -> bool { self.0 == other.0 } } -impl<'a> std::hash::Hash for U8SliceKey<'a> { +impl std::hash::Hash for U8SliceKey<'_> { fn hash(&self, state: &mut H) { self.0.hash(state); } diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index edbf754375b..b0c941d71ad 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -10,7 +10,7 @@ use arrow_array::{Array, BinaryArray, GenericBinaryArray}; use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer}; use byteorder::{ReadBytesExt, WriteBytesExt}; use deepsize::DeepSizeOf; -use roaring::{MultiOps, RoaringBitmap}; +use roaring::{MultiOps, RoaringBitmap, RoaringTreemap}; use crate::Result; @@ -517,7 +517,8 @@ impl RowIdTreeMap { /// for each entry: /// * u32: fragment_id /// * u32: bitmap size - /// * [u8]: bitmap + /// * \[u8\]: bitmap + /// /// If bitmap size is zero then the entire fragment is selected. pub fn serialize_into(&self, mut writer: W) -> Result<()> { writer.write_u32::(self.inner.len() as u32)?; @@ -706,6 +707,16 @@ impl<'a> FromIterator<&'a u64> for RowIdTreeMap { } } +impl From for RowIdTreeMap { + fn from(roaring: RoaringTreemap) -> Self { + let mut inner = BTreeMap::new(); + for (fragment, set) in roaring.bitmaps() { + inner.insert(fragment, RowIdSelection::Partial(set.clone())); + } + Self { inner } + } +} + impl Extend for RowIdTreeMap { fn extend>(&mut self, iter: T) { for row_id in iter { diff --git a/rust/lance-core/src/utils/parse.rs b/rust/lance-core/src/utils/parse.rs new file mode 100644 index 00000000000..7efea7cfc72 --- /dev/null +++ b/rust/lance-core/src/utils/parse.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +/// Parse a string into a boolean value. +pub fn str_is_truthy(val: &str) -> bool { + val.eq_ignore_ascii_case("1") + | val.eq_ignore_ascii_case("true") + | val.eq_ignore_ascii_case("on") + | val.eq_ignore_ascii_case("yes") + | val.eq_ignore_ascii_case("y") +} diff --git a/rust/lance-core/src/utils/path.rs b/rust/lance-core/src/utils/path.rs index 72d7311894f..fb4ec56eb47 100644 --- a/rust/lance-core/src/utils/path.rs +++ b/rust/lance-core/src/utils/path.rs @@ -11,7 +11,7 @@ impl LancePathExt for Path { fn child_path(&self, path: &Path) -> Path { let mut new_path = self.clone(); for part in path.parts() { - new_path = path.child(part); + new_path = new_path.child(part); } new_path } diff --git a/rust/lance-core/src/utils/testing.rs b/rust/lance-core/src/utils/testing.rs index 9746787f715..f1112364863 100644 --- a/rust/lance-core/src/utils/testing.rs +++ b/rust/lance-core/src/utils/testing.rs @@ -218,7 +218,7 @@ impl Default for MockClock<'_> { } } -impl<'a> MockClock<'a> { +impl MockClock<'_> { pub fn new() -> Self { Default::default() } @@ -228,7 +228,7 @@ impl<'a> MockClock<'a> { } } -impl<'a> Drop for MockClock<'a> { +impl Drop for MockClock<'_> { fn drop(&mut self) { // Reset the clock to the epoch mock_instant::MockClock::set_system_time(TimeDelta::try_days(0).unwrap().to_std().unwrap()); diff --git a/rust/lance-core/src/utils/tokio.rs b/rust/lance-core/src/utils/tokio.rs index 4db88cdc0bb..857666d114e 100644 --- a/rust/lance-core/src/utils/tokio.rs +++ b/rust/lance-core/src/utils/tokio.rs @@ -17,8 +17,9 @@ pub fn get_num_compute_intensive_cpus() -> usize { let cpus = num_cpus::get(); if cpus <= *IO_CORE_RESERVATION { - // on systems with only 1 CPU there is no point in warning - if cpus > 1 { + // If the user is not setting a custom value for LANCE_IO_CORE_RESERVATION then we don't emit + // a warning because they're just on a small machine and there isn't much they can do about it. + if cpus > 2 { log::warn!( "Number of CPUs is less than or equal to the number of IO core reservations. \ This is not a supported configuration. using 1 CPU for compute intensive tasks." diff --git a/rust/lance-core/src/utils/tracing.rs b/rust/lance-core/src/utils/tracing.rs index 6505358cf83..067a72d7157 100644 --- a/rust/lance-core/src/utils/tracing.rs +++ b/rust/lance-core/src/utils/tracing.rs @@ -47,3 +47,20 @@ impl StreamTracingExt for S { } } } + +pub const TRACE_FILE_AUDIT: &str = "lance::file_audit"; +pub const AUDIT_MODE_CREATE: &str = "create"; +pub const AUDIT_MODE_DELETE: &str = "delete"; +pub const AUDIT_MODE_DELETE_UNVERIFIED: &str = "delete_unverified"; +pub const AUDIT_TYPE_DELETION: &str = "deletion"; +pub const AUDIT_TYPE_MANIFEST: &str = "manifest"; +pub const AUDIT_TYPE_INDEX: &str = "index"; +pub const AUDIT_TYPE_DATA: &str = "data"; +pub const TRACE_FILE_CREATE: &str = "create"; +pub const TRACE_IO_EVENTS: &str = "lance::io_events"; +pub const IO_TYPE_OPEN_SCALAR: &str = "open_scalar_index"; +pub const IO_TYPE_OPEN_VECTOR: &str = "open_vector_index"; +pub const IO_TYPE_LOAD_VECTOR_PART: &str = "load_vector_part"; +pub const IO_TYPE_LOAD_SCALAR_PART: &str = "load_scalar_part"; +pub const TRACE_EXECUTION: &str = "lance::execution"; +pub const EXECUTION_PLAN_RUN: &str = "plan_run"; diff --git a/rust/lance-datafusion/Cargo.toml b/rust/lance-datafusion/Cargo.toml index 41af9afb284..3d60207422f 100644 --- a/rust/lance-datafusion/Cargo.toml +++ b/rust/lance-datafusion/Cargo.toml @@ -10,8 +10,8 @@ categories.workspace = true description = "Internal utilities used by other lance modules to simplify working with datafusion" [dependencies] -arrow.workspace = true -arrow-array.workspace = true +arrow = { workspace = true, features = ["ffi"] } +arrow-array = { workspace = true, features = ["ffi"] } arrow-buffer.workspace = true arrow-schema.workspace = true arrow-select.workspace = true @@ -21,18 +21,22 @@ datafusion.workspace = true datafusion-common.workspace = true datafusion-functions.workspace = true datafusion-physical-expr.workspace = true -datafusion-substrait = { version = "41.0", optional = true } +datafusion-substrait = { version = "46.0", optional = true } futures.workspace = true lance-arrow.workspace = true lance-core = { workspace = true, features = ["datafusion"] } +lance-datagen.workspace = true lazy_static.workspace = true log.workspace = true +pin-project.workspace = true prost.workspace = true snafu.workspace = true +tempfile.workspace = true tokio.workspace = true +tracing.workspace = true [dev-dependencies] -substrait-expr = { version = "0.2.1" } +substrait-expr = { version = "0.2.3" } lance-datagen.workspace = true [features] diff --git a/rust/lance-datafusion/src/datagen.rs b/rust/lance-datafusion/src/datagen.rs new file mode 100644 index 00000000000..70b07b9a20b --- /dev/null +++ b/rust/lance-datafusion/src/datagen.rs @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use datafusion::{ + execution::SendableRecordBatchStream, + physical_plan::{stream::RecordBatchStreamAdapter, ExecutionPlan}, +}; +use datafusion_common::DataFusionError; +use futures::TryStreamExt; +use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount}; + +use crate::exec::OneShotExec; + +pub trait DatafusionDatagenExt { + fn into_df_stream( + self, + batch_size: RowCount, + num_batches: BatchCount, + ) -> SendableRecordBatchStream; + + fn into_df_exec(self, batch_size: RowCount, num_batches: BatchCount) -> Arc; +} + +impl DatafusionDatagenExt for BatchGeneratorBuilder { + fn into_df_stream( + self, + batch_size: RowCount, + num_batches: BatchCount, + ) -> SendableRecordBatchStream { + let (stream, schema) = self.into_reader_stream(batch_size, num_batches); + let stream = stream.map_err(DataFusionError::from); + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) + } + + fn into_df_exec(self, batch_size: RowCount, num_batches: BatchCount) -> Arc { + let stream = self.into_df_stream(batch_size, num_batches); + Arc::new(OneShotExec::new(stream)) + } +} diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index c3f64e1bec6..30fba77af4f 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -8,29 +8,44 @@ use std::sync::{Arc, Mutex}; use arrow_array::RecordBatch; use arrow_schema::Schema as ArrowSchema; use datafusion::{ + catalog::streaming::StreamingTable, dataframe::DataFrame, - datasource::streaming::StreamingTable, execution::{ context::{SessionConfig, SessionContext}, disk_manager::DiskManagerConfig, memory_pool::FairSpillPool, - runtime_env::{RuntimeConfig, RuntimeEnv}, + runtime_env::RuntimeEnvBuilder, TaskContext, }, physical_plan::{ - display::DisplayableExecutionPlan, stream::RecordBatchStreamAdapter, - streaming::PartitionStream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, - SendableRecordBatchStream, + analyze::AnalyzeExec, + display::DisplayableExecutionPlan, + execution_plan::{Boundedness, EmissionType}, + stream::RecordBatchStreamAdapter, + streaming::PartitionStream, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }, }; use datafusion_common::{DataFusionError, Statistics}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use lazy_static::lazy_static; -use futures::stream; +use futures::{stream, StreamExt}; use lance_arrow::SchemaExt; -use lance_core::Result; +use lance_core::{ + utils::{ + futures::FinallyStreamExt, + tracing::{EXECUTION_PLAN_RUN, TRACE_EXECUTION}, + }, + Error, Result, +}; use log::{debug, info, warn}; +use snafu::location; + +use crate::utils::{ + MetricsExt, BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC, IOPS_METRIC, + PARTS_LOADED_METRIC, REQUESTS_METRIC, +}; /// An source execution node created from an existing stream /// @@ -57,7 +72,8 @@ impl OneShotExec { properties: PlanProperties::new( EquivalenceProperties::new(schema), Partitioning::RoundRobinBatch(1), - datafusion::physical_plan::ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ), } } @@ -161,10 +177,31 @@ impl ExecutionPlan for OneShotExec { } } -#[derive(Debug, Default, Clone)] +/// Callback for reporting statistics after a scan +pub type ExecutionStatsCallback = Arc; + +#[derive(Default, Clone)] pub struct LanceExecutionOptions { pub use_spilling: bool, pub mem_pool_size: Option, + pub batch_size: Option, + pub target_partition: Option, + pub execution_stats_callback: Option, +} + +impl std::fmt::Debug for LanceExecutionOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LanceExecutionOptions") + .field("use_spilling", &self.use_spilling) + .field("mem_pool_size", &self.mem_pool_size) + .field("batch_size", &self.batch_size) + .field("target_partition", &self.target_partition) + .field( + "execution_stats_callback", + &self.execution_stats_callback.is_some(), + ) + .finish() + } } const DEFAULT_LANCE_MEM_POOL_SIZE: u64 = 100 * 1024 * 1024; @@ -197,42 +234,121 @@ impl LanceExecutionOptions { } } -pub fn new_session_context(options: LanceExecutionOptions) -> SessionContext { - let session_config = SessionConfig::new(); - let mut runtime_config = RuntimeConfig::new(); +pub fn new_session_context(options: &LanceExecutionOptions) -> SessionContext { + let mut session_config = SessionConfig::new(); + let mut runtime_env_builder = RuntimeEnvBuilder::new(); + if let Some(target_partition) = options.target_partition { + session_config = session_config.with_target_partitions(target_partition); + } if options.use_spilling() { - runtime_config.disk_manager = DiskManagerConfig::NewOs; - runtime_config.memory_pool = Some(Arc::new(FairSpillPool::new( - options.mem_pool_size() as usize - ))); + runtime_env_builder = runtime_env_builder + .with_disk_manager(DiskManagerConfig::new()) + .with_memory_pool(Arc::new(FairSpillPool::new( + options.mem_pool_size() as usize + ))); } - let runtime_env = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + let runtime_env = runtime_env_builder.build_arc().unwrap(); SessionContext::new_with_config_rt(session_config, runtime_env) } lazy_static! { static ref DEFAULT_SESSION_CONTEXT: SessionContext = - new_session_context(LanceExecutionOptions::default()); + new_session_context(&LanceExecutionOptions::default()); static ref DEFAULT_SESSION_CONTEXT_WITH_SPILLING: SessionContext = { - new_session_context(LanceExecutionOptions { + new_session_context(&LanceExecutionOptions { use_spilling: true, ..Default::default() }) }; } -pub fn get_session_context(options: LanceExecutionOptions) -> SessionContext { - let session_ctx: SessionContext; - if options.mem_pool_size() == DEFAULT_LANCE_MEM_POOL_SIZE { - if options.use_spilling() { - session_ctx = DEFAULT_SESSION_CONTEXT_WITH_SPILLING.clone(); +pub fn get_session_context(options: &LanceExecutionOptions) -> SessionContext { + if options.mem_pool_size() == DEFAULT_LANCE_MEM_POOL_SIZE && options.target_partition.is_none() + { + return if options.use_spilling() { + DEFAULT_SESSION_CONTEXT_WITH_SPILLING.clone() } else { - session_ctx = DEFAULT_SESSION_CONTEXT.clone(); - } - } else { - session_ctx = new_session_context(options) + DEFAULT_SESSION_CONTEXT.clone() + }; + } + new_session_context(options) +} + +fn get_task_context( + session_ctx: &SessionContext, + options: &LanceExecutionOptions, +) -> Arc { + let mut state = session_ctx.state(); + if let Some(batch_size) = options.batch_size.as_ref() { + state.config_mut().options_mut().execution.batch_size = *batch_size; + } + + state.task_ctx() +} + +#[derive(Default)] +pub struct ExecutionSummaryCounts { + pub iops: usize, + pub requests: usize, + pub bytes_read: usize, + pub indices_loaded: usize, + pub parts_loaded: usize, + pub index_comparisons: usize, +} + +fn visit_node(node: &dyn ExecutionPlan, counts: &mut ExecutionSummaryCounts) { + if let Some(metrics) = node.metrics() { + counts.iops += metrics + .find_count(IOPS_METRIC) + .map(|c| c.value()) + .unwrap_or(0); + counts.requests += metrics + .find_count(REQUESTS_METRIC) + .map(|c| c.value()) + .unwrap_or(0); + counts.bytes_read += metrics + .find_count(BYTES_READ_METRIC) + .map(|c| c.value()) + .unwrap_or(0); + counts.indices_loaded += metrics + .find_count(INDICES_LOADED_METRIC) + .map(|c| c.value()) + .unwrap_or(0); + counts.parts_loaded += metrics + .find_count(PARTS_LOADED_METRIC) + .map(|c| c.value()) + .unwrap_or(0); + counts.index_comparisons += metrics + .find_count(INDEX_COMPARISONS_METRIC) + .map(|c| c.value()) + .unwrap_or(0); + } + for child in node.children() { + visit_node(child.as_ref(), counts); + } +} + +fn report_plan_summary_metrics(plan: &dyn ExecutionPlan, options: &LanceExecutionOptions) { + let output_rows = plan + .metrics() + .map(|m| m.output_rows().unwrap_or(0)) + .unwrap_or(0); + let mut counts = ExecutionSummaryCounts::default(); + visit_node(plan, &mut counts); + tracing::info!( + target: TRACE_EXECUTION, + type = EXECUTION_PLAN_RUN, + output_rows, + iops = counts.iops, + requests = counts.requests, + bytes_read = counts.bytes_read, + indices_loaded = counts.indices_loaded, + parts_loaded = counts.parts_loaded, + index_comparisons = counts.index_comparisons, + ); + if let Some(callback) = options.execution_stats_callback.as_ref() { + callback(&counts); } - session_ctx } /// Executes a plan using default session & runtime configuration @@ -247,12 +363,43 @@ pub fn execute_plan( DisplayableExecutionPlan::new(plan.as_ref()).indent(true) ); - let session_ctx = get_session_context(options); + let session_ctx = get_session_context(&options); // NOTE: we are only executing the first partition here. Therefore, if // the plan has more than one partition, we will be missing data. assert_eq!(plan.properties().partitioning.partition_count(), 1); - Ok(plan.execute(0, session_ctx.task_ctx())?) + let stream = plan.execute(0, get_task_context(&session_ctx, &options))?; + + let schema = stream.schema(); + let stream = stream.finally(move || { + report_plan_summary_metrics(plan.as_ref(), &options); + }); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) +} + +pub async fn analyze_plan( + plan: Arc, + options: LanceExecutionOptions, +) -> Result { + let schema = plan.schema(); + let analyze = Arc::new(AnalyzeExec::new(true, true, plan, schema)); + + let session_ctx = get_session_context(&options); + assert_eq!(analyze.properties().partitioning.partition_count(), 1); + let mut stream = analyze + .execute(0, get_task_context(&session_ctx, &options)) + .map_err(|err| { + Error::io( + format!("Failed to execute analyze plan: {}", err), + location!(), + ) + })?; + + // fully execute the plan + while (stream.next().await).is_some() {} + + let display = DisplayableExecutionPlan::with_metrics(analyze.as_ref()); + Ok(format!("{}", display.indent(true))) } pub trait SessionContextExt { @@ -270,6 +417,16 @@ struct OneShotPartitionStream { schema: Arc, } +impl std::fmt::Debug for OneShotPartitionStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let data = self.data.lock().unwrap(); + f.debug_struct("OneShotPartitionStream") + .field("exhausted", &data.is_none()) + .field("schema", self.schema.as_ref()) + .finish() + } +} + impl OneShotPartitionStream { fn new(data: SendableRecordBatchStream) -> Self { let schema = data.schema(); diff --git a/rust/lance-datafusion/src/expr.rs b/rust/lance-datafusion/src/expr.rs index dbc450b654e..9e80a3d8a4f 100644 --- a/rust/lance-datafusion/src/expr.rs +++ b/rust/lance-datafusion/src/expr.rs @@ -393,6 +393,32 @@ pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option None, } } + ScalarValue::FixedSizeBinary(len, value) => match ty { + DataType::FixedSizeBinary(len2) => { + if len == len2 { + Some(ScalarValue::FixedSizeBinary(*len, value.clone())) + } else { + None + } + } + DataType::Binary => Some(ScalarValue::Binary(value.clone())), + _ => None, + }, + ScalarValue::Binary(value) => match ty { + DataType::Binary => Some(ScalarValue::Binary(value.clone())), + DataType::FixedSizeBinary(len) => { + if let Some(value) = value { + if value.len() == *len as usize { + Some(ScalarValue::FixedSizeBinary(*len, Some(value.clone()))) + } else { + None + } + } else { + None + } + } + _ => None, + }, _ => None, } } diff --git a/rust/lance-datafusion/src/lib.rs b/rust/lance-datafusion/src/lib.rs index a2a3c9ee342..a99afbbbe08 100644 --- a/rust/lance-datafusion/src/lib.rs +++ b/rust/lance-datafusion/src/lib.rs @@ -3,11 +3,13 @@ pub mod chunker; pub mod dataframe; +pub mod datagen; pub mod exec; pub mod expr; pub mod logical_expr; pub mod planner; pub mod projection; +pub mod spill; pub mod sql; #[cfg(feature = "substrait")] pub mod substrait; diff --git a/rust/lance-datafusion/src/logical_expr.rs b/rust/lance-datafusion/src/logical_expr.rs index ebfb73ea03f..c49526a1798 100644 --- a/rust/lance-datafusion/src/logical_expr.rs +++ b/rust/lance-datafusion/src/logical_expr.rs @@ -9,7 +9,7 @@ use arrow_schema::DataType; use crate::expr::safe_coerce_scalar; use datafusion::logical_expr::{expr::ScalarFunction, BinaryExpr, Operator}; -use datafusion::logical_expr::{ScalarUDF, ScalarUDFImpl}; +use datafusion::logical_expr::{Between, ScalarUDF, ScalarUDFImpl}; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_functions::core::getfield::GetFieldFunc; @@ -17,7 +17,7 @@ use lance_arrow::DataTypeExt; use lance_core::datatypes::Schema; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; /// Resolve a Value fn resolve_value(expr: &Expr, data_type: &DataType) -> Result { match expr { @@ -91,6 +91,23 @@ pub fn resolve_column_type(expr: &Expr, schema: &Schema) -> Option { /// - *schema*: lance schema. pub fn resolve_expr(expr: &Expr, schema: &Schema) -> Result { match expr { + Expr::Between(Between { + expr: inner_expr, + low, + high, + negated, + }) => { + if let Some(inner_expr_type) = resolve_column_type(inner_expr.as_ref(), schema) { + Ok(Expr::Between(Between { + expr: inner_expr.clone(), + low: Box::new(coerce_expr(low.as_ref(), &inner_expr_type)?), + high: Box::new(coerce_expr(high.as_ref(), &inner_expr_type)?), + negated: *negated, + })) + } else { + Ok(expr.clone()) + } + } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { if matches!(op, Operator::And | Operator::Or) { Ok(Expr::BinaryExpr(BinaryExpr { diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index a8d985d82a8..13194db916f 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors //! Exec plan planner @@ -23,7 +21,7 @@ use datafusion::config::ConfigOptions; use datafusion::error::Result as DFResult; use datafusion::execution::config::SessionConfig; use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::expr::ScalarFunction; use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawFieldAccessExpr}; @@ -34,13 +32,13 @@ use datafusion::logical_expr::{ use datafusion::optimizer::simplify_expressions::SimplifyContext; use datafusion::sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; use datafusion::sql::sqlparser::ast::{ - Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo, Expr as SQLExpr, - Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript, TimezoneInfo, - UnaryOperator, Value, + AccessExpr, Array as SQLArray, BinaryOperator, DataType as SQLDataType, ExactNumberInfo, + Expr as SQLExpr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Subscript, + TimezoneInfo, UnaryOperator, Value, }; use datafusion::{ common::Column, - logical_expr::{col, BinaryExpr, Like, Operator}, + logical_expr::{col, Between, BinaryExpr, Like, Operator}, physical_expr::execution_props::ExecutionProps, physical_plan::PhysicalExpr, prelude::Expr, @@ -49,7 +47,7 @@ use datafusion::{ use datafusion_functions::core::getfield::GetFieldFunc; use lance_arrow::cast::cast_with_options; use lance_core::datatypes::Schema; -use snafu::{location, Location}; +use snafu::location; use lance_core::{Error, Result}; @@ -162,8 +160,7 @@ struct LanceContextProvider { impl Default for LanceContextProvider { fn default() -> Self { let config = SessionConfig::new(); - let runtime_config = RuntimeConfig::new(); - let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + let runtime = RuntimeEnvBuilder::new().build_arc().unwrap(); let mut state_builder = SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) @@ -415,6 +412,7 @@ impl Planner { enable_ident_normalization: false, support_varchar_with_length: false, enable_options_value_normalization: false, + collect_spans: false, }, ); @@ -443,7 +441,7 @@ impl Planner { SQLDataType::String(_) => Ok(ArrowDataType::Utf8), SQLDataType::Binary(_) => Ok(ArrowDataType::Binary), SQLDataType::Float(_) => Ok(ArrowDataType::Float32), - SQLDataType::Double => Ok(ArrowDataType::Float64), + SQLDataType::Double(_) => Ok(ArrowDataType::Float64), SQLDataType::Boolean => Ok(ArrowDataType::Boolean), SQLDataType::TinyInt(_) => Ok(ArrowDataType::Int8), SQLDataType::SmallInt(_) => Ok(ArrowDataType::Int16), @@ -636,7 +634,7 @@ impl Planner { })) } SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.parse_sql_expr(expr)?))), - SQLExpr::IsNotFalse(_) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))), + SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.parse_sql_expr(expr)?))), SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.parse_sql_expr(expr)?))), SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.parse_sql_expr(expr)?))), SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(self.parse_sql_expr(expr)?))), @@ -660,6 +658,7 @@ impl Planner { expr, pattern, escape_char, + any: _, } => Ok(Expr::Like(Like::new( *negated, Box::new(self.parse_sql_expr(expr)?), @@ -672,6 +671,7 @@ impl Planner { expr, pattern, escape_char, + any: _, } => Ok(Expr::Like(Like::new( *negated, Box::new(self.parse_sql_expr(expr)?), @@ -685,66 +685,69 @@ impl Planner { expr: Box::new(self.parse_sql_expr(expr)?), data_type: self.parse_type(data_type)?, })), - SQLExpr::MapAccess { column, keys } => { - let mut expr = self.parse_sql_expr(column)?; - - for key in keys { - let field_access = match &key.key { - SQLExpr::Value( - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => GetFieldAccess::NamedStructField { + SQLExpr::JsonAccess { .. } => Err(Error::invalid_input( + "JSON access is not supported", + location!(), + )), + SQLExpr::CompoundFieldAccess { root, access_chain } => { + let mut expr = self.parse_sql_expr(root)?; + + for access in access_chain { + let field_access = match access { + // x.y or x['y'] + AccessExpr::Dot(SQLExpr::Identifier(Ident { value: s, .. })) + | AccessExpr::Subscript(Subscript::Index { + index: + SQLExpr::Value( + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), + ), + }) => GetFieldAccess::NamedStructField { name: ScalarValue::from(s.as_str()), }, - SQLExpr::JsonAccess { .. } => { + AccessExpr::Subscript(Subscript::Index { index }) => { + let key = Box::new(self.parse_sql_expr(index)?); + GetFieldAccess::ListIndex { key } + } + AccessExpr::Subscript(Subscript::Slice { .. }) => { return Err(Error::invalid_input( - "JSON access is not supported", + "Slice subscript is not supported", location!(), )); } - key => { - let key = Box::new(self.parse_sql_expr(key)?); - GetFieldAccess::ListIndex { key } + _ => { + // Handle other cases like JSON access + // Note: JSON access is not supported in lance + return Err(Error::invalid_input( + "Only dot notation or index access is supported for field access", + location!(), + )); } }; let field_access_expr = RawFieldAccessExpr { expr, field_access }; - expr = self.plan_field_access(field_access_expr)?; } Ok(expr) } - SQLExpr::Subscript { expr, subscript } => { + SQLExpr::Between { + expr, + negated, + low, + high, + } => { + // Parse the main expression and bounds let expr = self.parse_sql_expr(expr)?; - - let field_access = match subscript.as_ref() { - Subscript::Index { index } => match index { - SQLExpr::Value( - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => GetFieldAccess::NamedStructField { - name: ScalarValue::from(s.as_str()), - }, - SQLExpr::JsonAccess { .. } => { - return Err(Error::invalid_input( - "JSON access is not supported", - location!(), - )); - } - _ => { - let key = Box::new(self.parse_sql_expr(index)?); - GetFieldAccess::ListIndex { key } - } - }, - Subscript::Slice { .. } => { - return Err(Error::invalid_input( - "Slice subscript is not supported", - location!(), - )); - } - }; - - let field_access_expr = RawFieldAccessExpr { expr, field_access }; - self.plan_field_access(field_access_expr) + let low = self.parse_sql_expr(low)?; + let high = self.parse_sql_expr(high)?; + + let between = Expr::Between(Between::new( + Box::new(expr), + *negated, + Box::new(low), + Box::new(high), + )); + Ok(between) } _ => Err(Error::invalid_input( format!("Expression '{expr}' is not supported SQL in lance"), @@ -785,7 +788,7 @@ impl Planner { /// TODO: use SqlToRel from Datafusion directly? fn try_decode_hex_literal(s: &str) -> Option> { let hex_bytes = s.as_bytes(); - let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2); + let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2)); let start_idx = hex_bytes.len() % 2; if start_idx > 0 { @@ -796,7 +799,7 @@ impl Planner { for i in (start_idx..hex_bytes.len()).step_by(2) { let high = Self::try_decode_hex_char(hex_bytes[i])?; let low = Self::try_decode_hex_char(hex_bytes[i + 1])?; - decoded_bytes.push(high << 4 | low); + decoded_bytes.push((high << 4) | low); } Some(decoded_bytes) @@ -1025,20 +1028,14 @@ mod tests { } } - let expected = Expr::Column(Column { - relation: None, - name: "s0".to_string(), - }); + let expected = Expr::Column(Column::new_unqualified("s0")); assert_column_eq(&planner, "s0", &expected); assert_column_eq(&planner, "`s0`", &expected); let expected = Expr::ScalarFunction(ScalarFunction { func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())), args: vec![ - Expr::Column(Column { - relation: None, - name: "st".to_string(), - }), + Expr::Column(Column::new_unqualified("st")), Expr::Literal(ScalarValue::Utf8(Some("s1".to_string()))), ], }); @@ -1052,10 +1049,7 @@ mod tests { Expr::ScalarFunction(ScalarFunction { func: Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())), args: vec![ - Expr::Column(Column { - relation: None, - name: "st".to_string(), - }), + Expr::Column(Column::new_unqualified("st")), Expr::Literal(ScalarValue::Utf8(Some("st".to_string()))), ], }), @@ -1463,6 +1457,98 @@ mod tests { } } + #[test] + fn test_sql_between() { + use arrow_array::{Float64Array, Int32Array, TimestampMicrosecondArray}; + use arrow_schema::{DataType, Field, Schema, TimeUnit}; + use std::sync::Arc; + + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Float64, false), + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + ])); + + let planner = Planner::new(schema.clone()); + + // Test integer BETWEEN + let expr = planner + .parse_filter("x BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)") + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + // Create timestamp array with values representing: + // 2024-01-01 00:00:00 to 2024-01-01 00:00:09 (in microseconds) + let base_ts = 1704067200000000_i64; // 2024-01-01 00:00:00 + let ts_array = TimestampMicrosecondArray::from_iter_values( + (0..10).map(|i| base_ts + i * 1_000_000), // Each value is 1 second apart + ); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from_iter_values(0..10)) as ArrayRef, + Arc::new(Float64Array::from_iter_values((0..10).map(|v| v as f64))), + Arc::new(ts_array), + ], + ) + .unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); + + // Test NOT BETWEEN + let expr = planner + .parse_filter("x NOT BETWEEN CAST(3 AS INT) AND CAST(7 AS INT)") + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + true, true, true, false, false, false, false, false, true, true + ]) + ); + + // Test floating point BETWEEN + let expr = planner.parse_filter("y BETWEEN 2.5 AND 6.5").unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, false, false, false + ]) + ); + + // Test timestamp BETWEEN + let expr = planner + .parse_filter( + "ts BETWEEN timestamp '2024-01-01 00:00:03' AND timestamp '2024-01-01 00:00:07'", + ) + .unwrap(); + let physical_expr = planner.create_physical_expr(&expr).unwrap(); + + let predicates = physical_expr.evaluate(&batch).unwrap(); + assert_eq!( + predicates.into_array(0).unwrap().as_ref(), + &BooleanArray::from(vec![ + false, false, false, true, true, true, true, true, false, false + ]) + ); + } + #[test] fn test_sql_comparison() { // Create a batch with all data types @@ -1497,9 +1583,7 @@ mod tests { ]; let expected: ArrayRef = Arc::new(BooleanArray::from_iter( - std::iter::repeat(Some(false)) - .take(5) - .chain(std::iter::repeat(Some(true)).take(5)), + std::iter::repeat_n(Some(false), 5).chain(std::iter::repeat_n(Some(true), 5)), )); for expression in expressions { // convert to physical expression diff --git a/rust/lance-datafusion/src/projection.rs b/rust/lance-datafusion/src/projection.rs index d3abca23b75..2d0266b3133 100644 --- a/rust/lance-datafusion/src/projection.rs +++ b/rust/lance-datafusion/src/projection.rs @@ -11,7 +11,7 @@ use datafusion::{ use datafusion_common::DFSchema; use datafusion_physical_expr::{expressions, PhysicalExpr}; use futures::TryStreamExt; -use snafu::{location, Location}; +use snafu::location; use std::{ collections::{HashMap, HashSet}, sync::Arc, diff --git a/rust/lance-datafusion/src/spill.rs b/rust/lance-datafusion/src/spill.rs new file mode 100644 index 00000000000..cb60669a5af --- /dev/null +++ b/rust/lance-datafusion/src/spill.rs @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{ + io::{BufReader, BufWriter}, + path::PathBuf, + sync::{Arc, Mutex}, +}; + +use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; +use arrow_array::RecordBatch; +use arrow_schema::{ArrowError, Schema}; +use datafusion::{ + execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter, +}; +use datafusion_common::DataFusionError; +use lance_arrow::memory::MemoryAccumulator; +use lance_core::error::LanceOptionExt; + +/// Start a spill of Arrow data to a file that can be read later multiple times. +/// +/// Up to `memory_limit` bytes of data can be buffered in memory before a spill +/// is created. If the memory limit is never reached before [`SpillSender::finish()`] +/// is called, then the data will simply be kept in memory and no spill will be +/// created. +/// +/// `path` is the path to the file that may be created. It should not already +/// exist. It is the responsibility of the caller to delete the file after it is +/// no longer needed. +/// +/// The [`SpillSender`] allows you to write batches to the spill. +/// +/// The [`SpillReceiver`] can open a [`SendableRecordBatchStream`] that reads +/// batches from the spill. This can be opened before, during, or after batches +/// have been written to the spill. +/// +/// Once [`SpillSender`] is dropped, the temporary file is deleted. This will +/// cause the [`SpillReceiver`] to return an error if it is still open. +pub fn create_replay_spill( + path: std::path::PathBuf, + schema: Arc, + memory_limit: usize, +) -> (SpillSender, SpillReceiver) { + let initial_status = WriteStatus::default(); + let (status_sender, status_receiver) = tokio::sync::watch::channel(initial_status); + let sender = SpillSender { + memory_limit, + path: path.clone(), + schema: schema.clone(), + state: SpillState::default(), + status_sender, + }; + + let receiver = SpillReceiver { + status_receiver, + path, + schema, + }; + + (sender, receiver) +} + +#[derive(Clone)] +pub struct SpillReceiver { + status_receiver: tokio::sync::watch::Receiver, + path: PathBuf, + schema: Arc, +} + +impl SpillReceiver { + /// Returns a stream of batches from the spill. The stream will emit + /// batches as they are written to the spill. If the spill has already + /// been finished, the stream will emit all batches in the spill. + /// + /// The stream will not complete until [`Self::finish()`] is called. + /// + /// If the spill has been dropped, an error will be returned. + pub fn read(&self) -> SendableRecordBatchStream { + let rx = self.status_receiver.clone(); + let reader = SpillReader::new(rx, self.path.clone()); + + let stream = futures::stream::try_unfold(reader, move |mut reader| async move { + match reader.read().await { + Ok(None) => Ok(None), + Ok(Some(batch)) => Ok(Some((batch, reader))), + Err(err) => Err(err), + } + }); + + Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream)) + } +} + +struct SpillReader { + pub batches_read: usize, + receiver: tokio::sync::watch::Receiver, + state: SpillReaderState, +} + +enum SpillReaderState { + Buffered { spill_path: PathBuf }, + Reader { reader: AsyncStreamReader }, +} + +impl SpillReader { + fn new(receiver: tokio::sync::watch::Receiver, spill_path: PathBuf) -> Self { + Self { + batches_read: 0, + receiver, + state: SpillReaderState::Buffered { spill_path }, + } + } + + async fn wait_for_more_data(&mut self) -> Result>, DataFusionError> { + let status = self + .receiver + .wait_for(|status| { + status.error.is_some() + || status.finished + || status.batches_written() > self.batches_read + }) + .await + .map_err(|_| { + DataFusionError::Execution( + "Spill has been dropped before reader has finish.".into(), + ) + })?; + + if let Some(error) = &status.error { + let mut guard = error.lock().ok().expect_ok()?; + return Err(DataFusionError::from(&mut (*guard))); + } + + if let DataLocation::Buffered { batches } = &status.data_location { + Ok(Some(batches.clone())) + } else { + Ok(None) + } + } + + async fn get_reader(&mut self) -> Result<&AsyncStreamReader, ArrowError> { + if let SpillReaderState::Buffered { spill_path } = &self.state { + let reader = AsyncStreamReader::open(spill_path.clone()).await?; + // Skip batches we've already read before the writer started spilling. + // The read batches were spilled to the file for the benefit of + // future readers, as the spill is replay-able. + for _ in 0..self.batches_read { + reader.read().await?; + } + self.state = SpillReaderState::Reader { reader }; + } + + if let SpillReaderState::Reader { reader } = &mut self.state { + Ok(reader) + } else { + unreachable!() + } + } + + async fn read(&mut self) -> Result, DataFusionError> { + let maybe_data = self.wait_for_more_data().await?; + + if let Some(batches) = maybe_data { + if self.batches_read < batches.len() { + let batch = batches[self.batches_read].clone(); + self.batches_read += 1; + Ok(Some(batch)) + } else { + Ok(None) + } + } else { + let reader = self.get_reader().await?; + let batch = reader.read().await?; + if batch.is_some() { + self.batches_read += 1; + } + Ok(batch) + } + } +} + +/// The sender side of the spill. This is used to write batches to the spill. +/// +/// Note: this must be kept alive until after the readers are done reading the +/// spill. Otherwise, they will return an error. +pub struct SpillSender { + memory_limit: usize, + schema: Arc, + path: PathBuf, + state: SpillState, + status_sender: tokio::sync::watch::Sender, +} + +enum SpillState { + Buffering { + batches: Vec, + memory_accumulator: MemoryAccumulator, + }, + Spilling { + writer: AsyncStreamWriter, + batches_written: usize, + }, + Finished { + batches: Option>, + batches_written: usize, + }, + Errored { + error: Arc>, + }, +} + +impl Default for SpillState { + fn default() -> Self { + Self::Buffering { + batches: Vec::new(), + memory_accumulator: MemoryAccumulator::default(), + } + } +} + +#[derive(Clone, Debug, Default)] +struct WriteStatus { + error: Option>>, + finished: bool, + data_location: DataLocation, +} + +impl WriteStatus { + fn batches_written(&self) -> usize { + match &self.data_location { + DataLocation::Buffered { batches } => batches.len(), + DataLocation::Spilled { + batches_written, .. + } => *batches_written, + } + } +} + +#[derive(Clone, Debug)] +enum DataLocation { + Buffered { batches: Arc<[RecordBatch]> }, + Spilled { batches_written: usize }, +} + +impl Default for DataLocation { + fn default() -> Self { + Self::Buffered { + batches: Arc::new([]), + } + } +} + +/// A DataFusion error that be be emitted multiple times. We provide the +/// Original error first, and subsequent conversions provide a copy with a +/// string representation of the original error. +#[derive(Debug)] +enum SpillError { + Original(DataFusionError), + Copy(DataFusionError), +} + +impl From for SpillError { + fn from(err: DataFusionError) -> Self { + Self::Original(err) + } +} + +impl From<&mut SpillError> for DataFusionError { + fn from(err: &mut SpillError) -> Self { + match err { + SpillError::Original(inner) => { + let copy = Self::Execution(inner.to_string()); + let original = std::mem::replace(err, SpillError::Copy(copy)); + if let SpillError::Original(inner) = original { + inner + } else { + unreachable!() + } + } + SpillError::Copy(Self::Execution(message)) => Self::Execution(message.clone()), + _ => unreachable!(), + } + } +} + +impl From<&SpillState> for WriteStatus { + fn from(state: &SpillState) -> Self { + match state { + SpillState::Buffering { batches, .. } => Self { + finished: false, + data_location: DataLocation::Buffered { + batches: batches.clone().into(), + }, + error: None, + }, + SpillState::Spilling { + batches_written, .. + } => Self { + finished: false, + data_location: DataLocation::Spilled { + batches_written: *batches_written, + }, + error: None, + }, + SpillState::Finished { + batches_written, + batches, + } => { + let data_location = if let Some(batches) = batches { + DataLocation::Buffered { + batches: batches.clone(), + } + } else { + DataLocation::Spilled { + batches_written: *batches_written, + } + }; + Self { + finished: true, + data_location, + error: None, + } + } + SpillState::Errored { error } => Self { + finished: true, + data_location: DataLocation::default(), // Doesn't matter. + error: Some(error.clone()), + }, + } + } +} + +impl SpillSender { + /// Write a batch to the spill. + /// + /// If there is room in the `memory_limit` then the batch is queued. + /// If `memory_limit` is first encountered then all queued batches, and this one, + /// will be written to disk as part of this call. + /// If we are already spilling then the batch will be written to disk as part of this + /// call. + pub async fn write(&mut self, batch: RecordBatch) -> Result<(), DataFusionError> { + if let SpillState::Finished { .. } = self.state { + return Err(DataFusionError::Execution( + "Spill has already been finished".to_string(), + )); + } + + if let SpillState::Errored { .. } = &self.state { + return Err(DataFusionError::Execution( + "Spill has sent an error".to_string(), + )); + } + + let (writer, batches_written) = match &mut self.state { + SpillState::Buffering { + batches, + ref mut memory_accumulator, + } => { + memory_accumulator.record_batch(&batch); + + if memory_accumulator.total() > self.memory_limit { + let writer = + AsyncStreamWriter::open(self.path.clone(), self.schema.clone()).await?; + let batches_written = batches.len(); + for batch in batches.drain(..) { + writer.write(batch).await?; + } + self.state = SpillState::Spilling { + writer, + batches_written, + }; + if let SpillState::Spilling { + writer, + batches_written, + } = &mut self.state + { + (writer, batches_written) + } else { + unreachable!() + } + } else { + batches.push(batch); + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + return Ok(()); + } + } + SpillState::Spilling { + writer, + batches_written, + } => (writer, batches_written), + _ => unreachable!(), + }; + + writer.write(batch).await?; + *batches_written += 1; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + + Ok(()) + } + + /// Send an error to the spill. This will be sent to all readers of the + /// spill. + pub fn send_error(&mut self, err: DataFusionError) { + let error = Arc::new(Mutex::new(err.into())); + self.state = SpillState::Errored { error }; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + } + + /// Complete the spill write. This will finalize the Arrow IPC stream file. + /// The file will remain available for reading until [`Self::shutdown()`] + /// or until the spill is dropped. + pub async fn finish(&mut self) -> Result<(), DataFusionError> { + // We create a temporary state to get an owned copy of current state. + // Since we hold an exclusive reference to `self`, no one should be + // able to see this temporary state. + let tmp_state = SpillState::Finished { + batches_written: 0, + batches: None, + }; + match std::mem::replace(&mut self.state, tmp_state) { + SpillState::Buffering { batches, .. } => { + let batches_written = batches.len(); + self.state = SpillState::Finished { + batches_written, + batches: Some(batches.into()), + }; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + } + SpillState::Spilling { + writer, + batches_written, + } => { + writer.finish().await?; + self.state = SpillState::Finished { + batches_written, + batches: None, + }; + self.status_sender + .send_replace(WriteStatus::from(&self.state)); + } + SpillState::Finished { .. } => { + return Err(DataFusionError::Execution( + "Spill has already been finished".to_string(), + )); + } + SpillState::Errored { .. } => { + return Err(DataFusionError::Execution( + "Spill has sent an error".to_string(), + )); + } + }; + + Ok(()) + } +} + +/// An async wrapper around [`StreamWriter`]. Each call uses [`tokio::task::spawn_blocking`] +/// to spawn a blocking task to write the batch. +struct AsyncStreamWriter { + writer: Arc>>>, +} + +impl AsyncStreamWriter { + pub async fn open(path: PathBuf, schema: Arc) -> Result { + let writer = tokio::task::spawn_blocking(move || { + let file = std::fs::File::create(&path).map_err(ArrowError::from)?; + let writer = BufWriter::new(file); + StreamWriter::try_new(writer, &schema) + }) + .await + .unwrap()?; + let writer = Arc::new(Mutex::new(writer)); + Ok(Self { writer }) + } + + pub async fn write(&self, batch: RecordBatch) -> Result<(), ArrowError> { + let writer = self.writer.clone(); + tokio::task::spawn_blocking(move || { + let mut writer = writer.lock().unwrap(); + writer.write(&batch)?; + writer.flush() + }) + .await + .unwrap() + } + + pub async fn finish(self) -> Result<(), ArrowError> { + let writer = self.writer.clone(); + tokio::task::spawn_blocking(move || { + let mut writer = writer.lock().unwrap(); + writer.finish() + }) + .await + .unwrap() + } +} + +struct AsyncStreamReader { + reader: Arc>>>, +} + +impl AsyncStreamReader { + pub async fn open(path: PathBuf) -> Result { + let reader = tokio::task::spawn_blocking(move || { + let file = std::fs::File::open(&path).map_err(ArrowError::from)?; + let reader = BufReader::new(file); + StreamReader::try_new(reader, None) + }) + .await + .unwrap()?; + let reader = Arc::new(Mutex::new(reader)); + Ok(Self { reader }) + } + + pub async fn read(&self) -> Result, ArrowError> { + let reader = self.reader.clone(); + tokio::task::spawn_blocking(move || { + let mut reader = reader.lock().unwrap(); + reader.next() + }) + .await + .unwrap() + .transpose() + } +} + +#[cfg(test)] +mod tests { + use arrow_array::Int32Array; + use arrow_schema::{DataType, Field}; + use futures::{poll, StreamExt, TryStreamExt}; + + use super::*; + + #[tokio::test] + async fn test_spill() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batches = [ + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![4, 5, 6]))], + ) + .unwrap(), + ]; + + // Create a stream + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), 0); + + // We can open a reader prior to writing any data. No batches will be ready. + let mut stream_before = receiver.read(); + let mut stream_before_next = stream_before.next(); + let poll_res = poll!(&mut stream_before_next); + assert!(poll_res.is_pending()); + + // If we write a batch, the existing reader can now receive it. + spill.write(batches[0].clone()).await.unwrap(); + let stream_before_batch1 = stream_before_next + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_before_batch1, &batches[0]); + let mut stream_before_next = stream_before.next(); + let poll_res = poll!(&mut stream_before_next); + assert!(poll_res.is_pending()); + + // We can also open a ready while the spill is being written to. We can + // retrieve batches written so far immediately. + let mut stream_during = receiver.read(); + let stream_during_batch1 = stream_during + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_during_batch1, &batches[0]); + let mut stream_during_next = stream_during.next(); + let poll_res = poll!(&mut stream_during_next); + assert!(poll_res.is_pending()); + + // Once we finish the spill, readers can get remaining batches and will + // reach the end of the stream. + spill.write(batches[1].clone()).await.unwrap(); + spill.finish().await.unwrap(); + + let stream_before_batch2 = stream_before_next + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_before_batch2, &batches[1]); + assert!(stream_before.next().await.is_none()); + + let stream_during_batch2 = stream_during_next + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_during_batch2, &batches[1]); + assert!(stream_during.next().await.is_none()); + + // Can also start a reader after finishing. + let stream_after = receiver.read(); + let stream_after_batches = stream_after.try_collect::>().await.unwrap(); + assert_eq!(&stream_after_batches, &batches); + + std::fs::remove_file(path).unwrap(); + } + + #[tokio::test] + async fn test_spill_error() { + // Create a spill + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), 0); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + spill.write(batch.clone()).await.unwrap(); + + let mut stream = receiver.read(); + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + + spill.send_error(DataFusionError::ResourcesExhausted("🥱".into())); + let stream_error = stream + .next() + .await + .expect("Expected an error") + .expect_err("Expected an error"); + assert!(matches!( + stream_error, + DataFusionError::ResourcesExhausted(message) if message == "🥱" + )); + + // If we try to write after sending an error, it should return an error. + let err = spill.write(batch).await; + assert!(matches!( + err, + Err(DataFusionError::Execution(message)) if message == "Spill has sent an error" + )); + + // If we try to finish after sending an error, it should return an error. + let err = spill.finish().await; + assert!(matches!( + err, + Err(DataFusionError::Execution(message)) if message == "Spill has sent an error" + )); + + // If we try to read after sending an error, it should return an error. + let mut stream = receiver.read(); + let stream_error = stream + .next() + .await + .expect("Expected an error") + .expect_err("Expected an error"); + assert!(matches!( + stream_error, + DataFusionError::Execution(message) if message.contains("🥱") + )); + + std::fs::remove_file(path).unwrap(); + } + + #[tokio::test] + async fn test_spill_buffered() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let memory_limit = 1024 * 1024; // 1 MiB + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit); + + // 0.5 MB batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))], + ) + .unwrap(); + spill.write(batch.clone()).await.unwrap(); + assert!(!std::fs::exists(&path).unwrap()); + + spill.finish().await.unwrap(); + assert!(!std::fs::exists(&path).unwrap()); + + let mut stream = receiver.read(); + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + + assert!(!std::fs::exists(&path).unwrap()); + } + + #[tokio::test] + async fn test_spill_buffered_transition() { + // Starts as buffered, then spills, then finished. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().join("spill.arrows"); + let memory_limit = 1024 * 1024; // 1 MiB + let (mut spill, receiver) = create_replay_spill(path.clone(), schema.clone(), memory_limit); + + // 0.7 MB batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1; (768 * 1024) / 4]))], + ) + .unwrap(); + spill.write(batch.clone()).await.unwrap(); + assert!(!std::fs::exists(&path).unwrap()); + + let mut stream = receiver.read(); + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + assert!(!std::fs::exists(&path).unwrap()); + + // 0.5 MB batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1; (512 * 1024) / 4]))], + ) + .unwrap(); + spill.write(batch.clone()).await.unwrap(); + assert!(std::fs::exists(&path).unwrap()); + + let stream_batch = stream + .next() + .await + .expect("Expected a batch") + .expect("Expected no error"); + assert_eq!(&stream_batch, &batch); + assert!(std::fs::exists(&path).unwrap()); + + spill.finish().await.unwrap(); + + assert!(stream.next().await.is_none()); + + std::fs::remove_file(path).unwrap(); + } +} diff --git a/rust/lance-datafusion/src/sql.rs b/rust/lance-datafusion/src/sql.rs index 88b4415eda7..0ba166a3011 100644 --- a/rust/lance-datafusion/src/sql.rs +++ b/rust/lance-datafusion/src/sql.rs @@ -3,6 +3,8 @@ //! SQL Parser utility +use std::any::TypeId; + use datafusion::sql::sqlparser::{ ast::{Expr, SelectItem, SetExpr, Statement}, dialect::{Dialect, GenericDialect}, @@ -11,7 +13,7 @@ use datafusion::sql::sqlparser::{ }; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; #[derive(Debug, Default)] struct LanceDialect(GenericDialect); @@ -22,6 +24,10 @@ impl LanceDialect { } impl Dialect for LanceDialect { + fn dialect(&self) -> TypeId { + self.0.dialect() + } + fn is_identifier_start(&self, ch: char) -> bool { self.0.is_identifier_start(ch) } @@ -129,7 +135,8 @@ mod tests { negated: false, expr: Box::new(Expr::Identifier(Ident::new("a"))), pattern: Box::new(Expr::Value(Value::SingleQuotedString("abc%".to_string()))), - escape_char: None + escape_char: None, + any: false, }, expr ); diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 57cffb1261d..84ceb595cf6 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use arrow_schema::Schema; +use arrow_schema::Schema as ArrowSchema; use datafusion::{ datasource::empty::EmptyTable, execution::context::SessionContext, logical_expr::Expr, }; @@ -20,16 +20,16 @@ use datafusion_substrait::substrait::proto::{ r#type::{Kind, Struct}, read_rel::{NamedTable, ReadType}, rel, Expression, ExtendedExpression, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, - RelRoot, + RelRoot, Type, }; use lance_core::{Error, Result}; use prost::Message; -use snafu::{location, Location}; +use snafu::location; use std::collections::HashMap; use std::sync::Arc; /// Convert a DF Expr into a Substrait ExtendedExpressions message -pub fn encode_substrait(expr: Expr, schema: Arc) -> Result> { +pub fn encode_substrait(expr: Expr, schema: Arc) -> Result> { use datafusion::logical_expr::{builder::LogicalTableSource, logical_plan, LogicalPlan}; use datafusion_substrait::substrait::proto::{plan_rel, ExpressionReference, NamedStruct}; @@ -50,8 +50,10 @@ pub fn encode_substrait(expr: Expr, schema: Arc) -> Result> { let session_context = SessionContext::new(); - let substrait_plan = - datafusion_substrait::logical_plan::producer::to_substrait_plan(&plan, &session_context)?; + let substrait_plan = datafusion_substrait::logical_plan::producer::to_substrait_plan( + &plan, + &session_context.state(), + )?; if let Some(plan_rel::RelType::Root(root)) = &substrait_plan.relations[0].rel_type { if let Some(rel::RelType::Filter(filt)) = &root.input.as_ref().unwrap().rel_type { @@ -81,10 +83,17 @@ pub fn encode_substrait(expr: Expr, schema: Arc) -> Result> { } } +fn count_fields(dtype: &Type) -> usize { + match dtype.kind.as_ref().unwrap() { + Kind::Struct(struct_type) => struct_type.types.iter().map(count_fields).sum::() + 1, + _ => 1, + } +} + fn remove_extension_types( substrait_schema: &NamedStruct, - arrow_schema: Arc, -) -> Result<(NamedStruct, Arc, HashMap)> { + arrow_schema: Arc, +) -> Result<(NamedStruct, Arc, HashMap)> { let fields = substrait_schema.r#struct.as_ref().unwrap(); if fields.types.len() != arrow_schema.fields.len() { return Err(Error::InvalidInput { @@ -96,25 +105,35 @@ fn remove_extension_types( let mut kept_arrow_fields = Vec::with_capacity(arrow_schema.fields.len()); let mut index_mapping = HashMap::with_capacity(arrow_schema.fields.len()); let mut field_counter = 0; - for (field_index, (substrait_field, arrow_field)) in fields - .types - .iter() - .zip(arrow_schema.fields.iter()) - .enumerate() - { - if !matches!( - substrait_field.kind.as_ref().unwrap(), - Kind::UserDefined(_) | Kind::UserDefinedTypeReference(_) - ) { + let mut field_index = 0; + // TODO: this logic doesn't catch user defined fields inside of struct fields + for (substrait_field, arrow_field) in fields.types.iter().zip(arrow_schema.fields.iter()) { + let num_fields = count_fields(substrait_field); + + if !substrait_schema.names[field_index].starts_with("__unlikely_name_placeholder") + && !matches!( + substrait_field.kind.as_ref().unwrap(), + Kind::UserDefined(_) | Kind::UserDefinedTypeReference(_) + ) + { kept_substrait_fields.push(substrait_field.clone()); kept_arrow_fields.push(arrow_field.clone()); - index_mapping.insert(field_index, field_counter); - field_counter += 1; + for i in 0..num_fields { + index_mapping.insert(field_index + i, field_counter + i); + } + field_counter += num_fields; } + field_index += num_fields; } - let new_arrow_schema = Arc::new(Schema::new(kept_arrow_fields)); + let mut names = vec![String::new(); index_mapping.len()]; + for (old_idx, old_name) in substrait_schema.names.iter().enumerate() { + if let Some(new_idx) = index_mapping.get(&old_idx) { + names[*new_idx] = old_name.clone(); + } + } + let new_arrow_schema = Arc::new(ArrowSchema::new(kept_arrow_fields)); let new_substrait_schema = NamedStruct { - names: vec![], + names, r#struct: Some(Struct { nullability: fields.nullability, type_variation_reference: fields.type_variation_reference, @@ -241,7 +260,7 @@ fn remap_expr_references(expr: &mut Expression, mapping: &HashMap) /// Convert a Substrait ExtendedExpressions message into a DF Expr /// /// The ExtendedExpressions message must contain a single scalar expression -pub async fn parse_substrait(expr: &[u8], input_schema: Arc) -> Result { +pub async fn parse_substrait(expr: &[u8], input_schema: Arc) -> Result { let envelope = ExtendedExpression::decode(expr)?; if envelope.referred_expr.is_empty() { return Err(Error::InvalidInput { @@ -342,9 +361,11 @@ pub async fn parse_substrait(expr: &[u8], input_schema: Arc) -> Result) -> Result(iter: I) -> impl Stream -where - I::Item: Send, -{ - stream::unfold(iter, |mut iter| { - spawn_blocking(|| iter.next().map(|val| (val, iter))) - .unwrap_or_else(|err| panic!("{}", err)) - }) - .fuse() -} +pub mod background_iterator; /// A trait for [BatchRecord] iterators, readers and streams /// that can be converted to a concrete stream type [SendableRecordBatchStream]. @@ -145,7 +143,53 @@ pub fn reader_to_stream(batches: Box) -> SendableR let arrow_schema = batches.arrow_schema(); let stream = RecordBatchStreamAdapter::new( arrow_schema, - background_iterator(batches).map_err(DataFusionError::from), + BackgroundIterator::new(batches) + .fuse() + .map_err(DataFusionError::from), ); Box::pin(stream) } + +pub trait MetricsExt { + fn find_count(&self, name: &str) -> Option; +} + +impl MetricsExt for MetricsSet { + fn find_count(&self, metric_name: &str) -> Option { + self.iter().find_map(|m| match m.value() { + MetricValue::Count { name, count } => { + if name == metric_name { + Some(count.clone()) + } else { + None + } + } + _ => None, + }) + } +} + +pub trait ExecutionPlanMetricsSetExt { + fn new_count(&self, name: &'static str, partition: usize) -> Count; +} + +impl ExecutionPlanMetricsSetExt for ExecutionPlanMetricsSet { + fn new_count(&self, name: &'static str, partition: usize) -> Count { + let count = Count::new(); + MetricBuilder::new(self) + .with_partition(partition) + .build(MetricValue::Count { + name: Cow::Borrowed(name), + count: count.clone(), + }); + count + } +} + +// Common metrics +pub const IOPS_METRIC: &str = "iops"; +pub const REQUESTS_METRIC: &str = "requests"; +pub const BYTES_READ_METRIC: &str = "bytes_read"; +pub const INDICES_LOADED_METRIC: &str = "indices_loaded"; +pub const PARTS_LOADED_METRIC: &str = "parts_loaded"; +pub const INDEX_COMPARISONS_METRIC: &str = "index_comparisons"; diff --git a/rust/lance-datafusion/src/utils/background_iterator.rs b/rust/lance-datafusion/src/utils/background_iterator.rs new file mode 100644 index 00000000000..d9f0458718e --- /dev/null +++ b/rust/lance-datafusion/src/utils/background_iterator.rs @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use futures::ready; +use futures::Stream; +use std::{ + future::Future, + panic, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::task::JoinHandle; + +/// Wrap an iterator as a stream that executes the iterator in a background +/// blocking thread. +/// +/// The size hint is preserved, but the stream is not fused. +#[pin_project::pin_project] +pub struct BackgroundIterator { + #[pin] + state: BackgroundIterState, +} + +impl BackgroundIterator { + pub fn new(iter: I) -> Self { + Self { + state: BackgroundIterState::Current { iter }, + } + } +} + +impl Stream for BackgroundIterator +where + I::Item: Send + 'static, +{ + type Item = I::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if let Some(mut iter) = this.state.as_mut().take_iter() { + this.state.set(BackgroundIterState::Running { + size_hint: iter.size_hint(), + task: tokio::task::spawn_blocking(move || { + let next = iter.next(); + next.map(|next| (iter, next)) + }), + }); + } + + let step = match this.state.as_mut().project_future() { + Some(task) => ready!(task.poll(cx)), + None => panic!( + "BackgroundIterator must not be polled after it returned `Poll::Ready(None)`" + ), + }; + + match step { + Ok(Some((iter, next))) => { + this.state.set(BackgroundIterState::Current { iter }); + Poll::Ready(Some(next)) + } + Ok(None) => { + this.state.set(BackgroundIterState::Empty); + Poll::Ready(None) + } + Err(err) => { + if err.is_panic() { + // Resume the panic on the main task + panic::resume_unwind(err.into_panic()); + } else { + panic!("Background task failed: {:?}", err); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.state { + BackgroundIterState::Current { iter } => iter.size_hint(), + BackgroundIterState::Running { size_hint, .. } => *size_hint, + BackgroundIterState::Empty => (0, Some(0)), + } + } +} + +// Inspired by Unfold implementation: https://github.com/rust-lang/futures-rs/blob/master/futures-util/src/unfold_state.rs#L22 +#[pin_project::pin_project(project = StateProj, project_replace = StateReplace)] +enum BackgroundIterState { + Current { + iter: I, + }, + Running { + size_hint: (usize, Option), + #[pin] + task: NextHandle, + }, + Empty, +} + +type NextHandle = JoinHandle>; + +impl BackgroundIterState { + fn project_future(self: Pin<&mut Self>) -> Option>> { + match self.project() { + StateProj::Running { task, .. } => Some(task), + _ => None, + } + } + + fn take_iter(self: Pin<&mut Self>) -> Option { + match &*self { + Self::Current { .. } => match self.project_replace(Self::Empty) { + StateReplace::Current { iter } => Some(iter), + _ => None, + }, + _ => None, + } + } +} diff --git a/rust/lance-datagen/benches/array_gen.rs b/rust/lance-datagen/benches/array_gen.rs index fdc19797581..6483c0149b5 100644 --- a/rust/lance-datagen/benches/array_gen.rs +++ b/rust/lance-datagen/benches/array_gen.rs @@ -119,7 +119,7 @@ fn bench_rand_gen(c: &mut Criterion) { lance_datagen::array::rand::() }); bench_gen(&mut group, "rand_varbin", || { - lance_datagen::array::rand_varbin(ByteCount::from(12), false) + lance_datagen::array::rand_fixedbin(ByteCount::from(12), false) }); bench_gen(&mut group, "rand_utf8", || { lance_datagen::array::rand_utf8(ByteCount::from(12), false) diff --git a/rust/lance-datagen/src/generator.rs b/rust/lance-datagen/src/generator.rs index 3d8f4d8012e..db6fdf09097 100644 --- a/rust/lance-datagen/src/generator.rs +++ b/rust/lance-datagen/src/generator.rs @@ -6,13 +6,16 @@ use std::{collections::HashMap, iter, marker::PhantomData, sync::Arc}; use arrow::{ array::{ArrayData, AsArray}, buffer::{BooleanBuffer, Buffer, OffsetBuffer, ScalarBuffer}, - datatypes::{ArrowPrimitiveType, Int32Type, Int64Type, IntervalDayTime, IntervalMonthDayNano}, + datatypes::{ + ArrowPrimitiveType, Int32Type, Int64Type, IntervalDayTime, IntervalMonthDayNano, UInt32Type, + }, }; use arrow_array::{ make_array, types::{ArrowDictionaryKeyType, BinaryType, ByteArrayType, Utf8Type}, - Array, FixedSizeBinaryArray, FixedSizeListArray, LargeListArray, ListArray, NullArray, - PrimitiveArray, RecordBatch, RecordBatchOptions, RecordBatchReader, StringArray, StructArray, + Array, BinaryArray, FixedSizeBinaryArray, FixedSizeListArray, LargeListArray, ListArray, + NullArray, PrimitiveArray, RecordBatch, RecordBatchOptions, RecordBatchReader, StringArray, + StructArray, }; use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SchemaRef}; use futures::{stream::BoxStream, StreamExt}; @@ -54,7 +57,7 @@ impl From for Dimension { } /// A trait for anything that can generate arrays of data -pub trait ArrayGenerator: Send + Sync { +pub trait ArrayGenerator: Send + Sync + std::fmt::Debug { /// Generate an array of the given length /// /// # Arguments @@ -91,6 +94,7 @@ pub trait ArrayGenerator: Send + Sync { fn element_size_bytes(&self) -> Option; } +#[derive(Debug)] pub struct CycleNullGenerator { generator: Box, validity: Vec, @@ -138,6 +142,7 @@ impl ArrayGenerator for CycleNullGenerator { } } +#[derive(Debug)] pub struct MetadataGenerator { generator: Box, metadata: HashMap, @@ -165,6 +170,7 @@ impl ArrayGenerator for MetadataGenerator { } } +#[derive(Debug)] pub struct NullGenerator { generator: Box, null_probability: f64, @@ -201,7 +207,7 @@ impl ArrayGenerator for NullGenerator { } } else { let array_len = array.len(); - let num_validity_bytes = (array_len + 7) / 8; + let num_validity_bytes = array_len.div_ceil(8); let mut null_count = 0; // Sampling the RNG once per bit is kind of slow so we do this to sample once // per byte. We only get 8 bits of RNG resolution but that should be good enough. @@ -244,6 +250,10 @@ impl ArrayGenerator for NullGenerator { } } + fn metadata(&self) -> Option> { + self.generator.metadata() + } + fn data_type(&self) -> &DataType { self.generator.data_type() } @@ -348,6 +358,23 @@ where element_size_bytes: Option, } +impl T> std::fmt::Debug + for FnGen +where + T: Copy + Default, + ArrayType: arrow_array::Array + From>, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FnGen") + .field("data_type", &self.data_type) + .field("array_type", &self.array_type) + .field("repeat", &self.repeat) + .field("leftover_count", &self.leftover_count) + .field("element_size_bytes", &self.element_size_bytes) + .finish() + } +} + impl T> FnGen where T: Copy + Default, @@ -421,6 +448,7 @@ impl From for Seed { } } +#[derive(Debug)] pub struct CycleVectorGenerator { underlying_gen: Box, dimension: Dimension, @@ -469,7 +497,65 @@ impl ArrayGenerator for CycleVectorGenerator { } } -#[derive(Default)] +#[derive(Debug)] +pub struct CycleListGenerator { + underlying_gen: Box, + lengths_gen: Box, + data_type: DataType, +} + +impl CycleListGenerator { + pub fn new( + underlying_gen: Box, + min_list_size: Dimension, + max_list_size: Dimension, + ) -> Self { + let data_type = DataType::List(Arc::new(Field::new( + "item", + underlying_gen.data_type().clone(), + true, + ))); + let lengths_dist = Uniform::new(min_list_size.0, max_list_size.0); + let lengths_gen = rand_with_distribution::>(lengths_dist); + Self { + underlying_gen, + lengths_gen, + data_type, + } + } +} + +impl ArrayGenerator for CycleListGenerator { + fn generate( + &mut self, + length: RowCount, + rng: &mut rand_xoshiro::Xoshiro256PlusPlus, + ) -> Result, ArrowError> { + let lengths = self.lengths_gen.generate(length, rng)?; + let lengths = lengths.as_primitive::(); + let total_length = lengths.values().iter().map(|i| *i as u64).sum::(); + let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize)); + let values = self + .underlying_gen + .generate(RowCount::from(total_length), rng)?; + let field = Arc::new(Field::new("item", values.data_type().clone(), true)); + let values = Arc::new(values); + + let array = ListArray::try_new(field, offsets, values, None)?; + + Ok(Arc::new(array)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn element_size_bytes(&self) -> Option { + None + } +} + +#[derive(Debug, Default)] pub struct PseudoUuidGenerator {} impl ArrayGenerator for PseudoUuidGenerator { @@ -496,7 +582,7 @@ impl ArrayGenerator for PseudoUuidGenerator { } } -#[derive(Default)] +#[derive(Debug, Default)] pub struct PseudoUuidHexGenerator {} impl ArrayGenerator for PseudoUuidHexGenerator { @@ -523,7 +609,7 @@ impl ArrayGenerator for PseudoUuidHexGenerator { } } -#[derive(Default)] +#[derive(Debug, Default)] pub struct RandomBooleanGenerator {} impl ArrayGenerator for RandomBooleanGenerator { @@ -532,7 +618,7 @@ impl ArrayGenerator for RandomBooleanGenerator { length: RowCount, rng: &mut rand_xoshiro::Xoshiro256PlusPlus, ) -> Result, ArrowError> { - let num_bytes = (length.0 + 7) / 8; + let num_bytes = length.0.div_ceil(8); let mut bytes = vec![0; num_bytes as usize]; rng.fill_bytes(&mut bytes); let bytes = BooleanBuffer::new(Buffer::from(bytes), 0, length.0 as usize); @@ -557,6 +643,14 @@ pub struct RandomBytesGenerator { data_type: DataType, } +impl std::fmt::Debug for RandomBytesGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RandomBytesGenerator") + .field("data_type", &self.data_type) + .finish() + } +} + impl RandomBytesGenerator { fn new(data_type: DataType) -> Self { Self { @@ -596,6 +690,7 @@ impl ArrayGenerator for RandomBytesGenerato // This is pretty much the same thing as RandomBinaryGenerator but we can't use that // because there is no ArrowPrimitiveType for FixedSizeBinary +#[derive(Debug)] pub struct RandomFixedSizeBinaryGenerator { data_type: DataType, size: i32, @@ -635,6 +730,7 @@ impl ArrayGenerator for RandomFixedSizeBinaryGenerator { } } +#[derive(Debug)] pub struct RandomIntervalGenerator { unit: IntervalUnit, data_type: DataType, @@ -687,6 +783,7 @@ impl ArrayGenerator for RandomIntervalGenerator { Some(ByteCount::from(12)) } } +#[derive(Debug)] pub struct RandomBinaryGenerator { bytes_per_element: ByteCount, scale_to_utf8: bool, @@ -725,9 +822,10 @@ impl ArrayGenerator for RandomBinaryGenerator { } let bytes = Buffer::from(bytes); if self.is_large { - let offsets = OffsetBuffer::from_lengths( - iter::repeat(self.bytes_per_element.0 as usize).take(length.0 as usize), - ); + let offsets = OffsetBuffer::from_lengths(iter::repeat_n( + self.bytes_per_element.0 as usize, + length.0 as usize, + )); if self.scale_to_utf8 { // This is safe because we are only using printable characters unsafe { @@ -743,9 +841,10 @@ impl ArrayGenerator for RandomBinaryGenerator { } } } else { - let offsets = OffsetBuffer::from_lengths( - iter::repeat(self.bytes_per_element.0 as usize).take(length.0 as usize), - ); + let offsets = OffsetBuffer::from_lengths(iter::repeat_n( + self.bytes_per_element.0 as usize, + length.0 as usize, + )); if self.scale_to_utf8 { // This is safe because we are only using printable characters unsafe { @@ -775,6 +874,52 @@ impl ArrayGenerator for RandomBinaryGenerator { } } +#[derive(Debug)] +pub struct VariableRandomBinaryGenerator { + lengths_gen: Box, + data_type: DataType, +} + +impl VariableRandomBinaryGenerator { + pub fn new(min_bytes_per_element: ByteCount, max_bytes_per_element: ByteCount) -> Self { + let lengths_dist = Uniform::new_inclusive( + min_bytes_per_element.0 as i32, + max_bytes_per_element.0 as i32, + ); + let lengths_gen = rand_with_distribution::>(lengths_dist); + + Self { + lengths_gen, + data_type: DataType::Binary, + } + } +} + +impl ArrayGenerator for VariableRandomBinaryGenerator { + fn generate( + &mut self, + length: RowCount, + rng: &mut rand_xoshiro::Xoshiro256PlusPlus, + ) -> Result, ArrowError> { + let lengths = self.lengths_gen.generate(length, rng)?; + let lengths = lengths.as_primitive::(); + let total_length = lengths.values().iter().map(|i| *i as usize).sum::(); + let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize)); + let mut bytes = vec![0; total_length]; + rng.fill_bytes(&mut bytes); + let bytes = Buffer::from(bytes); + Ok(Arc::new(BinaryArray::try_new(offsets, bytes, None)?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn element_size_bytes(&self) -> Option { + None + } +} + pub struct CycleBinaryGenerator { values: Vec, lengths: Vec, @@ -784,6 +929,18 @@ pub struct CycleBinaryGenerator { idx: usize, } +impl std::fmt::Debug for CycleBinaryGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CycleBinaryGenerator") + .field("values", &self.values) + .field("lengths", &self.lengths) + .field("data_type", &self.data_type) + .field("width", &self.width) + .field("idx", &self.idx) + .finish() + } +} + impl CycleBinaryGenerator { pub fn from_strings(values: &[&str]) -> Self { if values.is_empty() { @@ -859,6 +1016,15 @@ pub struct FixedBinaryGenerator { array_type: PhantomData, } +impl std::fmt::Debug for FixedBinaryGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FixedBinaryGenerator") + .field("value", &self.value) + .field("data_type", &self.data_type) + .finish() + } +} + impl FixedBinaryGenerator { pub fn new(value: Vec) -> Self { Self { @@ -883,7 +1049,7 @@ impl ArrayGenerator for FixedBinaryGenerator { .copied(), )); let offsets = - OffsetBuffer::from_lengths(iter::repeat(self.value.len()).take(length.0 as usize)); + OffsetBuffer::from_lengths(iter::repeat_n(self.value.len(), length.0 as usize)); Ok(Arc::new(arrow_array::GenericByteArray::::new( offsets, bytes, None, ))) @@ -908,6 +1074,16 @@ pub struct DictionaryGenerator { key_width: u64, } +impl std::fmt::Debug for DictionaryGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DictionaryGenerator") + .field("generator", &self.generator) + .field("data_type", &self.data_type) + .field("key_width", &self.key_width) + .finish() + } +} + impl DictionaryGenerator { fn new(generator: Box) -> Self { let key_type = Box::new(K::DATA_TYPE); @@ -947,6 +1123,7 @@ impl ArrayGenerator for DictionaryGener } } +#[derive(Debug)] struct RandomListGenerator { field: Arc, child_field: Arc, @@ -1023,6 +1200,7 @@ impl ArrayGenerator for RandomListGenerator { } } +#[derive(Debug)] struct NullArrayGenerator {} impl ArrayGenerator for NullArrayGenerator { @@ -1043,6 +1221,7 @@ impl ArrayGenerator for NullArrayGenerator { } } +#[derive(Debug)] struct RandomStructGenerator { fields: Fields, data_type: DataType, @@ -1066,6 +1245,12 @@ impl ArrayGenerator for RandomStructGenerator { length: RowCount, rng: &mut rand_xoshiro::Xoshiro256PlusPlus, ) -> Result, ArrowError> { + if self.child_gens.is_empty() { + // Have to create empty struct arrays specially to ensure they have the correct + // row count + let struct_arr = StructArray::new_empty_fields(length.0 as usize, None); + return Ok(Arc::new(struct_arr)); + } let child_arrays = self .child_gens .iter_mut() @@ -1255,12 +1440,15 @@ impl BatchGeneratorBuilder { self, batch_size: RowCount, num_batches: BatchCount, - ) -> BoxStream<'static, Result> { + ) -> ( + BoxStream<'static, Result>, + Arc, + ) { // TODO: this is pretty lazy and could be optimized - let batches = self - .into_reader_rows(batch_size, num_batches) - .collect::>(); - futures::stream::iter(batches).boxed() + let reader = self.into_reader_rows(batch_size, num_batches); + let schema = reader.schema(); + let batches = reader.collect::>(); + (futures::stream::iter(batches).boxed(), schema) } /// Create a RecordBatchReader that generates batches of the given size (in bytes) @@ -1307,6 +1495,38 @@ impl BatchGeneratorBuilder { } } +/// Factory for creating a single random array +pub struct ArrayGeneratorBuilder { + generator: Box, + seed: Option, +} + +impl ArrayGeneratorBuilder { + fn new(generator: Box) -> Self { + Self { + generator, + seed: None, + } + } + + /// Use the given seed for the generator + pub fn with_seed(mut self, seed: Seed) -> Self { + self.seed = Some(seed); + self + } + + /// Generate a single array with the given length + pub fn into_array_rows( + mut self, + length: RowCount, + ) -> Result, ArrowError> { + let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64( + self.seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0), + ); + self.generator.generate(length, &mut rng) + } +} + const MS_PER_DAY: i64 = 86400000; pub mod array { @@ -1339,6 +1559,22 @@ pub mod array { Box::new(CycleVectorGenerator::new(generator, dimension)) } + /// Create a generator of list vectors by continuously calling the given generator + /// + /// The lists will have lengths uniformly distributed between `min_list_size` (inclusive) and + /// `max_list_size` (exclusive). + pub fn cycle_vec_var( + generator: Box, + min_list_size: Dimension, + max_list_size: Dimension, + ) -> Box { + Box::new(CycleListGenerator::new( + generator, + min_list_size, + max_list_size, + )) + } + /// Create a generator from a vector of values /// /// If more rows are requested than the length of values then it will restart @@ -1395,7 +1631,7 @@ pub mod array { pub fn blob() -> Box { let mut blob_meta = HashMap::new(); blob_meta.insert("lance-encoding:blob".to_string(), "true".to_string()); - rand_varbin(ByteCount::from(4 * 1024 * 1024), true).with_metadata(blob_meta) + rand_fixedbin(ByteCount::from(4 * 1024 * 1024), true).with_metadata(blob_meta) } /// Create a generator that starts at a given value and increments by a given step for each element @@ -1737,8 +1973,8 @@ pub mod array { )) } - /// Create a generator of random binary values - pub fn rand_varbin(bytes_per_element: ByteCount, is_large: bool) -> Box { + /// Create a generator of random binary values where each value has a fixed number of bytes + pub fn rand_fixedbin(bytes_per_element: ByteCount, is_large: bool) -> Box { Box::new(RandomBinaryGenerator::new( bytes_per_element, false, @@ -1746,6 +1982,19 @@ pub mod array { )) } + /// Create a generator of random binary values where each value has a variable number of bytes + /// + /// The number of bytes per element will be randomly sampled from the given (inclusive) range + pub fn rand_varbin( + min_bytes_per_element: ByteCount, + max_bytes_per_element: ByteCount, + ) -> Box { + Box::new(VariableRandomBinaryGenerator::new( + min_bytes_per_element, + max_bytes_per_element, + )) + } + /// Create a generator of random strings /// /// All strings will consist entirely of printable ASCII characters @@ -1767,6 +2016,13 @@ pub mod array { Box::new(RandomListGenerator::new(child_gen, is_large)) } + pub fn rand_list_any( + item_gen: Box, + is_large: bool, + ) -> Box { + Box::new(RandomListGenerator::new(item_gen, is_large)) + } + pub fn rand_struct(fields: Fields) -> Box { let child_gens = fields .iter() @@ -1798,8 +2054,8 @@ pub mod array { DataType::Decimal256(_, _) => rand_primitive::(data_type.clone()), DataType::Utf8 => rand_utf8(ByteCount::from(12), false), DataType::LargeUtf8 => rand_utf8(ByteCount::from(12), true), - DataType::Binary => rand_varbin(ByteCount::from(12), false), - DataType::LargeBinary => rand_varbin(ByteCount::from(12), true), + DataType::Binary => rand_fixedbin(ByteCount::from(12), false), + DataType::LargeBinary => rand_fixedbin(ByteCount::from(12), true), DataType::Dictionary(key_type, value_type) => { dict_type(rand_type(value_type), key_type) } @@ -1858,11 +2114,16 @@ pub mod array { } } -/// Create a BatchGeneratorBuilder to start generating data +/// Create a BatchGeneratorBuilder to start generating batch data pub fn gen() -> BatchGeneratorBuilder { BatchGeneratorBuilder::default() } +/// Create an ArrayGeneratorBuilder to start generating array data +pub fn gen_array(gen: Box) -> ArrayGeneratorBuilder { + ArrayGeneratorBuilder::new(gen) +} + /// Create a BatchGeneratorBuilder with the given schema /// /// You can add more columns or convert this into a reader immediately @@ -1978,7 +2239,7 @@ mod tests { Int32Array::from_iter([-797553329, 1369325940, -69174021]) ); - let mut gen = array::rand_varbin(ByteCount::from(3), false); + let mut gen = array::rand_fixedbin(ByteCount::from(3), false); assert_eq!( *gen.generate(RowCount::from(3), &mut rng).unwrap(), arrow_array::BinaryArray::from_iter_values([ @@ -2009,6 +2270,16 @@ mod tests { // Sanity check to ensure we're getting at least some rng assert!(bools.false_count() > 100); assert!(bools.true_count() > 100); + + let mut gen = array::rand_varbin(ByteCount::from(2), ByteCount::from(4)); + assert_eq!( + *gen.generate(RowCount::from(3), &mut rng).unwrap(), + arrow_array::BinaryArray::from_iter_values([ + vec![56, 122, 157, 34], + vec![58, 51], + vec![41, 184, 125] + ]) + ); } #[test] diff --git a/rust/lance-encoding-datafusion/Cargo.toml b/rust/lance-encoding-datafusion/Cargo.toml index 3cccdc9ec68..ffc608ed301 100644 --- a/rust/lance-encoding-datafusion/Cargo.toml +++ b/rust/lance-encoding-datafusion/Cargo.toml @@ -42,9 +42,17 @@ lance-datagen.workspace = true [build-dependencies] prost-build.workspace = true +protobuf-src = { version = "2.1", optional = true } [target.'cfg(target_os = "linux")'.dev-dependencies] pprof = { workspace = true } +[features] +protoc = ["dep:protobuf-src"] + +[package.metadata.docs.rs] +# docs.rs uses an older version of Ubuntu that does not have the necessary protoc version +features = ["protoc"] + [lints] workspace = true diff --git a/rust/lance-encoding-datafusion/build.rs b/rust/lance-encoding-datafusion/build.rs index 8d89a39ac37..9d0206e2016 100644 --- a/rust/lance-encoding-datafusion/build.rs +++ b/rust/lance-encoding-datafusion/build.rs @@ -6,6 +6,10 @@ use std::io::Result; fn main() -> Result<()> { println!("cargo:rerun-if-changed=protos"); + #[cfg(feature = "protoc")] + // Use vendored protobuf compiler if requested. + std::env::set_var("PROTOC", protobuf_src::protoc()); + let mut prost_build = prost_build::Config::new(); prost_build.extern_path(".lance.encodings", "::lance_encoding::format::pb"); prost_build.protoc_arg("--experimental_allow_proto3_optional"); diff --git a/rust/lance-encoding-datafusion/src/zone.rs b/rust/lance-encoding-datafusion/src/zone.rs index 03b8e5278a1..25ea744a77d 100644 --- a/rust/lance-encoding-datafusion/src/zone.rs +++ b/rust/lance-encoding-datafusion/src/zone.rs @@ -43,7 +43,7 @@ use lance_file::{ v2::{reader::EncodedBatchReaderExt, writer::EncodedBatchWriteExt}, version::LanceFileVersion, }; -use snafu::{location, Location}; +use snafu::location; use crate::substrait::FilterExpressionExt; @@ -146,7 +146,7 @@ pub(crate) fn extract_zone_info( let mut zone_index = zone_index.clone(); let inner = zone_index.inner.take().unwrap(); let rows_per_zone = zone_index.rows_per_zone; - let zone_map_buffer = zone_index.zone_map_buffer.as_ref().unwrap().clone(); + let zone_map_buffer = *zone_index.zone_map_buffer.as_ref().unwrap(); assert_eq!( zone_map_buffer.buffer_type, i32::from(pb::buffer::BufferType::Column) @@ -611,6 +611,7 @@ impl FieldEncoder for ZoneMapsFieldEncoder { external_buffers: &mut OutOfLineBuffers, repdef: RepDefBuilder, row_number: u64, + num_rows: u64, ) -> Result> { // TODO: If we do the zone map calculation as part of the encoding task then we can // parallelize statistics gathering. Could be faster too since the encoding task is @@ -619,7 +620,7 @@ impl FieldEncoder for ZoneMapsFieldEncoder { // to improve write speed. self.update(&array)?; self.items_encoder - .maybe_encode(array, external_buffers, repdef, row_number) + .maybe_encode(array, external_buffers, repdef, row_number, num_rows) } fn flush( diff --git a/rust/lance-encoding/Cargo.toml b/rust/lance-encoding/Cargo.toml index e43a7c634ed..0b19be5d127 100644 --- a/rust/lance-encoding/Cargo.toml +++ b/rust/lance-encoding/Cargo.toml @@ -38,11 +38,12 @@ snafu.workspace = true tokio.workspace = true tracing.workspace = true zstd.workspace = true -bytemuck = "=1.18.0" +bytemuck = "1.14" arrayref = "0.3.7" paste = "1.0.15" seq-macro = "0.3.5" byteorder.workspace = true +lz4 = "1.28.1" [dev-dependencies] lance-testing.workspace = true @@ -56,10 +57,18 @@ rand_xoshiro = "0.6.0" [build-dependencies] prost-build.workspace = true +protobuf-src = { version = "2.1", optional = true } [target.'cfg(target_os = "linux")'.dev-dependencies] pprof = { workspace = true } +[features] +protoc = ["dep:protobuf-src"] + +[package.metadata.docs.rs] +# docs.rs uses an older version of Ubuntu that does not have the necessary protoc version +features = ["protoc"] + [[bench]] name = "decoder" harness = false diff --git a/rust/lance-encoding/benches/decoder.rs b/rust/lance-encoding/benches/decoder.rs index c6a80538a86..500274fa34d 100644 --- a/rust/lance-encoding/benches/decoder.rs +++ b/rust/lance-encoding/benches/decoder.rs @@ -299,6 +299,7 @@ fn bench_decode_packed_struct(c: &mut Criterion) { }); } +#[allow(dead_code)] fn bench_decode_str_with_fixed_size_binary_encoding(c: &mut Criterion) { let rt = tokio::runtime::Runtime::new().unwrap(); let mut group = c.benchmark_group("decode_primitive"); diff --git a/rust/lance-encoding/build.rs b/rust/lance-encoding/build.rs index 1f030d6d7fd..37efdcbc9d4 100644 --- a/rust/lance-encoding/build.rs +++ b/rust/lance-encoding/build.rs @@ -6,9 +6,14 @@ use std::io::Result; fn main() -> Result<()> { println!("cargo:rerun-if-changed=protos"); + #[cfg(feature = "protoc")] + // Use vendored protobuf compiler if requested. + std::env::set_var("PROTOC", protobuf_src::protoc()); + let mut prost_build = prost_build::Config::new(); prost_build.protoc_arg("--experimental_allow_proto3_optional"); prost_build.enable_type_names(); + prost_build.bytes(["."]); // Enable Bytes type for all messages to avoid Vec clones. prost_build.compile_protos(&["./protos/encodings.proto"], &["./protos"])?; Ok(()) diff --git a/rust/lance-encoding/src/buffer.rs b/rust/lance-encoding/src/buffer.rs index f62f447a6b4..7ca0cac0b2c 100644 --- a/rust/lance-encoding/src/buffer.rs +++ b/rust/lance-encoding/src/buffer.rs @@ -3,10 +3,11 @@ //! Utilities for byte arrays -use std::{ops::Deref, ptr::NonNull, sync::Arc}; +use std::{ops::Deref, panic::RefUnwindSafe, ptr::NonNull, sync::Arc}; use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, ScalarBuffer}; -use snafu::{location, Location}; +use itertools::Either; +use snafu::location; use lance_core::{utils::bit::is_pwr_two, Error, Result}; @@ -104,6 +105,18 @@ impl LanceBuffer { hex::encode_upper(self) } + /// Combine multiple buffers into a single buffer + /// + /// This does involve a data copy (and allocation of a new buffer) + pub fn concat(buffers: &[Self]) -> Self { + let total_len = buffers.iter().map(|b| b.len()).sum(); + let mut data = Vec::with_capacity(total_len); + for buffer in buffers { + data.extend_from_slice(buffer.as_ref()); + } + Self::Owned(data) + } + /// Converts the buffer into a hex string, inserting a space /// between words pub fn as_spaced_hex(&self, bytes_per_word: u32) -> String { @@ -151,6 +164,14 @@ impl LanceBuffer { } } + /// Convert a buffer into a bytes::Bytes object + pub fn into_bytes(self) -> bytes::Bytes { + match self { + Self::Owned(buf) => buf.into(), + Self::Borrowed(buf) => buf.into_vec::().unwrap().into(), + } + } + /// Convert into a borrowed buffer, this is a zero-copy operation /// /// This is often called before cloning the buffer @@ -219,6 +240,20 @@ impl LanceBuffer { Self::Borrowed(Buffer::from_vec(vec)) } + /// Reinterprets Arc<[T]> as a LanceBuffer + /// + /// This is similar to [`Self::reinterpret_vec`] but for Arc<[T]> instead of Vec + /// + /// The same alignment constraints apply + pub fn reinterpret_slice(arc: Arc<[T]>) -> Self { + let slice = arc.as_ref(); + let data = NonNull::new(slice.as_ptr() as _).unwrap_or(NonNull::dangling()); + let len = std::mem::size_of_val(slice); + // SAFETY: the ptr will be valid for len items if the Arc<[T]> is valid + let buffer = unsafe { Buffer::from_custom_allocation(data, len, Arc::new(arc)) }; + Self::Borrowed(buffer) + } + /// Reinterprets a LanceBuffer into a Vec /// /// If the underlying buffer is not properly aligned, this will involve a copy of the data @@ -227,7 +262,7 @@ impl LanceBuffer { /// of the data. Lance does not support big-endian machines so this is safe. However, if we end /// up supporting big-endian machines in the future, then any use of this method will need to be /// carefully reviewed. - pub fn borrow_to_typed_slice(&mut self) -> impl AsRef<[T]> { + pub fn borrow_to_typed_slice(&mut self) -> ScalarBuffer { let align = std::mem::align_of::(); let is_aligned = self.as_ptr().align_offset(align) == 0; if self.len() % std::mem::size_of::() != 0 { @@ -343,6 +378,42 @@ impl LanceBuffer { Self::Owned(buffer) => Self::Owned(buffer[offset..offset + length].to_vec()), } } + + // Backport of https://github.com/apache/arrow-rs/pull/6707 + fn arrow_bit_slice( + buf: &arrow_buffer::Buffer, + offset: usize, + len: usize, + ) -> arrow_buffer::Buffer { + if offset % 8 == 0 { + return buf.slice_with_length(offset / 8, len.div_ceil(8)); + } + + arrow_buffer::bitwise_unary_op_helper(buf, offset, len, |a| a) + } + + /// Returns a new [LanceBuffer] that is a slice of this buffer starting at bit `offset` + /// with `length` bits. + /// + /// Unlike `slice_with_length`, this method allows for slicing at a bit level but always + /// requires a copy of the data (unless offset is byte-aligned) + /// + /// This method also converts to a borrowed buffer for convenience, but that could be optimized + /// away in the future if needed. + /// + /// This method performs the bit slice using the Arrow convention of *bitwise* little-endian + /// + /// This means, given the bit buffer 0bABCDEFGH_HIJKLMNOP and the slice starting at bit 3 and + /// with length 8, the result will be 0bNOPABCDE + pub fn bit_slice_le_with_length(&mut self, offset: usize, length: usize) -> Self { + let Self::Borrowed(borrowed) = self.borrow_and_clone() else { + unreachable!() + }; + // Use this and remove backport once we upgrade to arrow-rs 54 + // let sliced = borrowed.bit_slice(offset, length); + let sliced = Self::arrow_bit_slice(&borrowed, offset, length); + Self::Borrowed(sliced) + } } impl AsRef<[u8]> for LanceBuffer { @@ -376,6 +447,40 @@ impl From for LanceBuffer { } } +// An iterator that keeps a clone of a borrowed LanceBuffer so we +// can have a 'static lifetime +pub struct BorrowedBufferIter { + buffer: arrow_buffer::Buffer, + index: usize, +} + +impl Iterator for BorrowedBufferIter { + type Item = u8; + + fn next(&mut self) -> Option { + if self.index >= self.buffer.len() { + None + } else { + // SAFETY: we just checked that index is in bounds + let byte = unsafe { self.buffer.get_unchecked(self.index) }; + self.index += 1; + Some(*byte) + } + } +} + +impl IntoIterator for LanceBuffer { + type Item = u8; + type IntoIter = Either, BorrowedBufferIter>; + + fn into_iter(self) -> Self::IntoIter { + match self { + Self::Borrowed(buffer) => Either::Right(BorrowedBufferIter { buffer, index: 0 }), + Self::Owned(buffer) => Either::Left(buffer.into_iter()), + } + } +} + #[cfg(test)] mod tests { use arrow_buffer::Buffer; @@ -486,4 +591,17 @@ mod tests { assert_ne!(view_ptr, view_ptr2); } + + #[test] + fn test_bit_slice_le() { + let mut buf = LanceBuffer::Owned(vec![0x0F, 0x0B]); + + // Keep in mind that validity buffers are *bitwise* little-endian + assert_eq!(buf.bit_slice_le_with_length(0, 4).as_ref(), &[0x0F]); + assert_eq!(buf.bit_slice_le_with_length(4, 4).as_ref(), &[0x00]); + assert_eq!(buf.bit_slice_le_with_length(3, 8).as_ref(), &[0x61]); + assert_eq!(buf.bit_slice_le_with_length(0, 8).as_ref(), &[0x0F]); + assert_eq!(buf.bit_slice_le_with_length(4, 8).as_ref(), &[0xB0]); + assert_eq!(buf.bit_slice_le_with_length(4, 12).as_ref(), &[0xB0, 0x00]); + } } diff --git a/rust/lance-encoding/src/compression_algo/fsst/src/fsst.rs b/rust/lance-encoding/src/compression_algo/fsst/src/fsst.rs index d59c1aff722..f57a115a1e3 100644 --- a/rust/lance-encoding/src/compression_algo/fsst/src/fsst.rs +++ b/rust/lance-encoding/src/compression_algo/fsst/src/fsst.rs @@ -38,7 +38,7 @@ const FSST_HASH_PRIME: u64 = 2971215073; const FSST_SHIFT: usize = 15; #[inline] fn fsst_hash(w: u64) -> u64 { - w.wrapping_mul(FSST_HASH_PRIME) ^ (w.wrapping_mul(FSST_HASH_PRIME)) >> FSST_SHIFT + w.wrapping_mul(FSST_HASH_PRIME) ^ ((w.wrapping_mul(FSST_HASH_PRIME)) >> FSST_SHIFT) } const MAX_SYMBOL_LENGTH: usize = 8; @@ -119,7 +119,7 @@ impl Symbol { Self { val: c as u64, // in a symbol which represents a single character, 56 bits(7 bytes) are ignored, code length is 1 - icl: (1 << CODE_LEN_SHIFT_IN_ICL) | (code as u64) << CODE_SHIFT_IN_ICL | 56, + icl: (1 << CODE_LEN_SHIFT_IN_ICL) | ((code as u64) << CODE_SHIFT_IN_ICL) | 56, } } @@ -368,7 +368,7 @@ impl SymbolTable { return self.byte_codes[input[0] as usize] & FSST_CODE_MASK; } if len == 2 { - let short_code = (input[1] as usize) << 8 | input[0] as usize; + let short_code = ((input[1] as usize) << 8) | input[0] as usize; if self.short_codes[short_code] >= FSST_CODE_BASE { return self.short_codes[short_code] & FSST_CODE_MASK; } else { @@ -1053,9 +1053,9 @@ impl FsstEncoder { let st = &self.symbol_table; let st_info: u64 = FSST_MAGIC - | (self.encoder_switch as u64) << 24 - | ((st.suffix_lim & 255) as u64) << 16 - | ((st.terminator & 255) as u64) << 8 + | ((self.encoder_switch as u64) << 24) + | (((st.suffix_lim & 255) as u64) << 16) + | (((st.terminator & 255) as u64) << 8) | ((st.n_symbols & 255) as u64); let st_info_bytes = st_info.to_ne_bytes(); diff --git a/rust/lance-encoding/src/data.rs b/rust/lance-encoding/src/data.rs index f85efccd43f..b706c6c2e45 100644 --- a/rust/lance-encoding/src/data.rs +++ b/rust/lance-encoding/src/data.rs @@ -25,7 +25,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBuffer, BooleanBufferBuilder, NullBuf use arrow_schema::DataType; use bytemuck::try_cast_slice; use lance_arrow::DataTypeExt; -use snafu::{location, Location}; +use snafu::location; use lance_core::{Error, Result}; @@ -251,6 +251,7 @@ impl FixedWidthDataBlock { } } +#[derive(Debug)] pub struct VariableWidthDataBlockBuilder { offsets: Vec, bytes: Vec, @@ -304,6 +305,42 @@ impl DataBlockBuilderImpl for VariableWidthDataBlockBuilder { } } +#[derive(Debug)] +struct BitmapDataBlockBuilder { + values: BooleanBufferBuilder, +} + +impl BitmapDataBlockBuilder { + fn new(estimated_size_bytes: u64) -> Self { + Self { + values: BooleanBufferBuilder::new(estimated_size_bytes as usize * 8), + } + } +} + +impl DataBlockBuilderImpl for BitmapDataBlockBuilder { + fn append(&mut self, data_block: &DataBlock, selection: Range) { + let bitmap_blk = data_block.as_fixed_width_ref().unwrap(); + self.values.append_packed_range( + selection.start as usize..selection.end as usize, + &bitmap_blk.data, + ); + } + + fn finish(mut self: Box) -> DataBlock { + let bool_buf = self.values.finish(); + let num_values = bool_buf.len() as u64; + let bits_buf = bool_buf.into_inner(); + DataBlock::FixedWidth(FixedWidthDataBlock { + data: LanceBuffer::from(bits_buf), + bits_per_value: 1, + num_values, + block_info: BlockInfo::new(), + }) + } +} + +#[derive(Debug)] struct FixedWidthDataBlockBuilder { bits_per_value: u64, bytes_per_value: u64, @@ -341,6 +378,53 @@ impl DataBlockBuilderImpl for FixedWidthDataBlockBuilder { } } +#[derive(Debug)] +struct StructDataBlockBuilder { + children: Vec>, +} + +impl StructDataBlockBuilder { + // Currently only Struct with fixed-width fields are supported. + // And the assumption that all fields have `bits_per_value % 8 == 0` is made here. + fn new(bits_per_values: Vec, estimated_size_bytes: u64) -> Self { + let mut children = vec![]; + + debug_assert!(bits_per_values.iter().all(|bpv| bpv % 8 == 0)); + + let bytes_per_row: u32 = bits_per_values.iter().sum::() / 8; + let bytes_per_row = bytes_per_row as u64; + + for bits_per_value in bits_per_values.iter() { + let this_estimated_size_bytes = + estimated_size_bytes / bytes_per_row * (*bits_per_value as u64) / 8; + let child = + FixedWidthDataBlockBuilder::new(*bits_per_value as u64, this_estimated_size_bytes); + children.push(Box::new(child) as Box); + } + Self { children } + } +} + +impl DataBlockBuilderImpl for StructDataBlockBuilder { + fn append(&mut self, data_block: &DataBlock, selection: Range) { + let data_block = data_block.as_struct_ref().unwrap(); + for i in 0..self.children.len() { + self.children[i].append(&data_block.children[i], selection.clone()); + } + } + + fn finish(self: Box) -> DataBlock { + let mut children_data_block = Vec::new(); + for child in self.children { + let child_data_block = child.finish(); + children_data_block.push(child_data_block); + } + DataBlock::Struct(StructDataBlock { + children: children_data_block, + block_info: BlockInfo::new(), + }) + } +} /// A data block to represent a fixed size list #[derive(Debug)] pub struct FixedSizeListBlock { @@ -365,14 +449,7 @@ impl FixedSizeListBlock { }) } - fn remove_validity(self) -> Self { - Self { - child: Box::new(self.child.remove_validity()), - dimension: self.dimension, - } - } - - fn num_values(&self) -> u64 { + pub fn num_values(&self) -> u64 { self.child.num_values() / self.dimension } @@ -401,6 +478,14 @@ impl FixedSizeListBlock { } } + pub fn flatten_as_fixed(&mut self) -> FixedWidthDataBlock { + match self.child.as_mut() { + DataBlock::FixedSizeList(fsl) => fsl.flatten_as_fixed(), + DataBlock::FixedWidth(fw) => fw.borrow_and_clone(), + _ => panic!("Expected FixedSizeList or FixedWidth data block"), + } + } + /// Convert a flattened values block into a FixedSizeListBlock pub fn from_flat(data: FixedWidthDataBlock, data_type: &DataType) -> DataBlock { match data_type { @@ -449,6 +534,7 @@ impl FixedSizeListBlock { } } +#[derive(Debug)] struct FixedSizeListBlockBuilder { inner: Box, dimension: u64, @@ -476,6 +562,43 @@ impl DataBlockBuilderImpl for FixedSizeListBlockBuilder { } } +#[derive(Debug)] +struct NullableDataBlockBuilder { + inner: Box, + validity: BooleanBufferBuilder, +} + +impl NullableDataBlockBuilder { + fn new(inner: Box, estimated_size_bytes: usize) -> Self { + Self { + inner, + validity: BooleanBufferBuilder::new(estimated_size_bytes * 8), + } + } +} + +impl DataBlockBuilderImpl for NullableDataBlockBuilder { + fn append(&mut self, data_block: &DataBlock, selection: Range) { + let nullable = data_block.as_nullable_ref().unwrap(); + let bool_buf = BooleanBuffer::new( + nullable.nulls.try_clone().unwrap().into_buffer(), + selection.start as usize, + (selection.end - selection.start) as usize, + ); + self.validity.append_buffer(&bool_buf); + self.inner.append(nullable.data.as_ref(), selection); + } + + fn finish(mut self: Box) -> DataBlock { + let inner_block = self.inner.finish(); + DataBlock::Nullable(NullableDataBlock { + data: Box::new(inner_block), + nulls: LanceBuffer::Borrowed(self.validity.finish().into_inner()), + block_info: BlockInfo::new(), + }) + } +} + /// A data block with no regular structure. There is no available spot to attach /// validity / repdef information and it cannot be converted to Arrow without being /// decoded @@ -573,6 +696,16 @@ impl VariableWidthBlock { }) } + pub fn offsets_as_block(&mut self) -> DataBlock { + let offsets = self.offsets.borrow_and_clone(); + DataBlock::FixedWidth(FixedWidthDataBlock { + data: offsets, + bits_per_value: self.bits_per_offset as u64, + num_values: self.num_values + 1, + block_info: BlockInfo::new(), + }) + } + pub fn data_size(&self) -> u64 { (self.data.len() + self.offsets.len()) as u64 } @@ -583,6 +716,7 @@ impl VariableWidthBlock { pub struct StructDataBlock { /// The child arrays pub children: Vec, + pub block_info: BlockInfo, } impl StructDataBlock { @@ -609,13 +743,14 @@ impl StructDataBlock { } } - fn remove_validity(self) -> Self { + fn remove_outer_validity(self) -> Self { Self { children: self .children .into_iter() - .map(|c| c.remove_validity()) + .map(|c| c.remove_outer_validity()) .collect(), + block_info: self.block_info, } } @@ -633,6 +768,7 @@ impl StructDataBlock { .iter_mut() .map(|c| c.borrow_and_clone()) .collect(), + block_info: self.block_info.clone(), } } @@ -643,8 +779,16 @@ impl StructDataBlock { .iter() .map(|c| c.try_clone()) .collect::>()?, + block_info: self.block_info.clone(), }) } + + pub fn data_size(&self) -> u64 { + self.children + .iter() + .map(|data_block| data_block.data_size()) + .sum() + } } /// A data block for dictionary encoded data @@ -823,6 +967,41 @@ impl DataBlock { } } + pub fn is_variable(&self) -> bool { + match self { + Self::Constant(_) => false, + Self::Empty() => false, + Self::AllNull(_) => false, + Self::Nullable(nullable) => nullable.data.is_variable(), + Self::FixedWidth(_) => false, + Self::FixedSizeList(fsl) => fsl.child.is_variable(), + Self::VariableWidth(_) => true, + Self::Struct(strct) => strct.children.iter().any(|c| c.is_variable()), + Self::Dictionary(_) => { + todo!("is_variable for DictionaryDataBlock is not implemented yet") + } + Self::Opaque(_) => panic!("Does not make sense to ask if an Opaque block is variable"), + } + } + + pub fn is_nullable(&self) -> bool { + match self { + Self::AllNull(_) => true, + Self::Nullable(_) => true, + Self::FixedSizeList(fsl) => fsl.child.is_nullable(), + Self::Struct(strct) => strct.children.iter().any(|c| c.is_nullable()), + Self::Dictionary(_) => { + todo!("is_nullable for DictionaryDataBlock is not implemented yet") + } + Self::Opaque(_) => panic!("Does not make sense to ask if an Opaque block is nullable"), + _ => false, + } + } + + /// The number of values in the block + /// + /// This function does not recurse into child blocks. If this is a FSL then it will + /// be the number of lists and not the number of items. pub fn num_values(&self) -> u64 { match self { Self::Empty() => 0, @@ -838,6 +1017,25 @@ impl DataBlock { } } + /// The number of items in a single row + /// + /// This is always 1 unless there are layers of FSL + pub fn items_per_row(&self) -> u64 { + match self { + Self::Empty() => todo!(), // Leave undefined until needed + Self::Constant(_) => todo!(), // Leave undefined until needed + Self::AllNull(_) => todo!(), // Leave undefined until needed + Self::Nullable(nullable) => nullable.data.items_per_row(), + Self::FixedWidth(_) => 1, + Self::FixedSizeList(fsl) => fsl.dimension * fsl.child.items_per_row(), + Self::VariableWidth(_) => 1, + Self::Struct(_) => todo!(), // Leave undefined until needed + Self::Dictionary(_) => 1, + Self::Opaque(_) => 1, + } + } + + /// The number of bytes in the data block (including any child blocks) pub fn data_size(&self) -> u64 { match self { Self::Empty() => 0, @@ -862,27 +1060,29 @@ impl DataBlock { /// This does not filter the block (e.g. remove rows). It only removes /// the validity bitmaps (if present). Any garbage masked by null bits /// will now appear as proper values. - pub fn remove_validity(self) -> Self { + /// + /// If `recurse` is true, then this will also remove validity from any child blocks. + pub fn remove_outer_validity(self) -> Self { match self { - Self::Empty() => Self::Empty(), - Self::Constant(inner) => Self::Constant(inner), Self::AllNull(_) => panic!("Cannot remove validity on all-null data"), Self::Nullable(inner) => *inner.data, - Self::FixedWidth(inner) => Self::FixedWidth(inner), - Self::FixedSizeList(inner) => Self::FixedSizeList(inner.remove_validity()), - Self::VariableWidth(inner) => Self::VariableWidth(inner), - Self::Struct(inner) => Self::Struct(inner.remove_validity()), - Self::Dictionary(inner) => Self::FixedWidth(inner.indices), - Self::Opaque(inner) => Self::Opaque(inner), + Self::Struct(inner) => Self::Struct(inner.remove_outer_validity()), + other => other, } } pub fn make_builder(&self, estimated_size_bytes: u64) -> Box { match self { - Self::FixedWidth(inner) => Box::new(FixedWidthDataBlockBuilder::new( - inner.bits_per_value, - estimated_size_bytes, - )), + Self::FixedWidth(inner) => { + if inner.bits_per_value == 1 { + Box::new(BitmapDataBlockBuilder::new(estimated_size_bytes)) + } else { + Box::new(FixedWidthDataBlockBuilder::new( + inner.bits_per_value, + estimated_size_bytes, + )) + } + } Self::VariableWidth(inner) => { if inner.bits_per_offset == 32 { Box::new(VariableWidthDataBlockBuilder::new(estimated_size_bytes)) @@ -897,7 +1097,31 @@ impl DataBlock { inner.dimension, )) } - _ => todo!(), + Self::Nullable(nullable) => { + // There's no easy way to know what percentage of the data is in the valiidty buffer + // but 1/16th seems like a reasonable guess. + let estimated_validity_size_bytes = estimated_size_bytes / 16; + let inner_builder = nullable + .data + .make_builder(estimated_size_bytes - estimated_validity_size_bytes); + Box::new(NullableDataBlockBuilder::new( + inner_builder, + estimated_validity_size_bytes as usize, + )) + } + Self::Struct(struct_data_block) => { + let mut bits_per_values = vec![]; + for child in struct_data_block.children.iter() { + let child = child.as_fixed_width_ref(). + expect("Currently StructDataBlockBuilder is only used in packed-struct encoding, and currently in packed-struct encoding, only fixed-width fields are supported."); + bits_per_values.push(child.bits_per_value as u32); + } + Box::new(StructDataBlockBuilder::new( + bits_per_values, + estimated_size_bytes, + )) + } + _ => todo!("make_builder for {:?}", self), } } } @@ -951,17 +1175,17 @@ impl DataBlock { as_type_ref!(as_variable_width_ref, VariableWidth, VariableWidthBlock); as_type_ref!(as_struct_ref, Struct, StructDataBlock); as_type_ref!(as_dictionary_ref, Dictionary, DictionaryDataBlock); - as_type_ref_mut!(as_all_null_mut_ref, AllNull, AllNullDataBlock); - as_type_ref_mut!(as_nullable_mut_ref, Nullable, NullableDataBlock); - as_type_ref_mut!(as_fixed_width_mut_ref, FixedWidth, FixedWidthDataBlock); + as_type_ref_mut!(as_all_null_ref_mut, AllNull, AllNullDataBlock); + as_type_ref_mut!(as_nullable_ref_mut, Nullable, NullableDataBlock); + as_type_ref_mut!(as_fixed_width_ref_mut, FixedWidth, FixedWidthDataBlock); as_type_ref_mut!( - as_fixed_size_list_mut_ref, + as_fixed_size_list_ref_mut, FixedSizeList, FixedSizeListBlock ); - as_type_ref_mut!(as_variable_width_mut_ref, VariableWidth, VariableWidthBlock); - as_type_ref_mut!(as_struct_mut_ref, Struct, StructDataBlock); - as_type_ref_mut!(as_dictionary_mut_ref, Dictionary, DictionaryDataBlock); + as_type_ref_mut!(as_variable_width_ref_mut, VariableWidth, VariableWidthBlock); + as_type_ref_mut!(as_struct_ref_mut, Struct, StructDataBlock); + as_type_ref_mut!(as_dictionary_ref_mut, Dictionary, DictionaryDataBlock); } // Methods to convert from Arrow -> DataBlock @@ -1097,7 +1321,7 @@ fn concat_dict_arrays(arrays: &[ArrayRef]) -> ArrayRef { let array_refs = arrays.iter().map(|arr| arr.as_ref()).collect::>(); match arrow_select::concat::concat(&array_refs) { Ok(array) => array, - Err(arrow_schema::ArrowError::DictionaryKeyOverflowError { .. }) => { + Err(arrow_schema::ArrowError::DictionaryKeyOverflowError) => { // Slow, but hopefully a corner case. Optimize later let upscaled = array_refs .iter() @@ -1110,7 +1334,7 @@ fn concat_dict_arrays(arrays: &[ArrayRef]) -> ArrayRef { ), ) { Ok(arr) => arr, - Err(arrow_schema::ArrowError::DictionaryKeyOverflowError { .. }) => { + Err(arrow_schema::ArrowError::DictionaryKeyOverflowError) => { // Technically I think this means the input type was u64 already unimplemented!("Dictionary arrays with more than 2^32 unique values") } @@ -1122,7 +1346,7 @@ fn concat_dict_arrays(arrays: &[ArrayRef]) -> ArrayRef { // Can still fail if concat pushes over u32 boundary match arrow_select::concat::concat(&array_refs) { Ok(array) => array, - Err(arrow_schema::ArrowError::DictionaryKeyOverflowError { .. }) => { + Err(arrow_schema::ArrowError::DictionaryKeyOverflowError) => { unimplemented!("Dictionary arrays with more than 2^32 unique values") } err => err.unwrap(), @@ -1356,7 +1580,10 @@ impl DataBlock { .collect::>(); children.push(Self::from_arrays(&child_vec, num_values)); } - Self::Struct(StructDataBlock { children }) + Self::Struct(StructDataBlock { + children, + block_info: BlockInfo::default(), + }) } DataType::FixedSizeList(_, dim) => { let children = arrays @@ -1415,11 +1642,12 @@ impl From for DataBlock { } } -pub trait DataBlockBuilderImpl { +pub trait DataBlockBuilderImpl: std::fmt::Debug { fn append(&mut self, data_block: &DataBlock, selection: Range); fn finish(self: Box) -> DataBlock; } +#[derive(Debug)] pub struct DataBlockBuilder { estimated_size_bytes: u64, builder: Option>, @@ -1457,7 +1685,7 @@ mod tests { use arrow::datatypes::{Int32Type, Int8Type}; use arrow_array::{ make_array, new_null_array, ArrayRef, DictionaryArray, Int8Array, LargeBinaryArray, - StringArray, UInt8Array, + StringArray, UInt16Array, UInt8Array, }; use arrow_buffer::{BooleanBuffer, NullBuffer}; @@ -1471,6 +1699,26 @@ mod tests { use arrow::compute::concat; use arrow_array::Array; + + #[test] + fn test_sliced_to_data_block() { + let ints = UInt16Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8]); + let ints = ints.slice(2, 4); + let data = DataBlock::from_array(ints); + + let fixed_data = data.as_fixed_width().unwrap(); + assert_eq!(fixed_data.num_values, 4); + assert_eq!(fixed_data.data.len(), 8); + + let nullable_ints = + UInt16Array::from(vec![Some(0), None, Some(2), None, Some(4), None, Some(6)]); + let nullable_ints = nullable_ints.slice(1, 3); + let data = DataBlock::from_array(nullable_ints); + + let nullable = data.as_nullable().unwrap(); + assert_eq!(nullable.nulls, LanceBuffer::Owned(vec![0b00000010])); + } + #[test] fn test_string_to_data_block() { // Converting string arrays that contain nulls to DataBlock @@ -1761,7 +2009,7 @@ mod tests { let array_data = arr.to_data(); let total_buffer_size: usize = array_data.buffers().iter().map(|buffer| buffer.len()).sum(); // the NullBuffer.len() returns the length in bits so we divide_round_up by 8 - let array_nulls_size_in_bytes = (arr.nulls().unwrap().len() + 7) / 8; + let array_nulls_size_in_bytes = arr.nulls().unwrap().len().div_ceil(8); assert!(block.data_size() == (total_buffer_size + array_nulls_size_in_bytes) as u64); let arr = gen.generate(RowCount::from(400), &mut rng).unwrap(); @@ -1769,7 +2017,7 @@ mod tests { let array_data = arr.to_data(); let total_buffer_size: usize = array_data.buffers().iter().map(|buffer| buffer.len()).sum(); - let array_nulls_size_in_bytes = (arr.nulls().unwrap().len() + 7) / 8; + let array_nulls_size_in_bytes = arr.nulls().unwrap().len().div_ceil(8); assert!(block.data_size() == (total_buffer_size + array_nulls_size_in_bytes) as u64); let mut gen = array::rand::().with_nulls(&[true, true, false]); @@ -1778,7 +2026,7 @@ mod tests { let array_data = arr.to_data(); let total_buffer_size: usize = array_data.buffers().iter().map(|buffer| buffer.len()).sum(); - let array_nulls_size_in_bytes = (arr.nulls().unwrap().len() + 7) / 8; + let array_nulls_size_in_bytes = arr.nulls().unwrap().len().div_ceil(8); assert!(block.data_size() == (total_buffer_size + array_nulls_size_in_bytes) as u64); let arr = gen.generate(RowCount::from(400), &mut rng).unwrap(); @@ -1786,7 +2034,7 @@ mod tests { let array_data = arr.to_data(); let total_buffer_size: usize = array_data.buffers().iter().map(|buffer| buffer.len()).sum(); - let array_nulls_size_in_bytes = (arr.nulls().unwrap().len() + 7) / 8; + let array_nulls_size_in_bytes = arr.nulls().unwrap().len().div_ceil(8); assert!(block.data_size() == (total_buffer_size + array_nulls_size_in_bytes) as u64); let mut gen = array::rand::().with_nulls(&[false, true, false]); @@ -1808,7 +2056,7 @@ mod tests { .map(|buffer| buffer.len()) .sum(); - let total_nulls_size_in_bytes = (concatenated_array.nulls().unwrap().len() + 7) / 8; + let total_nulls_size_in_bytes = concatenated_array.nulls().unwrap().len().div_ceil(8); assert!(block.data_size() == (total_buffer_size + total_nulls_size_in_bytes) as u64); } } diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index 552abd68017..ea85f84427a 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -217,43 +217,48 @@ use std::sync::Once; use std::{ops::Range, sync::Arc}; use arrow_array::cast::AsArray; -use arrow_array::{ArrayRef, RecordBatch}; -use arrow_schema::{DataType, Field as ArrowField, Fields, Schema as ArrowSchema}; +use arrow_array::{ArrayRef, RecordBatch, RecordBatchIterator, RecordBatchReader}; +use arrow_schema::{ArrowError, DataType, Field as ArrowField, Fields, Schema as ArrowSchema}; use bytes::Bytes; -use futures::future::BoxFuture; +use futures::future::{maybe_done, BoxFuture, MaybeDone}; use futures::stream::{self, BoxStream}; use futures::{FutureExt, StreamExt}; use lance_arrow::DataTypeExt; use lance_core::cache::{CapacityMode, FileMetadataCache}; use lance_core::datatypes::{Field, Schema, BLOB_DESC_LANCE_FIELD}; use log::{debug, trace, warn}; -use snafu::{location, Location}; +use snafu::location; use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{self, unbounded_channel}; -use lance_core::{Error, Result}; +use lance_core::{ArrowResult, Error, Result}; use tracing::instrument; use crate::buffer::LanceBuffer; -use crate::data::DataBlock; +use crate::data::{DataBlock, FixedWidthDataBlock, VariableWidthBlock}; use crate::encoder::{values_column_encoding, EncodedBatch}; use crate::encodings::logical::binary::BinaryFieldScheduler; use crate::encodings::logical::blob::BlobFieldScheduler; -use crate::encodings::logical::list::{ListFieldScheduler, OffsetPageInfo}; +use crate::encodings::logical::list::{ + ListFieldScheduler, OffsetPageInfo, StructuralListScheduler, +}; use crate::encodings::logical::primitive::{ PrimitiveFieldScheduler, StructuralPrimitiveFieldScheduler, }; use crate::encodings::logical::r#struct::{ SimpleStructDecoder, SimpleStructScheduler, StructuralStructDecoder, StructuralStructScheduler, }; -use crate::encodings::physical::binary::{BinaryBlockDecompressor, BinaryMiniBlockDecompressor}; -use crate::encodings::physical::bitpack_fastlanes::BitpackMiniBlockDecompressor; -use crate::encodings::physical::fixed_size_list::FslPerValueDecompressor; -use crate::encodings::physical::fsst::FsstMiniBlockDecompressor; +use crate::encodings::physical::binary::{ + BinaryBlockDecompressor, BinaryMiniBlockDecompressor, VariableDecoder, +}; +use crate::encodings::physical::bitpack_fastlanes::InlineBitpacking; +use crate::encodings::physical::block_compress::CompressedBufferEncoder; +use crate::encodings::physical::fsst::{FsstMiniBlockDecompressor, FsstPerValueDecompressor}; +use crate::encodings::physical::struct_encoding::PackedStructFixedWidthMiniBlockDecompressor; use crate::encodings::physical::value::{ConstantDecompressor, ValueDecompressor}; use crate::encodings::physical::{ColumnBuffers, FileBuffers}; use crate::format::pb::{self, column_encoding}; -use crate::repdef::{LevelBuffer, RepDefUnraveler}; +use crate::repdef::{CompositeRepDefUnraveler, RepDefUnraveler}; use crate::version::LanceFileVersion; use crate::{BufferScheduler, EncodingsIo}; @@ -454,22 +459,25 @@ impl<'a> ColumnInfoIter<'a> { } pub trait MiniBlockDecompressor: std::fmt::Debug + Send + Sync { - fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result; + fn decompress(&self, data: Vec, num_values: u64) -> Result; } -pub trait PerValueDecompressor: std::fmt::Debug + Send + Sync { +pub trait FixedPerValueDecompressor: std::fmt::Debug + Send + Sync { /// Decompress one or more values - fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result; + fn decompress(&self, data: FixedWidthDataBlock, num_values: u64) -> Result; /// The number of bits in each value /// - /// Returns 0 if the data type is variable-width - /// /// Currently (and probably long term) this must be a multiple of 8 fn bits_per_value(&self) -> u64; } +pub trait VariablePerValueDecompressor: std::fmt::Debug + Send + Sync { + /// Decompress one or more values + fn decompress(&self, data: VariableWidthBlock) -> Result; +} + pub trait BlockDecompressor: std::fmt::Debug + Send + Sync { - fn decompress(&self, data: LanceBuffer) -> Result; + fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result; } pub trait DecompressorStrategy: std::fmt::Debug + Send + Sync { @@ -478,10 +486,15 @@ pub trait DecompressorStrategy: std::fmt::Debug + Send + Sync { description: &pb::ArrayEncoding, ) -> Result>; - fn create_per_value_decompressor( + fn create_fixed_per_value_decompressor( &self, description: &pb::ArrayEncoding, - ) -> Result>; + ) -> Result>; + + fn create_variable_per_value_decompressor( + &self, + description: &pb::ArrayEncoding, + ) -> Result>; fn create_block_decompressor( &self, @@ -489,7 +502,7 @@ pub trait DecompressorStrategy: std::fmt::Debug + Send + Sync { ) -> Result>; } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct CoreDecompressorStrategy {} impl DecompressorStrategy for CoreDecompressorStrategy { @@ -499,38 +512,65 @@ impl DecompressorStrategy for CoreDecompressorStrategy { ) -> Result> { match description.array_encoding.as_ref().unwrap() { pb::array_encoding::ArrayEncoding::Flat(flat) => { - Ok(Box::new(ValueDecompressor::new(flat))) + Ok(Box::new(ValueDecompressor::from_flat(flat))) } - pb::array_encoding::ArrayEncoding::Bitpack2(description) => { - Ok(Box::new(BitpackMiniBlockDecompressor::new(description))) + pb::array_encoding::ArrayEncoding::InlineBitpacking(description) => { + Ok(Box::new(InlineBitpacking::from_description(description))) } - pb::array_encoding::ArrayEncoding::BinaryMiniBlock(_) => { + pb::array_encoding::ArrayEncoding::Variable(_) => { Ok(Box::new(BinaryMiniBlockDecompressor::default())) } - pb::array_encoding::ArrayEncoding::FsstMiniBlock(description) => { + pb::array_encoding::ArrayEncoding::Fsst(description) => { Ok(Box::new(FsstMiniBlockDecompressor::new(description))) } + pb::array_encoding::ArrayEncoding::PackedStructFixedWidthMiniBlock(description) => { + Ok(Box::new(PackedStructFixedWidthMiniBlockDecompressor::new( + description, + ))) + } + pb::array_encoding::ArrayEncoding::FixedSizeList(fsl) => { + // In the future, we might need to do something more complex here if FSL supports + // compression. + Ok(Box::new(ValueDecompressor::from_fsl(fsl))) + } _ => todo!(), } } - fn create_per_value_decompressor( + fn create_fixed_per_value_decompressor( &self, description: &pb::ArrayEncoding, - ) -> Result> { + ) -> Result> { match description.array_encoding.as_ref().unwrap() { pb::array_encoding::ArrayEncoding::Flat(flat) => { - Ok(Box::new(ValueDecompressor::new(flat))) + Ok(Box::new(ValueDecompressor::from_flat(flat))) } pb::array_encoding::ArrayEncoding::FixedSizeList(fsl) => { - let items_decompressor = - self.create_per_value_decompressor(fsl.items.as_ref().unwrap())?; - Ok(Box::new(FslPerValueDecompressor::new( - items_decompressor, - fsl.dimension as u64, + Ok(Box::new(ValueDecompressor::from_fsl(fsl))) + } + _ => todo!("fixed-per-value decompressor for {:?}", description), + } + } + + fn create_variable_per_value_decompressor( + &self, + description: &pb::ArrayEncoding, + ) -> Result> { + match *description.array_encoding.as_ref().unwrap() { + pb::array_encoding::ArrayEncoding::Variable(variable) => { + assert!(variable.bits_per_offset < u8::MAX as u32); + Ok(Box::new(VariableDecoder::default())) + } + pb::array_encoding::ArrayEncoding::Fsst(ref fsst) => { + Ok(Box::new(FsstPerValueDecompressor::new( + LanceBuffer::from_bytes(fsst.symbol_table.clone(), 1), + Box::new(VariableDecoder::default()), ))) } - _ => todo!(), + pb::array_encoding::ArrayEncoding::Block(ref block) => Ok(Box::new( + CompressedBufferEncoder::from_scheme(&block.scheme)?, + )), + _ => todo!("variable-per-value decompressor for {:?}", description), } } @@ -540,16 +580,13 @@ impl DecompressorStrategy for CoreDecompressorStrategy { ) -> Result> { match description.array_encoding.as_ref().unwrap() { pb::array_encoding::ArrayEncoding::Flat(flat) => { - Ok(Box::new(ValueDecompressor::new(flat))) + Ok(Box::new(ValueDecompressor::from_flat(flat))) } pb::array_encoding::ArrayEncoding::Constant(constant) => { - let scalar = LanceBuffer::Owned(constant.value.clone()); - Ok(Box::new(ConstantDecompressor::new( - scalar, - constant.num_values, - ))) + let scalar = LanceBuffer::from_bytes(constant.value.clone(), 1); + Ok(Box::new(ConstantDecompressor::new(scalar))) } - pb::array_encoding::ArrayEncoding::BinaryBlock(_) => { + pb::array_encoding::ArrayEncoding::Variable(_) => { Ok(Box::new(BinaryBlockDecompressor::default())) } _ => todo!(), @@ -750,11 +787,26 @@ impl CoreFieldDecoderStrategy { column_info.as_ref(), self.decompressor_strategy.as_ref(), )?); + + // advance to the next top level column column_infos.next_top_level(); + return Ok(scheduler); } match &data_type { DataType::Struct(fields) => { + if field.is_packed_struct() { + let column_info = column_infos.expect_next()?; + let scheduler = Box::new(StructuralPrimitiveFieldScheduler::try_new( + column_info.as_ref(), + self.decompressor_strategy.as_ref(), + )?); + + // advance to the next top level column + column_infos.next_top_level(); + + return Ok(scheduler); + } let mut child_schedulers = Vec::with_capacity(field.children.len()); for field in field.children.iter() { let field_scheduler = @@ -777,6 +829,16 @@ impl CoreFieldDecoderStrategy { column_infos.next_top_level(); Ok(scheduler) } + DataType::List(_) | DataType::LargeList(_) => { + let child = field + .children + .first() + .expect("List field must have a child"); + let child_scheduler = + self.create_structural_field_scheduler(child, column_infos)?; + Ok(Box::new(StructuralListScheduler::new(child_scheduler)) + as Box) + } _ => todo!(), } } @@ -893,6 +955,11 @@ impl CoreFieldDecoderStrategy { } else { // use default struct encoding Self::check_simple_struct(column_info, &field.name).unwrap(); + let num_rows = column_info + .page_infos + .iter() + .map(|page| page.num_rows) + .sum(); let mut child_schedulers = Vec::with_capacity(field.children.len()); for field in &field.children { column_infos.next_top_level(); @@ -905,6 +972,7 @@ impl CoreFieldDecoderStrategy { Ok(Box::new(SimpleStructScheduler::new( child_schedulers, fields, + num_rows, ))) } } @@ -1272,12 +1340,25 @@ impl DecodeBatchScheduler { return; } trace!("Scheduling take of {} rows", indices.len()); - let ranges = indices - .iter() - .map(|&idx| idx..(idx + 1)) - .collect::>(); + let ranges = Self::indices_to_ranges(indices); self.schedule_ranges(&ranges, filter, sink, scheduler) } + + // coalesce continuous indices if possible (the input indices must be sorted and non-empty) + fn indices_to_ranges(indices: &[u64]) -> Vec> { + let mut ranges = Vec::new(); + let mut start = indices[0]; + + for window in indices.windows(2) { + if window[1] != window[0] + 1 { + ranges.push(start..window[0] + 1); + start = window[1]; + } + } + + ranges.push(start..*indices.last().unwrap() + 1); + ranges + } } pub struct ReadBatchTask { @@ -1408,34 +1489,6 @@ impl BatchDecodeStream { Ok(Some(next_task)) } - #[instrument(level = "debug", skip_all)] - fn task_to_batch( - task: NextDecodeTask, - emitted_batch_size_warning: Arc, - ) -> Result { - let struct_arr = task.task.decode(); - match struct_arr { - Ok(struct_arr) => { - let batch = RecordBatch::from(struct_arr.as_struct()); - let size_bytes = batch.get_array_memory_size() as u64; - if size_bytes > BATCH_SIZE_BYTES_WARNING { - emitted_batch_size_warning.call_once(|| { - let size_mb = size_bytes / 1024 / 1024; - debug!("Lance read in a single batch that contained more than {}MiB of data. You may want to consider reducing the batch size.", size_mb); - }); - } - Ok(batch) - } - Err(e) => { - let e = Error::Internal { - message: format!("Error decoding batch: {}", e), - location: location!(), - }; - Err(e) - } - } - } - pub fn into_stream(self) -> BoxStream<'static, ReadBatchTask> { let stream = futures::stream::unfold(self, |mut slf| async move { let next_task = slf.next_batch_task().await; @@ -1444,7 +1497,7 @@ impl BatchDecodeStream { let emitted_batch_size_warning = slf.emitted_batch_size_warning.clone(); let task = tokio::spawn(async move { let next_task = next_task?; - Self::task_to_batch(next_task, emitted_batch_size_warning) + next_task.into_batch(emitted_batch_size_warning) }); (task, num_rows) }); @@ -1463,6 +1516,195 @@ impl BatchDecodeStream { } } +// Utility types to smooth out the differences between the 2.0 and 2.1 decoders so that +// we can have a single implementation of the batch decode iterator +enum RootDecoderMessage { + LoadedPage(LoadedPage), + LegacyPage(DecoderReady), +} +trait RootDecoderType { + fn accept_message(&mut self, message: RootDecoderMessage) -> Result<()>; + fn drain_batch(&mut self, num_rows: u64) -> Result; + fn wait(&mut self, loaded_need: u64, runtime: &tokio::runtime::Runtime) -> Result<()>; +} +impl RootDecoderType for StructuralStructDecoder { + fn accept_message(&mut self, message: RootDecoderMessage) -> Result<()> { + let RootDecoderMessage::LoadedPage(loaded_page) = message else { + unreachable!() + }; + self.accept_page(loaded_page) + } + fn drain_batch(&mut self, num_rows: u64) -> Result { + self.drain_batch_task(num_rows) + } + fn wait(&mut self, _: u64, _: &tokio::runtime::Runtime) -> Result<()> { + // Waiting happens elsewhere (not as part of the decoder) + Ok(()) + } +} +impl RootDecoderType for SimpleStructDecoder { + fn accept_message(&mut self, message: RootDecoderMessage) -> Result<()> { + let RootDecoderMessage::LegacyPage(legacy_page) = message else { + unreachable!() + }; + self.accept_child(legacy_page) + } + fn drain_batch(&mut self, num_rows: u64) -> Result { + self.drain(num_rows) + } + fn wait(&mut self, loaded_need: u64, runtime: &tokio::runtime::Runtime) -> Result<()> { + runtime.block_on(self.wait_for_loaded(loaded_need)) + } +} + +/// A blocking batch decoder that performs synchronous decoding +struct BatchDecodeIterator { + messages: VecDeque>, + root_decoder: T, + rows_remaining: u64, + rows_per_batch: u32, + rows_scheduled: u64, + rows_drained: u64, + emitted_batch_size_warning: Arc, + // Note: this is not the runtime on which I/O happens. + // That's always in the scheduler. This is just a runtime we use to + // sleep the current thread if I/O is unready + wait_for_io_runtime: tokio::runtime::Runtime, + schema: Arc, +} + +impl BatchDecodeIterator { + /// Create a new instance of a batch decode iterator + pub fn new( + messages: VecDeque>, + rows_per_batch: u32, + num_rows: u64, + root_decoder: T, + schema: Arc, + ) -> Self { + Self { + messages, + root_decoder, + rows_remaining: num_rows, + rows_per_batch, + rows_scheduled: 0, + rows_drained: 0, + wait_for_io_runtime: tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(), + emitted_batch_size_warning: Arc::new(Once::new()), + schema, + } + } + + /// Wait for a single page of data to finish loading + /// + /// If the data is not available this will perform a *blocking* wait (put + /// the current thread to sleep) + fn wait_for_page(&self, unloaded_page: UnloadedPage) -> Result { + match maybe_done(unloaded_page.0) { + // Fast path, avoid all runtime shenanigans if the data is ready + MaybeDone::Done(loaded_page) => loaded_page, + // Slow path, we need to wait on I/O, enter the runtime + MaybeDone::Future(fut) => self.wait_for_io_runtime.block_on(fut), + MaybeDone::Gone => unreachable!(), + } + } + + /// Waits for I/O until `scheduled_need` rows have been loaded + /// + /// Note that `scheduled_need` is cumulative. E.g. this method + /// should be called with 5, 10, 15 and not 5, 5, 5 + #[instrument(skip_all)] + fn wait_for_io(&mut self, scheduled_need: u64) -> Result { + while self.rows_scheduled < scheduled_need && !self.messages.is_empty() { + let message = self.messages.pop_front().unwrap()?; + self.rows_scheduled = message.scheduled_so_far; + for decoder_message in message.decoders { + match decoder_message { + MessageType::UnloadedPage(unloaded_page) => { + let loaded_page = self.wait_for_page(unloaded_page)?; + self.root_decoder + .accept_message(RootDecoderMessage::LoadedPage(loaded_page))?; + } + MessageType::DecoderReady(decoder_ready) => { + // The root decoder we can ignore + if !decoder_ready.path.is_empty() { + self.root_decoder + .accept_message(RootDecoderMessage::LegacyPage(decoder_ready))?; + } + } + } + } + } + + let loaded_need = self.rows_drained + self.rows_per_batch as u64 - 1; + + self.root_decoder + .wait(loaded_need, &self.wait_for_io_runtime)?; + Ok(self.rows_scheduled) + } + + #[instrument(level = "debug", skip_all)] + fn next_batch_task(&mut self) -> Result> { + trace!( + "Draining batch task (rows_remaining={} rows_drained={} rows_scheduled={})", + self.rows_remaining, + self.rows_drained, + self.rows_scheduled, + ); + if self.rows_remaining == 0 { + return Ok(None); + } + + let mut to_take = self.rows_remaining.min(self.rows_per_batch as u64); + self.rows_remaining -= to_take; + + let scheduled_need = (self.rows_drained + to_take).saturating_sub(self.rows_scheduled); + trace!("scheduled_need = {} because rows_drained = {} and to_take = {} and rows_scheduled = {}", scheduled_need, self.rows_drained, to_take, self.rows_scheduled); + if scheduled_need > 0 { + let desired_scheduled = scheduled_need + self.rows_scheduled; + trace!( + "Draining from scheduler (desire at least {} scheduled rows)", + desired_scheduled + ); + let actually_scheduled = self.wait_for_io(desired_scheduled)?; + if actually_scheduled < desired_scheduled { + let under_scheduled = desired_scheduled - actually_scheduled; + to_take -= under_scheduled; + } + } + + if to_take == 0 { + return Ok(None); + } + + let next_task = self.root_decoder.drain_batch(to_take)?; + + self.rows_drained += to_take; + + let batch = next_task.into_batch(self.emitted_batch_size_warning.clone())?; + + Ok(Some(batch)) + } +} + +impl Iterator for BatchDecodeIterator { + type Item = ArrowResult; + + fn next(&mut self) -> Option { + self.next_batch_task() + .transpose() + .map(|r| r.map_err(ArrowError::from)) + } +} + +impl RecordBatchReader for BatchDecodeIterator { + fn schema(&self) -> Arc { + self.schema.clone() + } +} + /// A stream that takes scheduled jobs and generates decode tasks from them. pub struct StructuralBatchDecodeStream { context: DecoderContext, @@ -1566,44 +1808,11 @@ impl StructuralBatchDecodeStream { return Ok(None); } - let next_task = self.root_decoder.drain(to_take)?; - let next_task = NextDecodeTask { - has_more: self.rows_remaining > 0, - num_rows: to_take, - task: Box::new(next_task), - }; + let next_task = self.root_decoder.drain_batch_task(to_take)?; self.rows_drained += to_take; Ok(Some(next_task)) } - #[instrument(level = "debug", skip_all)] - fn task_to_batch( - task: NextDecodeTask, - emitted_batch_size_warning: Arc, - ) -> Result { - let struct_arr = task.task.decode(); - match struct_arr { - Ok(struct_arr) => { - let batch = RecordBatch::from(struct_arr.as_struct()); - let size_bytes = batch.get_array_memory_size() as u64; - if size_bytes > BATCH_SIZE_BYTES_WARNING { - emitted_batch_size_warning.call_once(|| { - let size_mb = size_bytes / 1024 / 1024; - debug!("Lance read in a single batch that contained more than {}MiB of data. You may want to consider reducing the batch size.", size_mb); - }); - } - Ok(batch) - } - Err(e) => { - let e = Error::Internal { - message: format!("Error decoding batch: {}", e), - location: location!(), - }; - Err(e) - } - } - } - pub fn into_stream(self) -> BoxStream<'static, ReadBatchTask> { let stream = futures::stream::unfold(self, |mut slf| async move { let next_task = slf.next_batch_task().await; @@ -1612,7 +1821,7 @@ impl StructuralBatchDecodeStream { let emitted_batch_size_warning = slf.emitted_batch_size_warning.clone(); let task = tokio::spawn(async move { let next_task = next_task?; - Self::task_to_batch(next_task, emitted_batch_size_warning) + next_task.into_batch(emitted_batch_size_warning) }); (task, num_rows) }); @@ -1685,7 +1894,11 @@ pub fn create_decode_stream( ) -> BoxStream<'static, ReadBatchTask> { if is_structural { let arrow_schema = ArrowSchema::from(schema); - let structural_decoder = StructuralStructDecoder::new(arrow_schema.fields, should_validate); + let structural_decoder = StructuralStructDecoder::new( + arrow_schema.fields, + should_validate, + /*is_root=*/ true, + ); StructuralBatchDecodeStream::new(rx, batch_size, num_rows, structural_decoder).into_stream() } else { let arrow_schema = ArrowSchema::from(schema); @@ -1696,6 +1909,41 @@ pub fn create_decode_stream( } } +/// Creates a iterator that decodes a set of messages in a blocking fashion +/// +/// See [`schedule_and_decode_blocking`] for more information. +pub fn create_decode_iterator( + schema: &Schema, + num_rows: u64, + batch_size: u32, + should_validate: bool, + is_structural: bool, + messages: VecDeque>, +) -> Box { + let arrow_schema = Arc::new(ArrowSchema::from(schema)); + let root_fields = arrow_schema.fields.clone(); + if is_structural { + let simple_struct_decoder = + StructuralStructDecoder::new(root_fields, should_validate, /*is_root=*/ true); + Box::new(BatchDecodeIterator::new( + messages, + batch_size, + num_rows, + simple_struct_decoder, + arrow_schema, + )) + } else { + let root_decoder = SimpleStructDecoder::new(root_fields, num_rows); + Box::new(BatchDecodeIterator::new( + messages, + batch_size, + num_rows, + root_decoder, + arrow_schema, + )) + } +} + fn create_scheduler_decoder( column_infos: Vec>, requested_rows: RequestedRows, @@ -1790,6 +2038,90 @@ pub fn schedule_and_decode( } } +lazy_static::lazy_static! { + pub static ref WAITER_RT: tokio::runtime::Runtime = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); +} + +/// Schedules and decodes the requested data in a blocking fashion +/// +/// This function is a blocking version of [`schedule_and_decode`]. It schedules the requested data +/// and decodes it in the current thread. +/// +/// This can be useful when the disk is fast (or the data is in memory) and the amount +/// of data is relatively small. For example, when doing a take against NVMe or in-memory data. +/// +/// This should NOT be used for full scans. Even if the data is in memory this function will +/// not parallelize the decode and will be slower than the async version. Full scans typically +/// make relatively few IOPs and so the asynchronous overhead is much smaller. +/// +/// This method will first completely run the scheduling process. Then it will run the +/// decode process. +pub fn schedule_and_decode_blocking( + column_infos: Vec>, + requested_rows: RequestedRows, + filter: FilterExpression, + column_indices: Vec, + target_schema: Arc, + config: SchedulerDecoderConfig, +) -> Result> { + if requested_rows.num_rows() == 0 { + let arrow_schema = Arc::new(ArrowSchema::from(target_schema.as_ref())); + return Ok(Box::new(RecordBatchIterator::new(vec![], arrow_schema))); + } + + let num_rows = requested_rows.num_rows(); + let is_structural = column_infos[0].is_structural(); + + let (tx, mut rx) = mpsc::unbounded_channel(); + + // Initialize the scheduler. This is still "asynchronous" but we run it with a current-thread + // runtime. + let mut decode_scheduler = WAITER_RT.block_on(DecodeBatchScheduler::try_new( + target_schema.as_ref(), + &column_indices, + &column_infos, + &vec![], + num_rows, + config.decoder_plugins, + config.io.clone(), + config.cache, + &filter, + ))?; + + // Schedule the requested rows + match requested_rows { + RequestedRows::Ranges(ranges) => { + decode_scheduler.schedule_ranges(&ranges, &filter, tx, config.io) + } + RequestedRows::Indices(indices) => { + decode_scheduler.schedule_take(&indices, &filter, tx, config.io) + } + } + + // Drain the scheduler queue into a vec of decode messages + let mut messages = Vec::new(); + while rx + .recv_many(&mut messages, usize::MAX) + .now_or_never() + .unwrap() + != 0 + {} + + // Create a decoder to decode the messages + let decode_iterator = create_decode_iterator( + &target_schema, + num_rows, + config.batch_size, + config.should_validate, + is_structural, + messages.into(), + ); + + Ok(decode_iterator) +} + /// A decoder for single-column encodings of primitive data (this includes fixed size /// lists of primitive data) /// @@ -1852,7 +2184,7 @@ pub trait PageScheduler: Send + Sync + std::fmt::Debug { /// # Arguments /// /// * `range` - the range of row offsets (relative to start of page) requested - /// these must be ordered and must not overlap + /// these must be ordered and must not overlap /// * `scheduler` - a scheduler to submit the I/O request to /// * `top_level_row` - the row offset of the top level field currently being /// scheduled. This can be used to assign priority to I/O requests @@ -2173,14 +2505,46 @@ impl DecodeArrayTask for Box { } } -/// A task to decode data into an Arrow array +/// A task to decode data into an Arrow record batch +/// +/// It has a child `task` which decodes a struct array with no nulls. +/// This is then converted into a record batch. pub struct NextDecodeTask { /// The decode task itself pub task: Box, /// The number of rows that will be created pub num_rows: u64, - /// Whether or not the decoder that created this still has more rows to decode - pub has_more: bool, +} + +impl NextDecodeTask { + // Run the task and produce a record batch + // + // If the batch is very large this function will log a warning message + // suggesting the user try a smaller batch size. + #[instrument(name = "task_to_batch", level = "debug", skip_all)] + fn into_batch(self, emitted_batch_size_warning: Arc) -> Result { + let struct_arr = self.task.decode(); + match struct_arr { + Ok(struct_arr) => { + let batch = RecordBatch::from(struct_arr.as_struct()); + let size_bytes = batch.get_array_memory_size() as u64; + if size_bytes > BATCH_SIZE_BYTES_WARNING { + emitted_batch_size_warning.call_once(|| { + let size_mb = size_bytes / 1024 / 1024; + debug!("Lance read in a single batch that contained more than {}MiB of data. You may want to consider reducing the batch size.", size_mb); + }); + } + Ok(batch) + } + Err(e) => { + let e = Error::Internal { + message: format!("Error decoding batch: {}", e), + location: location!(), + }; + Err(e) + } + } + } } #[derive(Debug)] @@ -2305,8 +2669,7 @@ pub trait LogicalPageDecoder: std::fmt::Debug + Send { pub struct DecodedPage { pub data: DataBlock, - pub repetition: Option, - pub definition: Option, + pub repdef: RepDefUnraveler, } pub trait DecodePageTask: Send + std::fmt::Debug { @@ -2347,7 +2710,7 @@ pub struct LoadedPage { pub struct DecodedArray { pub array: ArrayRef, - pub repdef: RepDefUnraveler, + pub repdef: CompositeRepDefUnraveler, } pub trait StructuralDecodeArrayTask: std::fmt::Debug + Send { @@ -2414,3 +2777,30 @@ pub async fn decode_batch( ); decode_stream.next().await.unwrap().task.await } + +#[cfg(test)] +// test coalesce indices to ranges +mod tests { + use super::*; + + #[test] + fn test_coalesce_indices_to_ranges_with_single_index() { + let indices = vec![1]; + let ranges = DecodeBatchScheduler::indices_to_ranges(&indices); + assert_eq!(ranges, vec![1..2]); + } + + #[test] + fn test_coalesce_indices_to_ranges() { + let indices = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; + let ranges = DecodeBatchScheduler::indices_to_ranges(&indices); + assert_eq!(ranges, vec![1..10]); + } + + #[test] + fn test_coalesce_indices_to_ranges_with_gaps() { + let indices = vec![1, 2, 3, 5, 6, 7, 9]; + let ranges = DecodeBatchScheduler::indices_to_ranges(&indices); + assert_eq!(ranges, vec![1..4, 5..8, 9..10]); + } +} diff --git a/rust/lance-encoding/src/encoder.rs b/rust/lance-encoding/src/encoder.rs index 53ff59aed8c..4a31862c805 100644 --- a/rust/lance-encoding/src/encoder.rs +++ b/rust/lance-encoding/src/encoder.rs @@ -14,25 +14,30 @@ use lance_core::datatypes::{ }; use lance_core::utils::bit::{is_pwr_two, pad_bytes_to}; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; use crate::buffer::LanceBuffer; use crate::data::{DataBlock, FixedWidthDataBlock, VariableWidthBlock}; use crate::decoder::PageEncoding; use crate::encodings::logical::blob::BlobFieldEncoder; +use crate::encodings::logical::list::ListStructuralEncoder; use crate::encodings::logical::primitive::PrimitiveStructuralEncoder; use crate::encodings::logical::r#struct::StructFieldEncoder; use crate::encodings::logical::r#struct::StructStructuralEncoder; -use crate::encodings::physical::binary::{BinaryBlockEncoder, BinaryMiniBlockEncoder}; +use crate::encodings::physical::binary::{BinaryMiniBlockEncoder, VariableEncoder}; use crate::encodings::physical::bitpack_fastlanes::BitpackedForNonNegArrayEncoder; use crate::encodings::physical::bitpack_fastlanes::{ - compute_compressed_bit_width_for_non_neg, BitpackMiniBlockEncoder, + compute_compressed_bit_width_for_non_neg, InlineBitpacking, +}; +use crate::encodings::physical::block_compress::{ + CompressedBufferEncoder, CompressionConfig, CompressionScheme, }; -use crate::encodings::physical::block_compress::{CompressionConfig, CompressionScheme}; use crate::encodings::physical::dictionary::AlreadyDictionaryEncoder; -use crate::encodings::physical::fixed_size_list::FslPerValueCompressor; -use crate::encodings::physical::fsst::{FsstArrayEncoder, FsstMiniBlockEncoder}; +use crate::encodings::physical::fsst::{ + FsstArrayEncoder, FsstMiniBlockEncoder, FsstPerValueEncoder, +}; use crate::encodings::physical::packed_struct::PackedStructEncoder; +use crate::encodings::physical::struct_encoding::PackedStructFixedWidthMiniBlockEncoder; use crate::format::ProtobufUtils; use crate::repdef::RepDefBuilder; use crate::statistics::{GetStat, Stat}; @@ -145,9 +150,10 @@ pub const MAX_MINIBLOCK_VALUES: u64 = 4096; /// Page data that has been compressed into a series of chunks put into /// a single buffer. +#[derive(Debug)] pub struct MiniBlockCompressed { - /// The buffer of compressed data - pub data: LanceBuffer, + /// The buffers of compressed data + pub data: Vec, /// Describes the size of each chunk pub chunks: Vec, /// The number of values in the entire page @@ -165,10 +171,10 @@ pub struct MiniBlockCompressed { /// data (values, repetition, and definition) per mini-block. #[derive(Debug)] pub struct MiniBlockChunk { - // The number of bytes that make up the chunk + // The size in bytes of each buffer in the chunk. // - // This value must be less than or equal to 8Ki - 6 (8188) - pub num_bytes: u16, + // The total size must be less than or equal to 8Ki - 6 (8188) + pub buffer_sizes: Vec, // The log (base 2) of the number of values in the chunk. If this is the final chunk // then this should be 0 (the number of values will be calculated by subtracting the // size of all other chunks from the total size of the page) @@ -216,11 +222,21 @@ pub trait MiniBlockCompressor: std::fmt::Debug + Send + Sync { /// A single buffer of value data and a buffer of offsets /// /// TODO: In the future we may allow metadata buffers +#[derive(Debug)] pub enum PerValueDataBlock { Fixed(FixedWidthDataBlock), Variable(VariableWidthBlock), } +impl PerValueDataBlock { + pub fn data_size(&self) -> u64 { + match self { + Self::Fixed(fixed) => fixed.data_size(), + Self::Variable(variable) => variable.data_size(), + } + } +} + /// Trait for compression algorithms that are suitable for use in the zipped structural encoding /// /// This compression must return either a FixedWidthDataBlock or a VariableWidthBlock. This is because @@ -236,28 +252,6 @@ pub trait PerValueCompressor: std::fmt::Debug + Send + Sync { fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)>; } -/// Trait for compression algorithms that are suitable for use in the zipped structural encoding -/// -/// This encoding is useful for non-short strings, binary, and variable length lists -/// (i.e. when the average value is >= 128 bytes) -/// -/// These compressors can be extremely generic. They only need to produce one buffer of bytes -/// and another buffer of offsets into the bytes, one offset for each value. Both of these buffers -/// will be stored. -/// -/// Note: It is perfectly legal for a value to have 0 bytes. However, we still need to store the -/// offset itself. This means that this compressor, when implemented by something like RLE will not -/// be as efficient (space-wise) as a block version (which could skip the offsets for runs). -/// -/// Accessing this data will require 2 IOPS and accessing in a random-access fashion will require -/// a repetition index. -pub trait VariablePerValueCompressor: std::fmt::Debug + Send + Sync { - /// Compress the data into a single buffer where each value is encoded with a different size - /// - /// Also returns a description of the compression that can be used to decompress when reading the data back - fn compress(&self, data: DataBlock) -> Result<(VariableWidthBlock, pb::ArrayEncoding)>; -} - /// Trait for compression algorithms that compress an entire block of data into one opaque /// and self-described chunk. /// @@ -370,12 +364,21 @@ pub trait FieldEncoder: Send { /// than a single disk page. /// /// It could also return an empty Vec if there is not enough data yet to encode any pages. + /// + /// The `row_number` must be passed which is the top-level row number currently being encoded + /// This is stored in any pages produced by this call so that we can know the priority of the + /// page. + /// + /// The `num_rows` is the number of top level rows. It is initially the same as `array.len()` + /// however it is passed seprately because array will become flattened over time (if there is + /// repetition) and we need to know the original number of rows for various purposes. fn maybe_encode( &mut self, array: ArrayRef, external_buffers: &mut OutOfLineBuffers, repdef: RepDefBuilder, row_number: u64, + num_rows: u64, ) -> Result>; /// Flush any remaining data from the buffers into encoding tasks /// @@ -422,9 +425,9 @@ pub trait ArrayEncodingStrategy: Send + Sync + std::fmt::Debug { /// width data block. In other words, there is some number of bits per value. /// In addition, each value should be independently decompressible. /// - Mini-block compression results in a small block of opaque data for chunks -/// of rows. Each block is somewhere between 0 and 16KiB in size. This is -/// used for narrow data types (both fixed and variable length) where we can -/// fit many values into an 16KiB block. +/// of rows. Each block is somewhere between 0 and 16KiB in size. This is +/// used for narrow data types (both fixed and variable length) where we can +/// fit many values into an 16KiB block. pub trait CompressionStrategy: Send + Sync + std::fmt::Debug { /// Create a block compressor for the given data fn create_block_compressor( @@ -493,13 +496,26 @@ impl CoreArrayEncodingStrategy { let bin_indices_encoder = Self::choose_array_encoder(arrays, &DataType::UInt64, data_size, false, version, None)?; - let compression = field_meta.and_then(Self::get_field_compression); - - let bin_encoder = Box::new(BinaryEncoder::new(bin_indices_encoder, compression)); - if compression.is_none() && Self::can_use_fsst(data_type, data_size, version) { - Ok(Box::new(FsstArrayEncoder::new(bin_encoder))) + if let Some(compression) = field_meta.and_then(Self::get_field_compression) { + if compression.scheme == CompressionScheme::Fsst { + // User requested FSST + let raw_encoder = Box::new(BinaryEncoder::new(bin_indices_encoder, None)); + Ok(Box::new(FsstArrayEncoder::new(raw_encoder))) + } else { + // Generic compression + Ok(Box::new(BinaryEncoder::new( + bin_indices_encoder, + Some(compression), + ))) + } } else { - Ok(bin_encoder) + // No user-specified compression, use FSST if we can + let bin_encoder = Box::new(BinaryEncoder::new(bin_indices_encoder, None)); + if Self::can_use_fsst(data_type, data_size, version) { + Ok(Box::new(FsstArrayEncoder::new(bin_encoder))) + } else { + Ok(bin_encoder) + } } } @@ -787,78 +803,131 @@ impl ArrayEncodingStrategy for CoreArrayEncodingStrategy { impl CompressionStrategy for CoreArrayEncodingStrategy { fn create_miniblock_compressor( &self, - _field: &Field, + field: &Field, data: &DataBlock, ) -> Result> { - if let DataBlock::FixedWidth(ref fixed_width_data) = data { - let bit_widths = data - .get_stat(Stat::BitWidth) - .expect("FixedWidthDataBlock should have valid `Stat::BitWidth` statistics"); - // Temporary hack to work around https://github.com/lancedb/lance/issues/3102 - // Ideally we should still be able to bit-pack here (either to 0 or 1 bit per value) - let has_all_zeros = bit_widths - .as_primitive::() - .values() - .iter() - .any(|v| *v == 0); - if !has_all_zeros - && (fixed_width_data.bits_per_value == 8 - || fixed_width_data.bits_per_value == 16 - || fixed_width_data.bits_per_value == 32 - || fixed_width_data.bits_per_value == 64) - { - return Ok(Box::new(BitpackMiniBlockEncoder::default())); + match data { + DataBlock::FixedWidth(fixed_width_data) => { + if let Some(compression) = field.metadata.get(COMPRESSION_META_KEY) { + if compression == "none" { + return Ok(Box::new(ValueEncoder::default())); + } + } + + let bit_widths = data.expect_stat(Stat::BitWidth); + let bit_widths = bit_widths.as_primitive::(); + // Temporary hack to work around https://github.com/lancedb/lance/issues/3102 + // Ideally we should still be able to bit-pack here (either to 0 or 1 bit per value) + let has_all_zeros = bit_widths.values().iter().any(|v| *v == 0); + // The minimum bit packing size is a block of 1024 values. For very small pages the uncompressed + // size might be smaller than the compressed size. + let too_small = bit_widths.len() == 1 + && InlineBitpacking::min_size_bytes(bit_widths.value(0)) >= data.data_size(); + if !has_all_zeros + && !too_small + && (fixed_width_data.bits_per_value == 8 + || fixed_width_data.bits_per_value == 16 + || fixed_width_data.bits_per_value == 32 + || fixed_width_data.bits_per_value == 64) + { + Ok(Box::new(InlineBitpacking::new( + fixed_width_data.bits_per_value, + ))) + } else { + Ok(Box::new(ValueEncoder::default())) + } } - } - if let DataBlock::VariableWidth(ref variable_width_data) = data { - if variable_width_data.bits_per_offset == 32 { - let data_size = variable_width_data.get_stat(Stat::DataSize).expect( - "VariableWidth DataBlock should have valid `Stat::DataSize` statistics", - ); - let data_size = data_size.as_primitive::().value(0); - - let max_len = variable_width_data.get_stat(Stat::MaxLength).expect( - "VariableWidth DataBlock should have valid `Stat::DataSize` statistics", - ); - let max_len = max_len.as_primitive::().value(0); - - if max_len >= FSST_LEAST_INPUT_MAX_LENGTH - && data_size >= FSST_LEAST_INPUT_SIZE as u64 + DataBlock::VariableWidth(variable_width_data) => { + if variable_width_data.bits_per_offset == 32 { + let data_size = + variable_width_data.expect_single_stat::(Stat::DataSize); + let max_len = + variable_width_data.expect_single_stat::(Stat::MaxLength); + + if max_len >= FSST_LEAST_INPUT_MAX_LENGTH + && data_size >= FSST_LEAST_INPUT_SIZE as u64 + { + Ok(Box::new(FsstMiniBlockEncoder::default())) + } else { + Ok(Box::new(BinaryMiniBlockEncoder::default())) + } + } else { + todo!("Implement MiniBlockCompression for VariableWidth DataBlock with 64 bits offsets.") + } + } + DataBlock::Struct(struct_data_block) => { + // this condition is actually checked at `PrimitiveStructuralEncoder::do_flush`, + // just being cautious here. + if struct_data_block + .children + .iter() + .any(|child| !matches!(child, DataBlock::FixedWidth(_))) { - return Ok(Box::new(FsstMiniBlockEncoder::default())); + panic!("packed struct encoding currently only supports fixed-width fields.") } - return Ok(Box::new(BinaryMiniBlockEncoder::default())); + Ok(Box::new(PackedStructFixedWidthMiniBlockEncoder::default())) + } + DataBlock::FixedSizeList(_) => { + // Ideally we would compress the list items but this creates something of a challenge. + // We don't want to break lists across chunks and we need to worry about inner validity + // layers. If we try and use a compression scheme then it is unlikely to respect these + // constraints. + // + // For now, we just don't compress. In the future, we might want to consider a more + // sophisticated approach. + Ok(Box::new(ValueEncoder::default())) } + _ => Err(Error::NotSupported { + source: format!( + "Mini-block compression not yet supported for block type {}", + data.name() + ) + .into(), + location: location!(), + }), } - Ok(Box::new(ValueEncoder::default())) } fn create_per_value( &self, - field: &Field, + _field: &Field, data: &DataBlock, ) -> Result> { match data { - DataBlock::FixedWidth(_) => { - let encoder = Box::new(ValueEncoder::default()); - Ok(encoder) - } - DataBlock::VariableWidth(_variable_width) => { - todo!() - } - DataBlock::FixedSizeList(fsl) => { - let DataType::FixedSizeList(inner_field, field_dim) = field.data_type() else { - panic!("FSL data block without FSL field") - }; - debug_assert_eq!(fsl.dimension, field_dim as u64); - let inner_compressor = self.create_per_value( - &inner_field.as_ref().try_into().unwrap(), - fsl.child.as_ref(), - )?; - let fsl_compressor = FslPerValueCompressor::new(inner_compressor, fsl.dimension); - Ok(Box::new(fsl_compressor)) + DataBlock::FixedWidth(_) => Ok(Box::new(ValueEncoder::default())), + DataBlock::FixedSizeList(_) => Ok(Box::new(ValueEncoder::default())), + DataBlock::VariableWidth(variable_width) => { + let max_len = variable_width.expect_single_stat::(Stat::MaxLength); + let data_size = variable_width.expect_single_stat::(Stat::DataSize); + + // If values are very large then use zstd-per-value + // + // TODO: Could maybe use median here + if max_len > 32 * 1024 && data_size >= FSST_LEAST_INPUT_SIZE as u64 { + return Ok(Box::new(CompressedBufferEncoder::default())); + } + + if variable_width.bits_per_offset == 32 { + let data_size = variable_width.expect_single_stat::(Stat::DataSize); + let max_len = variable_width.expect_single_stat::(Stat::MaxLength); + + let variable_compression = Box::new(VariableEncoder::default()); + + if max_len >= FSST_LEAST_INPUT_MAX_LENGTH + && data_size >= FSST_LEAST_INPUT_SIZE as u64 + { + Ok(Box::new(FsstPerValueEncoder::new(variable_compression))) + } else { + Ok(variable_compression) + } + } else { + todo!("Implement MiniBlockCompression for VariableWidth DataBlock with 64 bits offsets.") + } } - _ => unreachable!(), + _ => unreachable!( + "Per-value compression not yet supported for block type: {}", + data.name() + ), } } @@ -876,13 +945,9 @@ impl CompressionStrategy for CoreArrayEncodingStrategy { Ok((encoder, encoding)) } DataBlock::VariableWidth(variable_width) => { - if variable_width.bits_per_offset == 32 { - let encoder = Box::new(BinaryBlockEncoder::default()); - let encoding = ProtobufUtils::binary_block(); - Ok((encoder, encoding)) - } else { - todo!("Implement BlockCompression for VariableWidth DataBlock with 64 bits offsets.") - } + let encoder = Box::new(VariableEncoder::default()); + let encoding = ProtobufUtils::variable(variable_width.bits_per_offset); + Ok((encoder, encoding)) } _ => unreachable!(), } @@ -1192,15 +1257,14 @@ impl StructuralEncodingStrategy { | DataType::LargeUtf8, ) } -} -impl FieldEncodingStrategy for StructuralEncodingStrategy { - fn create_field_encoder( + fn do_create_field_encoder( &self, _encoding_strategy_root: &dyn FieldEncodingStrategy, field: &Field, column_index: &mut ColumnIndexSequence, options: &EncodingOptions, + root_field_metadata: &HashMap, ) -> Result> { let data_type = field.data_type(); if Self::is_primitive_type(&data_type) { @@ -1209,35 +1273,41 @@ impl FieldEncodingStrategy for StructuralEncodingStrategy { self.compression_strategy.clone(), column_index.next_column_index(field.id as u32), field.clone(), + Arc::new(root_field_metadata.clone()), )?)) } else { match data_type { - DataType::List(_child) | DataType::LargeList(_child) => { - todo!() + DataType::List(_) | DataType::LargeList(_) => { + let child = field.children.first().expect("List should have a child"); + let child_encoder = self.do_create_field_encoder( + _encoding_strategy_root, + child, + column_index, + options, + root_field_metadata, + )?; + Ok(Box::new(ListStructuralEncoder::new(child_encoder))) } DataType::Struct(_) => { - let field_metadata = &field.metadata; - if field_metadata - .get("packed") - .map(|v| v == "true") - .unwrap_or(false) - { + if field.is_packed_struct() { Ok(Box::new(PrimitiveStructuralEncoder::try_new( options, self.compression_strategy.clone(), column_index.next_column_index(field.id as u32), field.clone(), + Arc::new(root_field_metadata.clone()), )?)) } else { let children_encoders = field .children .iter() .map(|field| { - self.create_field_encoder( + self.do_create_field_encoder( _encoding_strategy_root, field, column_index, options, + root_field_metadata, ) }) .collect::>>()?; @@ -1252,6 +1322,7 @@ impl FieldEncodingStrategy for StructuralEncodingStrategy { self.compression_strategy.clone(), column_index.next_column_index(field.id as u32), field.clone(), + Arc::new(root_field_metadata.clone()), )?)) } else { // A dictionary of logical is, itself, logical and we don't support that today @@ -1268,6 +1339,24 @@ impl FieldEncodingStrategy for StructuralEncodingStrategy { } } +impl FieldEncodingStrategy for StructuralEncodingStrategy { + fn create_field_encoder( + &self, + encoding_strategy_root: &dyn FieldEncodingStrategy, + field: &Field, + column_index: &mut ColumnIndexSequence, + options: &EncodingOptions, + ) -> Result> { + self.do_create_field_encoder( + encoding_strategy_root, + field, + column_index, + options, + &field.metadata, + ) + } +} + /// A batch encoder that encodes RecordBatch objects by delegating /// to field encoders for each top-level field in the batch. pub struct BatchEncoder { @@ -1377,7 +1466,9 @@ pub async fn encode_batch( OutOfLineBuffers::new(data_buffer.len() as u64, options.buffer_alignment); let repdef = RepDefBuilder::default(); let encoder = encoder.as_mut(); - let mut tasks = encoder.maybe_encode(arr.clone(), &mut external_buffers, repdef, 0)?; + let num_rows = arr.len() as u64; + let mut tasks = + encoder.maybe_encode(arr.clone(), &mut external_buffers, repdef, 0, num_rows)?; tasks.extend(encoder.flush(&mut external_buffers)?); for buffer in external_buffers.take_buffers() { data_buffer.extend_from_slice(&buffer); diff --git a/rust/lance-encoding/src/encodings/logical/binary.rs b/rust/lance-encoding/src/encodings/logical/binary.rs index 1791f31b158..3acfe194941 100644 --- a/rust/lance-encoding/src/encodings/logical/binary.rs +++ b/rust/lance-encoding/src/encodings/logical/binary.rs @@ -27,7 +27,7 @@ pub struct BinarySchedulingJob<'a> { inner: Box, } -impl<'a> SchedulingJob for BinarySchedulingJob<'a> { +impl SchedulingJob for BinarySchedulingJob<'_> { fn schedule_next( &mut self, context: &mut SchedulerContext, @@ -118,7 +118,6 @@ impl LogicalPageDecoder for BinaryPageDecoder { fn drain(&mut self, num_rows: u64) -> Result { let inner_task = self.inner.drain(num_rows)?; Ok(NextDecodeTask { - has_more: inner_task.has_more, num_rows: inner_task.num_rows, task: Box::new(BinaryArrayDecoder { inner: inner_task.task, diff --git a/rust/lance-encoding/src/encodings/logical/blob.rs b/rust/lance-encoding/src/encodings/logical/blob.rs index ea26cb84e21..d235d36cb50 100644 --- a/rust/lance-encoding/src/encodings/logical/blob.rs +++ b/rust/lance-encoding/src/encodings/logical/blob.rs @@ -11,7 +11,7 @@ use arrow_buffer::{ use arrow_schema::DataType; use bytes::Bytes; use futures::{future::BoxFuture, FutureExt}; -use snafu::{location, Location}; +use snafu::location; use lance_core::{datatypes::BLOB_DESC_FIELDS, Error, Result}; @@ -57,7 +57,7 @@ struct BlobFieldSchedulingJob<'a> { descriptions_job: Box, } -impl<'a> SchedulingJob for BlobFieldSchedulingJob<'a> { +impl SchedulingJob for BlobFieldSchedulingJob<'_> { fn schedule_next( &mut self, context: &mut SchedulerContext, @@ -231,7 +231,6 @@ impl LogicalPageDecoder for BlobFieldDecoder { let validity = self.drain_validity(num_rows as usize)?; self.rows_drained += num_rows; Ok(NextDecodeTask { - has_more: self.rows_drained < self.num_rows, num_rows, task: Box::new(BlobArrayDecodeTask::new(bytes, validity)), }) @@ -371,10 +370,16 @@ impl FieldEncoder for BlobFieldEncoder { external_buffers: &mut OutOfLineBuffers, repdef: RepDefBuilder, row_number: u64, + num_rows: u64, ) -> Result> { let descriptions = Self::write_bins(array, external_buffers)?; - self.description_encoder - .maybe_encode(descriptions, external_buffers, repdef, row_number) + self.description_encoder.maybe_encode( + descriptions, + external_buffers, + repdef, + row_number, + num_rows, + ) } // If there is any data left in the buffer then create an encode task from it diff --git a/rust/lance-encoding/src/encodings/logical/list.rs b/rust/lance-encoding/src/encodings/logical/list.rs index cfe6a7b1435..a525d296556 100644 --- a/rust/lance-encoding/src/encodings/logical/list.rs +++ b/rust/lance-encoding/src/encodings/logical/list.rs @@ -12,8 +12,9 @@ use arrow_array::{ use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, Buffer, NullBuffer, OffsetBuffer}; use arrow_schema::{DataType, Field, Fields}; use futures::{future::BoxFuture, FutureExt}; +use lance_arrow::list::ListArrayExt; use log::trace; -use snafu::{location, Location}; +use snafu::location; use tokio::task::JoinHandle; use lance_core::{cache::FileMetadataCache, Error, Result}; @@ -22,9 +23,11 @@ use crate::{ buffer::LanceBuffer, data::{BlockInfo, DataBlock, FixedWidthDataBlock}, decoder::{ - DecodeArrayTask, DecodeBatchScheduler, FieldScheduler, FilterExpression, ListPriorityRange, - LogicalPageDecoder, MessageType, NextDecodeTask, PageEncoding, PriorityRange, - ScheduledScanLine, SchedulerContext, SchedulingJob, + DecodeArrayTask, DecodeBatchScheduler, DecodedArray, FieldScheduler, FilterExpression, + ListPriorityRange, LogicalPageDecoder, MessageType, NextDecodeTask, PageEncoding, + PriorityRange, ScheduledScanLine, SchedulerContext, SchedulingJob, + StructuralDecodeArrayTask, StructuralFieldDecoder, StructuralFieldScheduler, + StructuralSchedulingJob, }, encoder::{ ArrayEncoder, EncodeTask, EncodedArray, EncodedColumn, EncodedPage, FieldEncoder, @@ -362,7 +365,7 @@ async fn indirect_schedule_task( // Create a new root scheduler, which has one column, which is our items data let root_fields = Fields::from(vec![Field::new("item", items_type, true)]); let indirect_root_scheduler = - SimpleStructScheduler::new(vec![items_scheduler], root_fields.clone()); + SimpleStructScheduler::new(vec![items_scheduler], root_fields.clone(), num_items); let mut indirect_scheduler = DecodeBatchScheduler::from_scheduler( Arc::new(indirect_root_scheduler), root_fields.clone(), @@ -424,7 +427,7 @@ impl<'a> ListFieldSchedulingJob<'a> { } } -impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> { +impl SchedulingJob for ListFieldSchedulingJob<'_> { fn schedule_next( &mut self, context: &mut SchedulerContext, @@ -783,9 +786,7 @@ impl LogicalPageDecoder for ListPageDecoder { }; self.rows_drained += num_rows; - let has_more = self.rows_left() > 0; Ok(NextDecodeTask { - has_more, num_rows, task: Box::new(ListDecodeTask { offsets, @@ -948,13 +949,18 @@ impl ListOffsetsEncoder { fn maybe_encode_offsets_and_validity(&mut self, list_arr: &dyn Array) -> Option { let offsets = Self::extract_offsets(list_arr); let validity = Self::extract_validity(list_arr); + let num_rows = offsets.len() as u64; // Either inserting the offsets OR inserting the validity could cause the // accumulation queue to fill up - if let Some(mut arrays) = self.accumulation_queue.insert(offsets, /*row_number=*/ 0) { + if let Some(mut arrays) = self + .accumulation_queue + .insert(offsets, /*row_number=*/ 0, num_rows) + { arrays.0.push(validity); Some(self.make_encode_task(arrays.0)) - } else if let Some(arrays) = - self.accumulation_queue.insert(validity, /*row_number=*/ 0) + } else if let Some(arrays) = self + .accumulation_queue + .insert(validity, /*row_number=*/ 0, num_rows) { Some(self.make_encode_task(arrays.0)) } else { @@ -1176,6 +1182,7 @@ impl FieldEncoder for ListFieldEncoder { external_buffers: &mut OutOfLineBuffers, repdef: RepDefBuilder, row_number: u64, + num_rows: u64, ) -> Result> { // The list may have an offset / shorter length which means the underlying // values array could be longer than what we need to encode and so we need @@ -1206,9 +1213,13 @@ impl FieldEncoder for ListFieldEncoder { .maybe_encode_offsets_and_validity(array.as_ref()) .map(|task| vec![task]) .unwrap_or_default(); - let mut item_tasks = - self.items_encoder - .maybe_encode(items, external_buffers, repdef, row_number)?; + let mut item_tasks = self.items_encoder.maybe_encode( + items, + external_buffers, + repdef, + row_number, + num_rows, + )?; if !offsets_tasks.is_empty() && item_tasks.is_empty() { // An items page cannot currently be shared by two different offsets pages. This is // a limitation in the current scheduler and could be addressed in the future. As a result @@ -1249,18 +1260,214 @@ impl FieldEncoder for ListFieldEncoder { } } +/// A structural encoder for list fields +/// +/// The list's offsets are added to the rep/def builder +/// and the list array's values are passed to the child encoder +/// +/// The values will have any garbage values removed and will be trimmed +/// to only include the values that are actually used. +pub struct ListStructuralEncoder { + child: Box, +} + +impl ListStructuralEncoder { + pub fn new(child: Box) -> Self { + Self { child } + } +} + +impl FieldEncoder for ListStructuralEncoder { + fn maybe_encode( + &mut self, + array: ArrayRef, + external_buffers: &mut OutOfLineBuffers, + mut repdef: RepDefBuilder, + row_number: u64, + num_rows: u64, + ) -> Result> { + let values = if let Some(list_arr) = array.as_list_opt::() { + let has_garbage_values = + repdef.add_offsets(list_arr.offsets().clone(), array.nulls().cloned()); + if has_garbage_values { + list_arr.filter_garbage_nulls().trimmed_values() + } else { + list_arr.trimmed_values() + } + } else if let Some(list_arr) = array.as_list_opt::() { + let has_garbage_values = + repdef.add_offsets(list_arr.offsets().clone(), array.nulls().cloned()); + if has_garbage_values { + list_arr.filter_garbage_nulls().trimmed_values() + } else { + list_arr.trimmed_values() + } + } else { + panic!("List encoder used for non-list data") + }; + self.child + .maybe_encode(values, external_buffers, repdef, row_number, num_rows) + } + + fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result> { + self.child.flush(external_buffers) + } + + fn num_columns(&self) -> u32 { + self.child.num_columns() + } + + fn finish( + &mut self, + external_buffers: &mut OutOfLineBuffers, + ) -> BoxFuture<'_, Result>> { + self.child.finish(external_buffers) + } +} + +#[derive(Debug)] +pub struct StructuralListScheduler { + child: Box, +} + +impl StructuralListScheduler { + pub fn new(child: Box) -> Self { + Self { child } + } +} + +impl StructuralFieldScheduler for StructuralListScheduler { + fn schedule_ranges<'a>( + &'a self, + ranges: &[Range], + filter: &FilterExpression, + ) -> Result> { + let child = self.child.schedule_ranges(ranges, filter)?; + + Ok(Box::new(StructuralListSchedulingJob::new(child))) + } + + fn initialize<'a>( + &'a mut self, + filter: &'a FilterExpression, + context: &'a SchedulerContext, + ) -> BoxFuture<'a, Result<()>> { + self.child.initialize(filter, context) + } +} + +/// Scheduling job for list data +/// +/// Scheduling is handled by the primitive encoder and nothing special +/// happens here. +#[derive(Debug)] +struct StructuralListSchedulingJob<'a> { + child: Box, +} + +impl<'a> StructuralListSchedulingJob<'a> { + fn new(child: Box) -> Self { + Self { child } + } +} + +impl StructuralSchedulingJob for StructuralListSchedulingJob<'_> { + fn schedule_next( + &mut self, + context: &mut SchedulerContext, + ) -> Result> { + self.child.schedule_next(context) + } +} + +#[derive(Debug)] +pub struct StructuralListDecoder { + child: Box, + data_type: DataType, +} + +impl StructuralListDecoder { + pub fn new(child: Box, data_type: DataType) -> Self { + Self { child, data_type } + } +} + +impl StructuralFieldDecoder for StructuralListDecoder { + fn accept_page(&mut self, child: crate::decoder::LoadedPage) -> Result<()> { + self.child.accept_page(child) + } + + fn drain(&mut self, num_rows: u64) -> Result> { + let child_task = self.child.drain(num_rows)?; + Ok(Box::new(StructuralListDecodeTask::new( + child_task, + self.data_type.clone(), + ))) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[derive(Debug)] +struct StructuralListDecodeTask { + child_task: Box, + data_type: DataType, +} + +impl StructuralListDecodeTask { + fn new(child_task: Box, data_type: DataType) -> Self { + Self { + child_task, + data_type, + } + } +} + +impl StructuralDecodeArrayTask for StructuralListDecodeTask { + fn decode(self: Box) -> Result { + let DecodedArray { array, mut repdef } = self.child_task.decode()?; + match &self.data_type { + DataType::List(child_field) => { + let (offsets, validity) = repdef.unravel_offsets::()?; + let list_array = ListArray::try_new(child_field.clone(), offsets, array, validity)?; + Ok(DecodedArray { + array: Arc::new(list_array), + repdef, + }) + } + DataType::LargeList(child_field) => { + let (offsets, validity) = repdef.unravel_offsets::()?; + let list_array = + LargeListArray::try_new(child_field.clone(), offsets, array, validity)?; + Ok(DecodedArray { + array: Arc::new(list_array), + repdef, + }) + } + _ => panic!("List decoder did not have a list field"), + } + } +} + #[cfg(test)] mod tests { use std::{collections::HashMap, sync::Arc}; - use arrow::array::{LargeListBuilder, StringBuilder}; + use arrow::array::{Int64Builder, LargeListBuilder, StringBuilder}; use arrow_array::{ builder::{Int32Builder, ListBuilder}, - Array, ArrayRef, BooleanArray, ListArray, StructArray, UInt64Array, + Array, ArrayRef, BooleanArray, DictionaryArray, LargeStringArray, ListArray, StructArray, + UInt64Array, UInt8Array, }; - use arrow_buffer::{OffsetBuffer, ScalarBuffer}; + use arrow_buffer::{BooleanBuffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_schema::{DataType, Field, Fields}; + use lance_core::datatypes::{ + STRUCTURAL_ENCODING_FULLZIP, STRUCTURAL_ENCODING_META_KEY, STRUCTURAL_ENCODING_MINIBLOCK, + }; + use rstest::rstest; use crate::{ testing::{check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases}, @@ -1275,10 +1482,39 @@ mod tests { DataType::LargeList(Arc::new(Field::new("item", inner_type, true))) } + #[rstest] #[test_log::test(tokio::test)] - async fn test_list() { - let field = Field::new("", make_list_type(DataType::Int32), true); - check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await; + async fn test_list( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + let field = + Field::new("", make_list_type(DataType::Int32), true).with_metadata(field_metadata); + check_round_trip_encoding_random(field, version).await; + } + + #[rstest] + #[test_log::test(tokio::test)] + async fn test_deeply_nested_lists( + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + let field = Field::new("item", DataType::Int32, true).with_metadata(field_metadata); + for _ in 0..5 { + let field = Field::new("", make_list_type(field.data_type().clone()), true); + check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await; + } } #[test_log::test(tokio::test)] @@ -1332,8 +1568,13 @@ mod tests { .await; } + #[rstest] #[test_log::test(tokio::test)] - async fn test_simple_list() { + async fn test_simple_list( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { let items_builder = Int32Builder::new(); let mut list_builder = ListBuilder::new(items_builder); list_builder.append_value([Some(1), Some(2), Some(3)]); @@ -1342,15 +1583,237 @@ mod tests { list_builder.append_value([Some(6), Some(7), Some(8)]); let list_array = list_builder.finish(); + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + let test_cases = TestCases::default() .with_range(0..2) .with_range(0..3) .with_range(1..3) - .with_indices(vec![1, 3]); - check_round_trip_encoding_of_data(vec![Arc::new(list_array)], &test_cases, HashMap::new()) + .with_indices(vec![1, 3]) + .with_indices(vec![2]) + .with_file_version(version); + check_round_trip_encoding_of_data(vec![Arc::new(list_array)], &test_cases, field_metadata) + .await; + } + + #[rstest] + #[test_log::test(tokio::test)] + async fn test_simple_nested_list_ends_with_null( + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + use arrow_array::Int32Array; + + let values = Int32Array::from(vec![1, 2, 3, 4, 5]); + let inner_offsets = ScalarBuffer::::from(vec![0, 1, 2, 3, 4, 5, 5]); + let inner_validity = BooleanBuffer::from(vec![true, true, true, true, true, false]); + let outer_offsets = ScalarBuffer::::from(vec![0, 1, 2, 3, 4, 5, 6, 6]); + let outer_validity = BooleanBuffer::from(vec![true, true, true, true, true, true, false]); + + let inner_list = ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::new(inner_offsets), + Arc::new(values), + Some(NullBuffer::new(inner_validity)), + ); + let outer_list = ListArray::new( + Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + OffsetBuffer::new(outer_offsets), + Arc::new(inner_list), + Some(NullBuffer::new(outer_validity)), + ); + + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(0..3) + .with_range(5..7) + .with_indices(vec![1, 6]) + .with_indices(vec![6]) + .with_file_version(LanceFileVersion::V2_1); + check_round_trip_encoding_of_data(vec![Arc::new(outer_list)], &test_cases, field_metadata) .await; } + #[rstest] + #[test_log::test(tokio::test)] + async fn test_simple_string_list( + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + let items_builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(items_builder); + list_builder.append_value([Some("a"), Some("bc"), Some("def")]); + list_builder.append_value([Some("gh"), None]); + list_builder.append_null(); + list_builder.append_value([Some("ijk"), Some("lmnop"), Some("qrs")]); + let list_array = list_builder.finish(); + + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(0..3) + .with_range(1..3) + .with_indices(vec![1, 3]) + .with_indices(vec![2]) + .with_file_version(LanceFileVersion::V2_1); + check_round_trip_encoding_of_data(vec![Arc::new(list_array)], &test_cases, field_metadata) + .await; + } + + #[rstest] + #[test_log::test(tokio::test)] + async fn test_simple_sliced_list( + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + let items_builder = Int32Builder::new(); + let mut list_builder = ListBuilder::new(items_builder); + list_builder.append_value([Some(1), Some(2), Some(3)]); + list_builder.append_value([Some(4), Some(5)]); + list_builder.append_null(); + list_builder.append_value([Some(6), Some(7), Some(8)]); + let list_array = list_builder.finish(); + + let list_array = list_array.slice(1, 2); + + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(1..2) + .with_indices(vec![0]) + .with_indices(vec![1]) + .with_file_version(LanceFileVersion::V2_1); + check_round_trip_encoding_of_data(vec![Arc::new(list_array)], &test_cases, field_metadata) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_simple_list_dict() { + let values = LargeStringArray::from_iter_values(["a", "bb", "ccc"]); + let indices = UInt8Array::from(vec![0, 1, 2, 0, 1, 2, 0, 1, 2]); + let dict_array = DictionaryArray::new(indices, Arc::new(values)); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 5, 6, 9])); + let list_array = ListArray::new( + Arc::new(Field::new("item", dict_array.data_type().clone(), true)), + offsets, + Arc::new(dict_array), + None, + ); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(1..3) + .with_range(2..4) + .with_indices(vec![1]) + .with_indices(vec![2]); + check_round_trip_encoding_of_data( + vec![Arc::new(list_array)], + &test_cases, + HashMap::default(), + ) + .await; + } + + #[rstest] + #[test_log::test(tokio::test)] + async fn test_list_with_garbage_nulls( + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + // In Arrow, list nulls are allowed to be non-empty, with masked garbage values + // Here we make a list with a null row in the middle with 3 garbage values + let items = UInt64Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let offsets = ScalarBuffer::::from(vec![0, 5, 8, 10]); + let offsets = OffsetBuffer::new(offsets); + let list_validity = NullBuffer::new(BooleanBuffer::from(vec![true, false, true])); + let list_arr = ListArray::new( + Arc::new(Field::new("item", DataType::UInt64, true)), + offsets, + Arc::new(items), + Some(list_validity), + ); + + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_range(1..2) + .with_indices(vec![1]) + .with_indices(vec![2]) + .with_file_version(LanceFileVersion::V2_1); + check_round_trip_encoding_of_data(vec![Arc::new(list_arr)], &test_cases, field_metadata) + .await; + } + + #[rstest] + #[test_log::test(tokio::test)] + async fn test_simple_two_page_list( + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + // This is a simple pre-defined list that spans two pages. This test is useful for + // debugging the repetition index + let items_builder = Int64Builder::new(); + let mut list_builder = ListBuilder::new(items_builder); + for i in 0..512 { + list_builder.append_value([Some(i), Some(i * 2)]); + } + let list_array_1 = list_builder.finish(); + + let items_builder = Int64Builder::new(); + let mut list_builder = ListBuilder::new(items_builder); + for i in 0..512 { + let i = i + 512; + list_builder.append_value([Some(i), Some(i * 2)]); + } + let list_array_2 = list_builder.finish(); + + let mut metadata = HashMap::new(); + metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + + let test_cases = TestCases::default() + .with_file_version(LanceFileVersion::V2_1) + .with_page_sizes(vec![100]) + .with_range(800..900); + check_round_trip_encoding_of_data( + vec![Arc::new(list_array_1), Arc::new(list_array_2)], + &test_cases, + metadata, + ) + .await; + } + #[test_log::test(tokio::test)] async fn test_simple_large_list() { let items_builder = Int32Builder::new(); diff --git a/rust/lance-encoding/src/encodings/logical/primitive.rs b/rust/lance-encoding/src/encodings/logical/primitive.rs index e73cd2b282c..ab4550291c2 100644 --- a/rust/lance-encoding/src/encodings/logical/primitive.rs +++ b/rust/lance-encoding/src/encodings/logical/primitive.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::{ + any::Any, collections::{HashMap, VecDeque}, fmt::Debug, iter, @@ -12,20 +13,37 @@ use std::{ use arrow::array::AsArray; use arrow_array::{make_array, types::UInt64Type, Array, ArrayRef, PrimitiveArray}; -use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer}; +use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, ScalarBuffer}; use arrow_schema::{DataType, Field as ArrowField}; -use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, TryStreamExt}; +use futures::{future::BoxFuture, stream::FuturesOrdered, FutureExt, TryStreamExt}; +use itertools::Itertools; use lance_arrow::deepcopy::deep_copy_array; -use lance_core::utils::bit::pad_bytes; -use lance_core::utils::hash::U8SliceKey; +use lance_core::{ + cache::{Context, DeepSizeOf}, + datatypes::{ + STRUCTURAL_ENCODING_FULLZIP, STRUCTURAL_ENCODING_META_KEY, STRUCTURAL_ENCODING_MINIBLOCK, + }, + error::Error, + utils::bit::pad_bytes, + utils::hash::U8SliceKey, +}; use log::{debug, trace}; -use snafu::{location, Location}; +use snafu::location; -use crate::data::{AllNullDataBlock, DataBlock, VariableWidthBlock}; -use crate::decoder::PerValueDecompressor; -use crate::encoder::PerValueDataBlock; -use crate::repdef::{build_control_word_iterator, ControlWordIterator, ControlWordParser}; +use crate::repdef::{ + build_control_word_iterator, CompositeRepDefUnraveler, ControlWordIterator, ControlWordParser, + DefinitionInterpretation, RepDefSlicer, +}; use crate::statistics::{ComputeStat, GetStat, Stat}; +use crate::utils::bytepack::ByteUnpacker; +use crate::{ + data::{AllNullDataBlock, DataBlock, VariableWidthBlock}, + utils::bytepack::BytepackedIntegerEncoder, +}; +use crate::{ + decoder::{FixedPerValueDecompressor, VariablePerValueDecompressor}, + encoder::PerValueDataBlock, +}; use lance_core::{datatypes::Field, utils::tokio::spawn_cpu, Result}; use crate::{ @@ -49,6 +67,8 @@ use crate::{ EncodingsIo, }; +const FILL_BYTE: u8 = 0xFE; + #[derive(Debug)] struct PrimitivePage { scheduler: Box, @@ -141,7 +161,7 @@ impl<'a> PrimitiveFieldSchedulingJob<'a> { } } -impl<'a> SchedulingJob for PrimitiveFieldSchedulingJob<'a> { +impl SchedulingJob for PrimitiveFieldSchedulingJob<'_> { fn schedule_next( &mut self, context: &mut SchedulerContext, @@ -264,12 +284,17 @@ impl FieldScheduler for PrimitiveFieldScheduler { /// a single page. trait StructuralPageScheduler: std::fmt::Debug + Send { /// Fetches any metadata required for the page - fn initialize<'a>(&'a mut self, io: &Arc) -> BoxFuture<'a, Result<()>>; + fn initialize<'a>( + &'a mut self, + io: &Arc, + ) -> BoxFuture<'a, Result>>; + /// Loads metadata from a previous initialize call + fn load(&mut self, data: &Arc); /// Schedules the read of the given ranges in the page fn schedule_ranges( &self, ranges: &[Range], - io: &dyn EncodingsIo, + io: &Arc, ) -> Result>>>; } @@ -278,6 +303,15 @@ trait StructuralPageScheduler: std::fmt::Debug + Send { struct ChunkMeta { num_values: u64, chunk_size_bytes: u64, + offset_bytes: u64, +} + +/// A mini-block chunk that has been decoded and decompressed +#[derive(Debug)] +struct DecodedMiniBlockChunk { + rep: Option>, + def: Option>, + values: DataBlock, } /// A task to decode a one or more mini-blocks of data into an output batch @@ -289,44 +323,27 @@ struct ChunkMeta { /// the decoding of the block. (TODO: test this theory) #[derive(Debug)] struct DecodeMiniBlockTask { - // The decompressors for the rep, def, and value buffers - rep_decompressor: Arc, - def_decompressor: Arc, + rep_decompressor: Option>, + def_decompressor: Option>, value_decompressor: Arc, dictionary_data: Option>, - // The mini-blocks to decode - // - // For each mini-block we also have the ranges of rows that we want to decode - // from that mini-block. For example, if the user asks for rows 10, 10000, and 20000 - // then we will have three chunks here and each chunk will have a small range of 1 row. - chunks: Vec, - // The offset into the first chunk that we want to start decoding from - offset_into_first_chunk: u64, - // The total number of rows that we are decoding - num_rows: u64, + def_meaning: Arc<[DefinitionInterpretation]>, + num_buffers: u64, + max_visible_level: u16, + instructions: Vec<(ChunkDrainInstructions, LoadedChunk)>, } impl DecodeMiniBlockTask { fn decode_levels( rep_decompressor: &dyn BlockDecompressor, levels: LanceBuffer, - ) -> Result>> { - let rep = rep_decompressor.decompress(levels)?; - match rep { - DataBlock::FixedWidth(mut rep) => Ok(Some(rep.data.borrow_to_typed_slice::())), - DataBlock::Constant(constant) => { - assert_eq!(constant.data.len(), 2); - if constant.data[0] == 0 && constant.data[1] == 0 { - Ok(None) - } else { - // Maybe in the future we will encode all-null def or - // constant rep (all 1-item lists?) in a constant encoding - // but that doesn't happen today so we don't need to worry. - todo!() - } - } - _ => unreachable!(), - } + num_levels: u16, + ) -> Result> { + let rep = rep_decompressor.decompress(levels, num_levels as u64)?; + let mut rep = rep.as_fixed_width().unwrap(); + debug_assert_eq!(rep.num_values, num_levels as u64); + debug_assert_eq!(rep.bits_per_value, 16); + Ok(rep.data.borrow_to_typed_slice::()) } // We are building a LevelBuffer (levels) and want to copy into it `total_len` @@ -336,7 +353,6 @@ impl DecodeMiniBlockTask { // yet) and the case where `level_buf` is None (the input we are copying from has // no nulls) fn extend_levels( - offset: usize, range: Range, levels: &mut Option, level_buf: &Option>, @@ -347,8 +363,8 @@ impl DecodeMiniBlockTask { // This is the first non-empty def buf we've hit, fill in the past // with 0 (valid) let mut new_levels_vec = - LevelBuffer::with_capacity(offset + (range.end - range.start) as usize); - new_levels_vec.extend(iter::repeat(0).take(dest_offset)); + LevelBuffer::with_capacity(dest_offset + (range.end - range.start) as usize); + new_levels_vec.extend(iter::repeat_n(0, dest_offset)); *levels = Some(new_levels_vec); } levels.as_mut().unwrap().extend( @@ -360,9 +376,314 @@ impl DecodeMiniBlockTask { let num_values = (range.end - range.start) as usize; // This is an all-valid level_buf but we had nulls earlier and so we // need to materialize it - levels.extend(iter::repeat(0).take(num_values)); + levels.extend(iter::repeat_n(0, num_values)); + } + } + + /// Maps a range of rows to a range of items and a range of levels + /// + /// If there is no repetition information this just returns the range as-is. + /// + /// If there is repetition information then we need to do some work to figure out what + /// range of items corresponds to the requested range of rows. + /// + /// For example, if the data is [[1, 2, 3], [4, 5], [6, 7]] and the range is 1..2 (i.e. just row + /// 1) then the user actually wants items 3..5. In the above case the rep levels would be: + /// + /// Idx: 0 1 2 3 4 5 6 + /// Rep: 1 0 0 1 0 1 0 + /// + /// So the start (1) maps to the second 1 (idx=3) and the end (2) maps to the third 1 (idx=5) + /// + /// If there are invisible items then we don't count them when calcuating the range of items we + /// are interested in but we do count them when calculating the range of levels we are interested + /// in. As a result we have to return both the item range (first return value) and the level range + /// (second return value). + /// + /// For example, if the data is [[1, 2, 3], [4, 5], NULL, [6, 7, 8]] and the range is 2..4 then the + /// user wants items 5..8 but they want levels 5..9. In the above case the rep/def levels would be: + /// + /// Idx: 0 1 2 3 4 5 6 7 8 + /// Rep: 1 0 0 1 0 1 1 0 0 + /// Def: 0 0 0 0 0 1 0 0 0 + /// Itm: 1 2 3 4 5 6 7 8 + /// + /// Finally, we have to contend with the fact that chunks may or may not start with a "preamble" of + /// trailing values that finish up a list from the previous chunk. In this case the first item does + /// not start at max_rep because it is a continuation of the previous chunk. For our purposes we do + /// not consider this a "row" and so the range 0..1 will refer to the first row AFTER the preamble. + /// + /// We have a separate parameter (`preamble_action`) to control whether we want the preamble or not. + /// + /// Note that the "trailer" is considered a "row" and if we want it we should include it in the range. + fn map_range( + range: Range, + rep: Option<&impl AsRef<[u16]>>, + def: Option<&impl AsRef<[u16]>>, + max_rep: u16, + max_visible_def: u16, + // The total number of items (not rows) in the chunk. This is not quite the same as + // rep.len() / def.len() because it doesn't count invisible items + total_items: u64, + preamble_action: PreambleAction, + ) -> (Range, Range) { + if let Some(rep) = rep { + let mut rep = rep.as_ref(); + // If there is a preamble and we need to skip it then do that first. The work is the same + // whether there is def information or not + let mut items_in_preamble = 0; + let first_row_start = match preamble_action { + PreambleAction::Skip | PreambleAction::Take => { + let first_row_start = if let Some(def) = def.as_ref() { + let mut first_row_start = None; + for (idx, (rep, def)) in rep.iter().zip(def.as_ref()).enumerate() { + if *rep == max_rep { + first_row_start = Some(idx); + break; + } + if *def <= max_visible_def { + items_in_preamble += 1; + } + } + first_row_start + } else { + let first_row_start = rep.iter().position(|&r| r == max_rep); + items_in_preamble = first_row_start.unwrap_or(rep.len()); + first_row_start + }; + // It is possible for a chunk to be entirely partial values but if it is then it + // should never show up as a preamble to skip + if first_row_start.is_none() { + assert!(preamble_action == PreambleAction::Take); + return (0..total_items, 0..rep.len() as u64); + } + let first_row_start = first_row_start.unwrap() as u64; + rep = &rep[first_row_start as usize..]; + first_row_start + } + PreambleAction::Absent => { + debug_assert!(rep[0] == max_rep); + 0 + } + }; + + // We hit this case when all we needed was the preamble + if range.start == range.end { + debug_assert!(preamble_action == PreambleAction::Take); + return (0..items_in_preamble as u64, 0..first_row_start); + } + assert!(range.start < range.end); + + let mut rows_seen = 0; + let mut new_start = 0; + let mut new_levels_start = 0; + + if let Some(def) = def { + let def = &def.as_ref()[first_row_start as usize..]; + + // range.start == 0 always maps to 0 (even with invis items), otherwise we need to walk + let mut lead_invis_seen = 0; + + if range.start > 0 { + if def[0] > max_visible_def { + lead_invis_seen += 1; + } + for (idx, (rep, def)) in rep.iter().zip(def).skip(1).enumerate() { + if *rep == max_rep { + rows_seen += 1; + if rows_seen == range.start { + new_start = idx as u64 + 1 - lead_invis_seen; + new_levels_start = idx as u64 + 1; + break; + } + if *def > max_visible_def { + lead_invis_seen += 1; + } + } + } + } + + rows_seen += 1; + + let mut new_end = u64::MAX; + let mut new_levels_end = rep.len() as u64; + let new_start_is_visible = def[new_levels_start as usize] <= max_visible_def; + let mut tail_invis_seen = if new_start_is_visible { 0 } else { 1 }; + for (idx, (rep, def)) in rep[(new_levels_start + 1) as usize..] + .iter() + .zip(&def[(new_levels_start + 1) as usize..]) + .enumerate() + { + if *rep == max_rep { + rows_seen += 1; + if rows_seen == range.end + 1 { + new_end = idx as u64 + new_start + 1 - tail_invis_seen; + new_levels_end = idx as u64 + new_levels_start + 1; + break; + } + if *def > max_visible_def { + tail_invis_seen += 1; + } + } + } + + if new_end == u64::MAX { + new_levels_end = rep.len() as u64; + // This is the total number of visible items (minus any items in the preamble) + let total_invis_seen = lead_invis_seen + tail_invis_seen; + new_end = rep.len() as u64 - total_invis_seen; + } + + assert_ne!(new_end, u64::MAX); + + // Adjust for any skipped preamble + if preamble_action == PreambleAction::Skip { + // TODO: Should this be items_in_preamble? If so, add a + // unit test for this case + new_start += first_row_start; + new_end += first_row_start; + new_levels_start += first_row_start; + new_levels_end += first_row_start; + } else if preamble_action == PreambleAction::Take { + debug_assert_eq!(new_start, 0); + debug_assert_eq!(new_levels_start, 0); + new_end += first_row_start; + new_levels_end += first_row_start; + } + + (new_start..new_end, new_levels_start..new_levels_end) + } else { + // Easy case, there are no invisible items, so we don't need to check for them + // The items range and levels range will be the same. We do still need to walk + // the rep levels to find the row boundaries + + // range.start == 0 always maps to 0, otherwise we need to walk + if range.start > 0 { + for (idx, rep) in rep.iter().skip(1).enumerate() { + if *rep == max_rep { + rows_seen += 1; + if rows_seen == range.start { + new_start = idx as u64 + 1; + break; + } + } + } + } + let mut new_end = rep.len() as u64; + // range.end == max_items always maps to rep.len(), otherwise we need to walk + if range.end < total_items { + for (idx, rep) in rep[(new_start + 1) as usize..].iter().enumerate() { + if *rep == max_rep { + rows_seen += 1; + if rows_seen == range.end { + new_end = idx as u64 + new_start + 1; + break; + } + } + } + } + + // Adjust for any skipped preamble + if preamble_action == PreambleAction::Skip { + new_start += first_row_start; + new_end += first_row_start; + } else if preamble_action == PreambleAction::Take { + debug_assert_eq!(new_start, 0); + new_end += first_row_start; + } + + (new_start..new_end, new_start..new_end) + } + } else { + // No repetition info, easy case, just use the range as-is and the item + // and level ranges are the same + (range.clone(), range) } } + + // Unserialize a miniblock into a collection of vectors + fn decode_miniblock_chunk( + &self, + buf: &LanceBuffer, + items_in_chunk: u64, + ) -> Result { + let mut offset = 0; + let num_levels = u16::from_le_bytes([buf[offset], buf[offset + 1]]); + offset += 2; + + let rep_size = if self.rep_decompressor.is_some() { + let rep_size = u16::from_le_bytes([buf[offset], buf[offset + 1]]); + offset += 2; + Some(rep_size) + } else { + None + }; + let def_size = if self.def_decompressor.is_some() { + let def_size = u16::from_le_bytes([buf[offset], buf[offset + 1]]); + offset += 2; + Some(def_size) + } else { + None + }; + let buffer_sizes = (0..self.num_buffers) + .map(|_| { + let size = u16::from_le_bytes([buf[offset], buf[offset + 1]]); + offset += 2; + size + }) + .collect::>(); + + offset += pad_bytes::(offset); + + let rep = rep_size.map(|rep_size| { + let rep = buf.slice_with_length(offset, rep_size as usize); + offset += rep_size as usize; + offset += pad_bytes::(offset); + rep + }); + + let def = def_size.map(|def_size| { + let def = buf.slice_with_length(offset, def_size as usize); + offset += def_size as usize; + offset += pad_bytes::(offset); + def + }); + + let buffers = buffer_sizes + .into_iter() + .map(|buf_size| { + let buf = buf.slice_with_length(offset, buf_size as usize); + offset += buf_size as usize; + offset += pad_bytes::(offset); + buf + }) + .collect::>(); + + let values = self + .value_decompressor + .decompress(buffers, items_in_chunk)?; + + let rep = rep + .map(|rep| { + Self::decode_levels( + self.rep_decompressor.as_ref().unwrap().as_ref(), + rep, + num_levels, + ) + }) + .transpose()?; + let def = def + .map(|def| { + Self::decode_levels( + self.def_decompressor.as_ref().unwrap().as_ref(), + def, + num_levels, + ) + }) + .transpose()?; + + Ok(DecodedMiniBlockChunk { rep, def, values }) + } } impl DecodePageTask for DecodeMiniBlockTask { @@ -370,104 +691,56 @@ impl DecodePageTask for DecodeMiniBlockTask { // First, we create output buffers for the rep and def and data let mut repbuf: Option = None; let mut defbuf: Option = None; - let rep_decompressor = self.rep_decompressor; - let def_decompressor = self.def_decompressor; - let mut remaining = self.num_rows; + let max_rep = self.def_meaning.iter().filter(|l| l.is_list()).count() as u16; + + // This is probably an over-estimate but it's quick and easy to calculate let estimated_size_bytes = self - .chunks + .instructions .iter() - .map(|chunk| chunk.data.len()) + .map(|(_, chunk)| chunk.data.len()) .sum::() * 2; let mut data_builder = DataBlockBuilder::with_capacity_estimate(estimated_size_bytes as u64); - let mut to_skip = self.offset_into_first_chunk; + // We need to keep track of the offset into repbuf/defbuf that we are building up let mut level_offset = 0; - // Now we iterate through each chunk and decode the data into the output buffers - for chunk in self.chunks.into_iter() { - // We always decode the entire chunk - let buf = chunk.data.into_buffer(); - // The first 6 bytes describe the size of the remaining buffers - let bytes_rep = u16::from_le_bytes([buf[0], buf[1]]) as usize; - let bytes_def = u16::from_le_bytes([buf[2], buf[3]]) as usize; - let bytes_val = u16::from_le_bytes([buf[4], buf[5]]) as usize; - - debug_assert!(buf.len() >= bytes_rep + bytes_def + bytes_val + 6); - debug_assert!( - buf.len() - <= bytes_rep - + bytes_def - + bytes_val - + 6 - + 1 // P1 - + (2 * MINIBLOCK_MAX_PADDING) // P2/P3 + // Now we iterate through each instruction and process it + for (instructions, chunk) in self.instructions.iter() { + // TODO: It's very possible that we have duplicate `buf` in self.instructions and we + // don't want to decode the buf again and again on the same thread. + + let DecodedMiniBlockChunk { rep, def, values } = + self.decode_miniblock_chunk(&chunk.data, chunk.items_in_chunk)?; + + // Our instructions tell us which rows we want to take from this chunk + let row_range_start = + instructions.rows_to_skip + instructions.chunk_instructions.rows_to_skip; + let row_range_end = row_range_start + instructions.rows_to_take; + + // We use the rep info to map the row range to an item range / levels range + let (item_range, level_range) = Self::map_range( + row_range_start..row_range_end, + rep.as_ref(), + def.as_ref(), + max_rep, + self.max_visible_level, + chunk.items_in_chunk, + instructions.preamble_action, ); - let p1 = bytes_rep % 2; - let rep = buf.slice_with_length(6, bytes_rep); - let def = buf.slice_with_length(6 + bytes_rep + p1, bytes_def); - let p2 = pad_bytes::(6 + bytes_rep + p1 + bytes_def); - let values = buf.slice_with_length(6 + bytes_rep + bytes_def + p2, bytes_val); - - let values = self - .value_decompressor - .decompress(LanceBuffer::Borrowed(values), chunk.vals_in_chunk)?; - - let rep = Self::decode_levels(rep_decompressor.as_ref(), LanceBuffer::Borrowed(rep))?; - let def = Self::decode_levels(def_decompressor.as_ref(), LanceBuffer::Borrowed(def))?; - - // We've decoded the entire block. Now we need to factor in: - // - The offset into the first chunk - // - The ranges the user asked for - // - The total # of rows in this task - // - // From these we can figure out which values to keep. - // - // For example, maybe we've are asked to decode 100 rows, with an offset of 50, from - // a block with 1024 values, and the user asked for the ranges 400..500 and 600..700 - // - // In this case we want to take the values 450..500 and 600..650 from the block. - let mut offset = to_skip; - for range in chunk.ranges { - if to_skip > range.end - range.start { - to_skip -= range.end - range.start; - continue; - } - // Subtract skip from start of range - let range = range.start + to_skip..range.end; - to_skip = 0; - // Truncate range to fit remaining - let range_len = range.end - range.start; - let to_take = range_len.min(remaining); - let range = range.start..range.start + to_take; - - // Grab values and add to what we are building - Self::extend_levels( - offset as usize, - range.clone(), - &mut repbuf, - &rep, - level_offset, - ); - Self::extend_levels( - offset as usize, - range.clone(), - &mut defbuf, - &def, - level_offset, - ); - data_builder.append(&values, range); - remaining -= to_take; - offset += to_take; - level_offset += to_take as usize; - } + // Now we append the data to the output buffers + Self::extend_levels(level_range.clone(), &mut repbuf, &rep, level_offset); + Self::extend_levels(level_range.clone(), &mut defbuf, &def, level_offset); + level_offset += (level_range.end - level_range.start) as usize; + data_builder.append(&values, item_range); } - debug_assert_eq!(remaining, 0); let data = data_builder.finish(); + let unraveler = RepDefUnraveler::new(repbuf, defbuf, self.def_meaning.clone()); + // if dictionary encoding is applied, do dictionary decode here. if let Some(dictionary) = &self.dictionary_data { // assume the indices are uniformly distributed. @@ -488,63 +761,273 @@ impl DecodePageTask for DecodeMiniBlockTask { let data = data_builder.finish(); return Ok(DecodedPage { data, - repetition: repbuf, - definition: defbuf, + repdef: unraveler, }); } } Ok(DecodedPage { data, - repetition: repbuf, - definition: defbuf, + repdef: unraveler, }) } } +/// A chunk that has been loaded by the miniblock scheduler (but not +/// yet decoded) +#[derive(Debug)] +struct LoadedChunk { + data: LanceBuffer, + items_in_chunk: u64, + byte_range: Range, + chunk_idx: usize, +} + +impl Clone for LoadedChunk { + fn clone(&self) -> Self { + Self { + // Safe as we always create borrowed buffers here + data: self.data.try_clone().unwrap(), + items_in_chunk: self.items_in_chunk, + byte_range: self.byte_range.clone(), + chunk_idx: self.chunk_idx, + } + } +} + /// Decodes mini-block formatted data. See [`PrimitiveStructuralEncoder`] for more /// details on the different layouts. #[derive(Debug)] struct MiniBlockDecoder { - rep_decompressor: Arc, - def_decompressor: Arc, + rep_decompressor: Option>, + def_decompressor: Option>, value_decompressor: Arc, - data: VecDeque, + def_meaning: Arc<[DefinitionInterpretation]>, + loaded_chunks: VecDeque, + instructions: VecDeque, offset_in_current_chunk: u64, num_rows: u64, + num_buffers: u64, dictionary: Option>, } +/// See [`MiniBlockScheduler`] for more details on the scheduling and decoding +/// process for miniblock encoded data. impl StructuralPageDecoder for MiniBlockDecoder { fn drain(&mut self, num_rows: u64) -> Result> { - let mut remaining = num_rows; - let mut chunks = Vec::new(); - let offset_into_first_chunk = self.offset_in_current_chunk; - while remaining > 0 { - if remaining >= self.data.front().unwrap().vals_targeted - self.offset_in_current_chunk + let mut items_desired = num_rows; + let mut need_preamble = false; + let mut skip_in_chunk = self.offset_in_current_chunk; + let mut drain_instructions = Vec::new(); + while items_desired > 0 || need_preamble { + let (instructions, consumed) = self + .instructions + .front() + .unwrap() + .drain_from_instruction(&mut items_desired, &mut need_preamble, &mut skip_in_chunk); + + while self.loaded_chunks.front().unwrap().chunk_idx + != instructions.chunk_instructions.chunk_idx { - // We are fully consuming the next chunk - let chunk = self.data.pop_front().unwrap(); - remaining -= chunk.vals_targeted - self.offset_in_current_chunk; - chunks.push(chunk); - self.offset_in_current_chunk = 0; - } else { - // We are partially consuming the next chunk - let chunk = self.data.front().unwrap().clone(); - self.offset_in_current_chunk += remaining; - debug_assert!(self.offset_in_current_chunk > 0); - remaining = 0; - chunks.push(chunk); + self.loaded_chunks.pop_front(); + } + drain_instructions.push((instructions, self.loaded_chunks.front().unwrap().clone())); + if consumed { + self.instructions.pop_front(); } } + // We can throw away need_preamble here because it must be false. If it were true it would mean + // we were still in the middle of loading rows. We do need to latch skip_in_chunk though. + self.offset_in_current_chunk = skip_in_chunk; + + let max_visible_level = self + .def_meaning + .iter() + .take_while(|l| !l.is_list()) + .map(|l| l.num_def_levels()) + .sum::(); + Ok(Box::new(DecodeMiniBlockTask { - chunks, - rep_decompressor: self.rep_decompressor.clone(), + instructions: drain_instructions, def_decompressor: self.def_decompressor.clone(), + rep_decompressor: self.rep_decompressor.clone(), value_decompressor: self.value_decompressor.clone(), dictionary_data: self.dictionary.clone(), + def_meaning: self.def_meaning.clone(), + num_buffers: self.num_buffers, + max_visible_level, + })) + } + + fn num_rows(&self) -> u64 { + self.num_rows + } +} + +#[derive(Debug)] +struct CachedComplexAllNullState { + rep: Option>, + def: Option>, +} + +impl DeepSizeOf for CachedComplexAllNullState { + fn deep_size_of_children(&self, _ctx: &mut Context) -> usize { + self.rep.as_ref().map(|buf| buf.len() * 2).unwrap_or(0) + + self.def.as_ref().map(|buf| buf.len() * 2).unwrap_or(0) + } +} + +impl CachedPageData for CachedComplexAllNullState { + fn as_arc_any(self: Arc) -> Arc { + self + } +} + +/// A scheduler for all-null data that has repetition and definition levels +/// +/// We still need to do some I/O in this case because we need to figure out what kind of null we +/// are dealing with (null list, null struct, what level null struct, etc.) +/// +/// TODO: Right now we just load the entire rep/def at initialization time and cache it. This is a touch +/// RAM aggressive and maybe we want something more lazy in the future. On the other hand, it's simple +/// and fast so...maybe not :) +#[derive(Debug)] +pub struct ComplexAllNullScheduler { + // Set from protobuf + buffer_offsets_and_sizes: Arc<[(u64, u64)]>, + def_meaning: Arc<[DefinitionInterpretation]>, + repdef: Option>, +} + +impl ComplexAllNullScheduler { + pub fn new( + buffer_offsets_and_sizes: Arc<[(u64, u64)]>, + def_meaning: Arc<[DefinitionInterpretation]>, + ) -> Self { + Self { + buffer_offsets_and_sizes, + def_meaning, + repdef: None, + } + } +} + +impl StructuralPageScheduler for ComplexAllNullScheduler { + fn initialize<'a>( + &'a mut self, + io: &Arc, + ) -> BoxFuture<'a, Result>> { + // Fully load the rep & def buffers, as needed + let (rep_pos, rep_size) = self.buffer_offsets_and_sizes[0]; + let (def_pos, def_size) = self.buffer_offsets_and_sizes[1]; + let has_rep = rep_size > 0; + let has_def = def_size > 0; + + let mut reads = Vec::with_capacity(2); + if has_rep { + reads.push(rep_pos..rep_pos + rep_size); + } + if has_def { + reads.push(def_pos..def_pos + def_size); + } + + let data = io.submit_request(reads, 0); + + async move { + let data = data.await?; + let mut data_iter = data.into_iter(); + + let rep = if has_rep { + let rep = data_iter.next().unwrap(); + let mut rep = LanceBuffer::from_bytes(rep, 2); + let rep = rep.borrow_to_typed_slice::(); + Some(rep) + } else { + None + }; + + let def = if has_def { + let def = data_iter.next().unwrap(); + let mut def = LanceBuffer::from_bytes(def, 2); + let def = def.borrow_to_typed_slice::(); + Some(def) + } else { + None + }; + + let repdef = Arc::new(CachedComplexAllNullState { rep, def }); + + self.repdef = Some(repdef.clone()); + + Ok(repdef as Arc) + } + .boxed() + } + + fn load(&mut self, data: &Arc) { + self.repdef = Some( + data.clone() + .as_arc_any() + .downcast::() + .unwrap(), + ); + } + + fn schedule_ranges( + &self, + ranges: &[Range], + _io: &Arc, + ) -> Result>>> { + let ranges = VecDeque::from_iter(ranges.iter().cloned()); + let num_rows = ranges.iter().map(|r| r.end - r.start).sum::(); + Ok(std::future::ready(Ok(Box::new(ComplexAllNullPageDecoder { + ranges, + rep: self.repdef.as_ref().unwrap().rep.clone(), + def: self.repdef.as_ref().unwrap().def.clone(), num_rows, - offset_into_first_chunk, + def_meaning: self.def_meaning.clone(), + }) as Box)) + .boxed()) + } +} + +#[derive(Debug)] +pub struct ComplexAllNullPageDecoder { + ranges: VecDeque>, + rep: Option>, + def: Option>, + num_rows: u64, + def_meaning: Arc<[DefinitionInterpretation]>, +} + +impl ComplexAllNullPageDecoder { + fn drain_ranges(&mut self, num_rows: u64) -> Vec> { + let mut rows_desired = num_rows; + let mut ranges = Vec::with_capacity(self.ranges.len()); + while rows_desired > 0 { + let front = self.ranges.front_mut().unwrap(); + let avail = front.end - front.start; + if avail > rows_desired { + ranges.push(front.start..front.start + rows_desired); + front.start += rows_desired; + rows_desired = 0; + } else { + ranges.push(self.ranges.pop_front().unwrap()); + rows_desired -= avail; + } + } + ranges + } +} + +impl StructuralPageDecoder for ComplexAllNullPageDecoder { + fn drain(&mut self, num_rows: u64) -> Result> { + let drained_ranges = self.drain_ranges(num_rows); + Ok(Box::new(DecodeComplexAllNullTask { + ranges: drained_ranges, + rep: self.rep.clone(), + def: self.def.clone(), + def_meaning: self.def_meaning.clone(), })) } @@ -553,6 +1036,50 @@ impl StructuralPageDecoder for MiniBlockDecoder { } } +/// We use `ranges` to slice into `rep` and `def` and create rep/def buffers +/// for the null data. +#[derive(Debug)] +pub struct DecodeComplexAllNullTask { + ranges: Vec>, + rep: Option>, + def: Option>, + def_meaning: Arc<[DefinitionInterpretation]>, +} + +impl DecodeComplexAllNullTask { + fn decode_level( + &self, + levels: &Option>, + num_values: u64, + ) -> Option> { + levels.as_ref().map(|levels| { + let mut referenced_levels = Vec::with_capacity(num_values as usize); + for range in &self.ranges { + referenced_levels.extend( + levels[range.start as usize..range.end as usize] + .iter() + .copied(), + ); + } + referenced_levels + }) + } +} + +impl DecodePageTask for DecodeComplexAllNullTask { + fn decode(self: Box) -> Result { + let num_values = self.ranges.iter().map(|r| r.end - r.start).sum::(); + let data = DataBlock::AllNull(AllNullDataBlock { num_values }); + let rep = self.decode_level(&self.rep, num_values); + let def = self.decode_level(&self.def, num_values); + let unraveler = RepDefUnraveler::new(rep, def, self.def_meaning); + Ok(DecodedPage { + data, + repdef: unraveler, + }) + } +} + /// A scheduler for simple all-null data /// /// "simple" all-null data is data that is all null and only has a single level of definition and @@ -561,14 +1088,19 @@ impl StructuralPageDecoder for MiniBlockDecoder { pub struct SimpleAllNullScheduler {} impl StructuralPageScheduler for SimpleAllNullScheduler { - fn initialize<'a>(&'a mut self, _io: &Arc) -> BoxFuture<'a, Result<()>> { - std::future::ready(Ok(())).boxed() + fn initialize<'a>( + &'a mut self, + _io: &Arc, + ) -> BoxFuture<'a, Result>> { + std::future::ready(Ok(Arc::new(NoCachedPageData) as Arc)).boxed() } + fn load(&mut self, _cache: &Arc) {} + fn schedule_ranges( &self, ranges: &[Range], - _io: &dyn EncodingsIo, + _io: &Arc, ) -> Result>>> { let num_rows = ranges.iter().map(|r| r.end - r.start).sum::(); Ok(std::future::ready(Ok( @@ -586,12 +1118,16 @@ struct SimpleAllNullDecodePageTask { } impl DecodePageTask for SimpleAllNullDecodePageTask { fn decode(self: Box) -> Result { + let unraveler = RepDefUnraveler::new( + None, + Some(vec![1; self.num_values as usize]), + Arc::new([DefinitionInterpretation::NullableItem]), + ); Ok(DecodedPage { data: DataBlock::AllNull(AllNullDataBlock { num_values: self.num_values, }), - repetition: None, - definition: Some(vec![1; self.num_values as usize]), + repdef: unraveler, }) } } @@ -619,57 +1155,189 @@ struct MiniBlockSchedulerDictionary { dictionary_decompressor: Arc, dictionary_buf_position_and_size: (u64, u64), dictionary_data_alignment: u64, + num_dictionary_items: u64, +} - // This is set after initialization - dictionary_data: Arc, +#[derive(Debug)] +struct RepIndexBlock { + // The index of the first row that starts after the beginning of this block. If the block + // has a preamble this will be the row after the preamble. If the block is entirely preamble + // then this will be a row that starts in some future block. + first_row: u64, + // The number of rows in the block, including the trailer but not the preamble. + // Can be 0 if the block is entirely preamble + starts_including_trailer: u64, + // Whether the block has a preamble + has_preamble: bool, + // Whether the block has a trailer + has_trailer: bool, +} + +impl DeepSizeOf for RepIndexBlock { + fn deep_size_of_children(&self, _context: &mut Context) -> usize { + 0 + } +} + +#[derive(Debug)] +struct RepetitionIndex { + blocks: Vec, +} + +impl DeepSizeOf for RepetitionIndex { + fn deep_size_of_children(&self, context: &mut Context) -> usize { + self.blocks.deep_size_of_children(context) + } +} + +impl RepetitionIndex { + fn decode(rep_index: &[Vec]) -> Self { + let mut chunk_has_preamble = false; + let mut offset = 0; + let mut blocks = Vec::with_capacity(rep_index.len()); + for chunk_rep in rep_index { + let ends_count = chunk_rep[0]; + let partial_count = chunk_rep[1]; + + let chunk_has_trailer = partial_count > 0; + let mut starts_including_trailer = ends_count; + if chunk_has_trailer { + starts_including_trailer += 1; + } + if chunk_has_preamble { + starts_including_trailer -= 1; + } + + blocks.push(RepIndexBlock { + first_row: offset, + starts_including_trailer, + has_preamble: chunk_has_preamble, + has_trailer: chunk_has_trailer, + }); + + chunk_has_preamble = chunk_has_trailer; + offset += starts_including_trailer; + } + + Self { blocks } + } +} + +/// State that is loaded once and cached for future lookups +#[derive(Debug)] +struct MiniBlockCacheableState { + /// Metadata that describes each chunk in the page + chunk_meta: Vec, + /// The decoded repetition index + rep_index: RepetitionIndex, + /// The dictionary for the page, if any + dictionary: Option>, +} + +impl DeepSizeOf for MiniBlockCacheableState { + fn deep_size_of_children(&self, context: &mut Context) -> usize { + self.rep_index.deep_size_of_children(context) + + self + .dictionary + .as_ref() + .map(|dict| dict.data_size() as usize) + .unwrap_or(0) + } +} + +impl CachedPageData for MiniBlockCacheableState { + fn as_arc_any(self: Arc) -> Arc { + self + } } /// A scheduler for a page that has been encoded with the mini-block layout +/// +/// Scheduling mini-block encoded data is simple in concept and somewhat complex +/// in practice. +/// +/// First, during initialization, we load the chunk metadata, the repetition index, +/// and the dictionary (these last two may not be present) +/// +/// Then, during scheduling, we use the user's requested row ranges and the repetition +/// index to determine which chunks we need and which rows we need from those chunks. +/// +/// For example, if the repetition index is: [50, 3], [50, 0], [10, 0] and the range +/// from the user is 40..60 then we need to: +/// +/// - Read the first chunk and skip the first 40 rows, then read 10 full rows, and +/// then read 3 items for the 11th row of our range. +/// - Read the second chunk and read the remaining items in our 11th row and then read +/// the remaining 9 full rows. +/// +/// Then, if we are going to decode that in batches of 5, we need to make decode tasks. +/// The first two decode tasks will just need the first chunk. The third decode task will +/// need the first chunk (for the trailer which has the 11th row in our range) and the second +/// chunk. The final decode task will just need the second chunk. +/// +/// The above prose descriptions are what are represented by [`ChunkInstructions`] and +/// [`ChunkDrainInstructions`]. #[derive(Debug)] pub struct MiniBlockScheduler { // These come from the protobuf - meta_buf_position: u64, - meta_buf_size: u64, - data_buf_position: u64, + buffer_offsets_and_sizes: Vec<(u64, u64)>, priority: u64, - rows_in_page: u64, - rep_decompressor: Arc, - def_decompressor: Arc, + items_in_page: u64, + repetition_index_depth: u16, + num_buffers: u64, + rep_decompressor: Option>, + def_decompressor: Option>, value_decompressor: Arc, - - // This is set after initialization - chunk_meta: Vec, - + def_meaning: Arc<[DefinitionInterpretation]>, dictionary: Option, + // This is set after initialization + page_meta: Option>, } impl MiniBlockScheduler { fn try_new( buffer_offsets_and_sizes: &[(u64, u64)], priority: u64, - rows_in_page: u64, + items_in_page: u64, layout: &pb::MiniBlockLayout, decompressors: &dyn DecompressorStrategy, ) -> Result { - let (meta_buf_position, meta_buf_size) = buffer_offsets_and_sizes[0]; - // We don't use the data buf size since we can get it from the metadata - let (data_buf_position, _) = buffer_offsets_and_sizes[1]; - let rep_decompressor = - decompressors.create_block_decompressor(layout.rep_compression.as_ref().unwrap())?; - let def_decompressor = - decompressors.create_block_decompressor(layout.def_compression.as_ref().unwrap())?; + let rep_decompressor = layout + .rep_compression + .as_ref() + .map(|rep_compression| { + decompressors + .create_block_decompressor(rep_compression) + .map(Arc::from) + }) + .transpose()?; + let def_decompressor = layout + .def_compression + .as_ref() + .map(|def_compression| { + decompressors + .create_block_decompressor(def_compression) + .map(Arc::from) + }) + .transpose()?; + let def_meaning = layout + .layers + .iter() + .map(|l| ProtobufUtils::repdef_layer_to_def_interp(*l)) + .collect::>(); let value_decompressor = decompressors .create_miniblock_decompressor(layout.value_compression.as_ref().unwrap())?; let dictionary = if let Some(dictionary_encoding) = layout.dictionary.as_ref() { + let num_dictionary_items = layout.num_dictionary_items; match dictionary_encoding.array_encoding.as_ref().unwrap() { - pb::array_encoding::ArrayEncoding::BinaryBlock(_) => { + pb::array_encoding::ArrayEncoding::Variable(_) => { Some(MiniBlockSchedulerDictionary { dictionary_decompressor: decompressors .create_block_decompressor(dictionary_encoding)? .into(), dictionary_buf_position_and_size: buffer_offsets_and_sizes[2], dictionary_data_alignment: 4, - dictionary_data: Arc::new(DataBlock::Empty()), + num_dictionary_items, }) } pb::array_encoding::ArrayEncoding::Flat(_) => Some(MiniBlockSchedulerDictionary { @@ -678,7 +1346,7 @@ impl MiniBlockScheduler { .into(), dictionary_buf_position_and_size: buffer_offsets_and_sizes[2], dictionary_data_alignment: 16, - dictionary_data: Arc::new(DataBlock::Empty()), + num_dictionary_items, }), _ => { unreachable!("Currently only encodings `BinaryBlock` and `Flat` used for encoding MiniBlock dictionary.") @@ -689,88 +1357,314 @@ impl MiniBlockScheduler { }; Ok(Self { - meta_buf_position, - meta_buf_size, - data_buf_position, - rep_decompressor: rep_decompressor.into(), - def_decompressor: def_decompressor.into(), + buffer_offsets_and_sizes: buffer_offsets_and_sizes.to_vec(), + rep_decompressor, + def_decompressor, value_decompressor: value_decompressor.into(), + repetition_index_depth: layout.repetition_index_depth as u16, + num_buffers: layout.num_buffers, priority, - rows_in_page, - chunk_meta: Vec::new(), + items_in_page, dictionary, + def_meaning: def_meaning.into(), + page_meta: None, }) } - /// Calculates the overlap between a user-supplied range and a chunk of mini-block data - fn calc_overlap( - range: &mut Range, - chunk: &ChunkMeta, - rows_offset: u64, - dst: &mut ScheduledChunk, - ) -> ChunkOverlap { - if range.start > chunk.num_values + rows_offset { - ChunkOverlap::RangeAfterChunk - } else { - let start_in_chunk = range.start - rows_offset; - let end_in_chunk = (start_in_chunk + range.end - range.start).min(chunk.num_values); - let rows_in_chunk = end_in_chunk - start_in_chunk; - range.start += rows_in_chunk; - dst.ranges.push(start_in_chunk..end_in_chunk); - ChunkOverlap::Overlap - } + fn lookup_chunks(&self, chunk_indices: &[usize]) -> Vec { + let page_meta = self.page_meta.as_ref().unwrap(); + chunk_indices + .iter() + .map(|&chunk_idx| { + let chunk_meta = &page_meta.chunk_meta[chunk_idx]; + let bytes_start = chunk_meta.offset_bytes; + let bytes_end = bytes_start + chunk_meta.chunk_size_bytes; + LoadedChunk { + byte_range: bytes_start..bytes_end, + items_in_chunk: chunk_meta.num_values, + chunk_idx, + data: LanceBuffer::empty(), + } + }) + .collect() } } -#[derive(Debug)] -struct ScheduledChunk { - data: LanceBuffer, - // The total number of values in the chunk, not all values may be targeted - vals_in_chunk: u64, - // The number of values that are targeted by the ranges. This should be the - // same as the sum of `Self::ranges` - vals_targeted: u64, - ranges: Vec>, +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum PreambleAction { + Take, + Skip, + Absent, } -impl Clone for ScheduledChunk { - fn clone(&self) -> Self { - Self { - data: self.data.try_clone().unwrap(), - vals_in_chunk: self.vals_in_chunk, - ranges: self.ranges.clone(), - vals_targeted: self.vals_targeted, +// TODO: Add test cases for the all-preamble and all-trailer cases + +// When we schedule a chunk we use the repetition index (or, if none exists, just the # of items +// in each chunk) to map a user requested range into a set of ChunkInstruction objects which tell +// us how exactly to read from the chunk. +#[derive(Clone, Debug, PartialEq, Eq)] +struct ChunkInstructions { + // The index of the chunk to read + chunk_idx: usize, + // A "preamble" is when a chunk begins with a continuation of the previous chunk's list. If there + // is no repetition index there is never a preamble. + // + // It's possible for a chunk to be entirely premable. For example, if there is a really large list + // that spans several chunks. + preamble: PreambleAction, + // How many complete rows (not including the preamble or trailer) to skip + // + // If this is non-zero then premable must not be Take + rows_to_skip: u64, + // How many complete (non-preamble / non-trailer) rows to take + rows_to_take: u64, + // A "trailer" is when a chunk ends with a partial list. If there is no repetition index there is + // never a trailer. + // + // It's possible for a chunk to be entirely trailer. This would mean the chunk starts with the beginning + // of a list and that list is continued in the next chunk. + // + // If this is true then we want to include the trailer + take_trailer: bool, +} + +// First, we schedule a bunch of [`ChunkInstructions`] based on the users ranges. Then we +// start decoding them, based on a batch size, which might not align with what we scheduled. +// +// This results in `ChunkDrainInstructions` which targets a contiguous slice of a `ChunkInstructions` +// +// So if `ChunkInstructions` is "skip preamble, skip 10, take 50, take trailer" and we are decoding in +// batches of size 10 we might have a `ChunkDrainInstructions` that targets that chunk and has its own +// skip of 17 and take of 10. This would mean we decode the chunk, skip the preamble and 27 rows, and +// then take 10 rows. +// +// One very confusing bit is that `rows_to_take` includes the trailer. So if we have two chunks: +// -no preamble, skip 5, take 10, take trailer +// -take preamble, skip 0, take 50, no trailer +// +// and we are draining 20 rows then the drain instructions for the first batch will be: +// - no preamble, skip 0 (from chunk 0), take 11 (from chunk 0) +// - take preamble (from chunk 1), skip 0 (from chunk 1), take 9 (from chunk 1) +#[derive(Debug, PartialEq, Eq)] +struct ChunkDrainInstructions { + chunk_instructions: ChunkInstructions, + rows_to_skip: u64, + rows_to_take: u64, + preamble_action: PreambleAction, +} + +impl ChunkInstructions { + // Given a repetition index and a set of user ranges we need to figure out how to read from the chunks + // + // We assume that `user_ranges` are in sorted order and non-overlapping + // + // The output will be a set of `ChunkInstructions` which tell us how to read from the chunks + fn schedule_instructions(rep_index: &RepetitionIndex, user_ranges: &[Range]) -> Vec { + // This is an in-exact capacity guess but pretty good. The actual capacity can be + // smaller if instructions are merged. It can be larger if there are multiple instructions + // per row which can happen with lists. + let mut chunk_instructions = Vec::with_capacity(user_ranges.len()); + + for user_range in user_ranges { + let mut rows_needed = user_range.end - user_range.start; + let mut need_preamble = false; + + // Need to find the first chunk with a first row >= user_range.start. If there are + // multiple chunks with the same first row we need to take the first one. + let mut block_index = match rep_index + .blocks + .binary_search_by_key(&user_range.start, |block| block.first_row) + { + Ok(idx) => { + // Slightly tricky case, we may need to walk backwards a bit to make sure we + // are grabbing first eligible chunk + let mut idx = idx; + while idx > 0 && rep_index.blocks[idx - 1].first_row == user_range.start { + idx -= 1; + } + idx + } + // Easy case. idx is greater, and idx - 1 is smaller, so idx - 1 contains the start + Err(idx) => idx - 1, + }; + + let mut to_skip = user_range.start - rep_index.blocks[block_index].first_row; + + while rows_needed > 0 || need_preamble { + let chunk = &rep_index.blocks[block_index]; + let rows_avail = chunk.starts_including_trailer - to_skip; + debug_assert!(rows_avail > 0); + + let rows_to_take = rows_avail.min(rows_needed); + rows_needed -= rows_to_take; + + let mut take_trailer = false; + let preamble = if chunk.has_preamble { + if need_preamble { + PreambleAction::Take + } else { + PreambleAction::Skip + } + } else { + PreambleAction::Absent + }; + let mut rows_to_take_no_trailer = rows_to_take; + + // Are we taking the trailer? If so, make sure we mark that we need the preamble + if rows_to_take == rows_avail && chunk.has_trailer { + take_trailer = true; + need_preamble = true; + rows_to_take_no_trailer -= 1; + } else { + need_preamble = false; + }; + + chunk_instructions.push(Self { + preamble, + chunk_idx: block_index, + rows_to_skip: to_skip, + rows_to_take: rows_to_take_no_trailer, + take_trailer, + }); + + to_skip = 0; + block_index += 1; + } + } + + // If there were multiple ranges we may have multiple instructions for a single chunk. Merge them now if they + // are _adjacent_ (i.e. don't merge "take first row of chunk 0" and "take third row of chunk 0" into "take 2 + // rows of chunk 0 starting at 0") + if user_ranges.len() > 1 { + // TODO: Could probably optimize this allocation away + let mut merged_instructions = Vec::with_capacity(chunk_instructions.len()); + let mut instructions_iter = chunk_instructions.into_iter(); + merged_instructions.push(instructions_iter.next().unwrap()); + for instruction in instructions_iter { + let last = merged_instructions.last_mut().unwrap(); + if last.chunk_idx == instruction.chunk_idx + && last.rows_to_take + last.rows_to_skip == instruction.rows_to_skip + { + last.rows_to_take += instruction.rows_to_take; + last.take_trailer |= instruction.take_trailer; + } else { + merged_instructions.push(instruction); + } + } + merged_instructions + } else { + chunk_instructions } } -} -pub enum ChunkOverlap { - RangeAfterChunk, - Overlap, + fn drain_from_instruction( + &self, + rows_desired: &mut u64, + need_preamble: &mut bool, + skip_in_chunk: &mut u64, + ) -> (ChunkDrainInstructions, bool) { + // If we need the premable then we shouldn't be skipping anything + debug_assert!(!*need_preamble || *skip_in_chunk == 0); + let mut rows_avail = self.rows_to_take - *skip_in_chunk; + let has_preamble = self.preamble != PreambleAction::Absent; + let preamble_action = match (*need_preamble, has_preamble) { + (true, true) => PreambleAction::Take, + (true, false) => panic!("Need preamble but there isn't one"), + (false, true) => PreambleAction::Skip, + (false, false) => PreambleAction::Absent, + }; + + // Did the scheduled chunk have a trailer? If so, we have one extra row available + if self.take_trailer { + rows_avail += 1; + } + + // How many rows are we actually taking in this take step (including the preamble + // and trailer both as individual rows) + let rows_taking = if *rows_desired >= rows_avail { + // We want all the rows. If there is a trailer we are grabbing it and will need + // the preamble of the next chunk + *need_preamble = self.take_trailer; + rows_avail + } else { + // We aren't taking all the rows. Even if there is a trailer we aren't taking + // it so we will not need the preamble + *need_preamble = false; + *rows_desired + }; + let rows_skipped = *skip_in_chunk; + + // Update the state for the next iteration + let consumed_chunk = if *rows_desired >= rows_avail { + *rows_desired -= rows_avail; + *skip_in_chunk = 0; + true + } else { + *skip_in_chunk += *rows_desired; + *rows_desired = 0; + false + }; + + ( + ChunkDrainInstructions { + chunk_instructions: self.clone(), + rows_to_skip: rows_skipped, + rows_to_take: rows_taking, + preamble_action, + }, + consumed_chunk, + ) + } } impl StructuralPageScheduler for MiniBlockScheduler { - fn initialize<'a>(&'a mut self, io: &Arc) -> BoxFuture<'a, Result<()>> { - let metadata = io.submit_single( - self.meta_buf_position..self.meta_buf_position + self.meta_buf_size, - 0, - ); - let dictionary_data = self.dictionary.as_ref().map(|dictionary| { - io.submit_single( + fn initialize<'a>( + &'a mut self, + io: &Arc, + ) -> BoxFuture<'a, Result>> { + // We always need to fetch chunk metadata. We may also need to fetch a dictionary and + // we may also need to fetch the repetition index. Here, we gather what buffers we + // need. + let (meta_buf_position, meta_buf_size) = self.buffer_offsets_and_sizes[0]; + let value_buf_position = self.buffer_offsets_and_sizes[1].0; + let mut bufs_needed = 1; + if self.dictionary.is_some() { + bufs_needed += 1; + } + if self.repetition_index_depth > 0 { + bufs_needed += 1; + } + let mut required_ranges = Vec::with_capacity(bufs_needed); + required_ranges.push(meta_buf_position..meta_buf_position + meta_buf_size); + if let Some(ref dictionary) = self.dictionary { + required_ranges.push( dictionary.dictionary_buf_position_and_size.0 ..dictionary.dictionary_buf_position_and_size.0 + dictionary.dictionary_buf_position_and_size.1, - 0, - ) - }); + ); + } + if self.repetition_index_depth > 0 { + let (rep_index_pos, rep_index_size) = self.buffer_offsets_and_sizes.last().unwrap(); + required_ranges.push(*rep_index_pos..*rep_index_pos + *rep_index_size); + } + let io_req = io.submit_request(required_ranges, 0); + async move { - let bytes = metadata.await?; - assert!(bytes.len() % 2 == 0); - let mut bytes = LanceBuffer::from_bytes(bytes, 2); + let mut buffers = io_req.await?.into_iter().fuse(); + let meta_bytes = buffers.next().unwrap(); + let dictionary_bytes = self.dictionary.as_ref().and_then(|_| buffers.next()); + let rep_index_bytes = buffers.next(); + + // Parse the metadata and build the chunk meta + assert!(meta_bytes.len() % 2 == 0); + let mut bytes = LanceBuffer::from_bytes(meta_bytes, 2); let words = bytes.borrow_to_typed_slice::(); let words = words.as_ref(); - self.chunk_meta.reserve(words.len()); + + let mut chunk_meta = Vec::with_capacity(words.len()); + let mut rows_counter = 0; + let mut offset_bytes = value_buf_position; for (word_idx, word) in words.iter().enumerate() { let log_num_values = word & 0x0F; let divided_bytes = word >> 4; @@ -780,131 +1674,172 @@ impl StructuralPageScheduler for MiniBlockScheduler { debug_assert!(log_num_values > 0); 1 << log_num_values } else { - debug_assert_eq!(log_num_values, 0); - self.rows_in_page - rows_counter + debug_assert!( + log_num_values == 0 + || (1 << log_num_values) == (self.items_in_page - rows_counter) + ); + self.items_in_page - rows_counter }; rows_counter += num_values; - self.chunk_meta.push(ChunkMeta { + chunk_meta.push(ChunkMeta { num_values, chunk_size_bytes: num_bytes as u64, + offset_bytes, }); + offset_bytes += num_bytes as u64; } + + // Build the repetition index + let rep_index = if let Some(rep_index_data) = rep_index_bytes { + // If we have a repetition index then we use that + // TODO: Compress the repetition index :) + assert!(rep_index_data.len() % 8 == 0); + let mut repetition_index_vals = LanceBuffer::from_bytes(rep_index_data, 8); + let repetition_index_vals = repetition_index_vals.borrow_to_typed_slice::(); + // Unflatten + repetition_index_vals + .as_ref() + .chunks_exact(self.repetition_index_depth as usize + 1) + .map(|c| c.to_vec()) + .collect::>() + } else { + // Default rep index is just the number of items in each chunk + // with 0 partials/leftovers + chunk_meta + .iter() + .map(|c| vec![c.num_values, 0]) + .collect::>() + }; + + let mut page_meta = MiniBlockCacheableState { + chunk_meta, + rep_index: RepetitionIndex::decode(&rep_index), + dictionary: None, + }; + // decode dictionary if let Some(ref mut dictionary) = self.dictionary { - let dictionary_data = dictionary_data.unwrap().await?; - dictionary.dictionary_data = - Arc::new(dictionary.dictionary_decompressor.decompress( + let dictionary_data = dictionary_bytes.unwrap(); + page_meta.dictionary = + Some(Arc::new(dictionary.dictionary_decompressor.decompress( LanceBuffer::from_bytes( dictionary_data, dictionary.dictionary_data_alignment, ), - )?) + dictionary.num_dictionary_items, + )?)); }; - Ok(()) + let page_meta = Arc::new(page_meta); + self.page_meta = Some(page_meta.clone()); + Ok(page_meta as Arc) } .boxed() } + fn load(&mut self, data: &Arc) { + self.page_meta = Some( + data.clone() + .as_arc_any() + .downcast::() + .unwrap(), + ); + } + fn schedule_ranges( &self, ranges: &[Range], - io: &dyn EncodingsIo, + io: &Arc, ) -> Result>>> { - let mut chunk_meta_iter = self.chunk_meta.iter(); - let mut current_chunk = chunk_meta_iter.next().unwrap(); - let mut row_offset = 0; - let mut bytes_offset = 0; + let num_rows = ranges.iter().map(|r| r.end - r.start).sum(); - let mut scheduled_chunks = VecDeque::with_capacity(self.chunk_meta.len()); - let mut ranges_to_req = Vec::with_capacity(self.chunk_meta.len()); - let mut num_rows = 0; + let page_meta = self.page_meta.as_ref().unwrap(); - let mut current_scheduled_chunk = ScheduledChunk { - data: LanceBuffer::empty(), - ranges: Vec::new(), - vals_in_chunk: current_chunk.num_values, - vals_targeted: 0, - }; + let chunk_instructions = + ChunkInstructions::schedule_instructions(&page_meta.rep_index, ranges); - // There can be both multiple ranges per chunk and multiple chunks per range - for range in ranges { - num_rows += range.end - range.start; - let mut range = range.clone(); - while !range.is_empty() { - Self::calc_overlap( - &mut range, - current_chunk, - row_offset, - &mut current_scheduled_chunk, - ); - // Might be empty if entire chunk is skipped - if !range.is_empty() { - if !current_scheduled_chunk.ranges.is_empty() { - scheduled_chunks.push_back(current_scheduled_chunk); - ranges_to_req.push( - (self.data_buf_position + bytes_offset) - ..(self.data_buf_position - + bytes_offset - + current_chunk.chunk_size_bytes), - ); - } - row_offset += current_chunk.num_values; - bytes_offset += current_chunk.chunk_size_bytes; - if let Some(next_chunk) = chunk_meta_iter.next() { - current_chunk = next_chunk; + debug_assert_eq!( + num_rows, + chunk_instructions + .iter() + .map(|ci| { + let taken = ci.rows_to_take; + if ci.take_trailer { + taken + 1 + } else { + taken } - current_scheduled_chunk = ScheduledChunk { - data: LanceBuffer::empty(), - ranges: Vec::new(), - vals_in_chunk: current_chunk.num_values, - vals_targeted: 0, - }; - } - } - } - if !current_scheduled_chunk.ranges.is_empty() { - scheduled_chunks.push_back(current_scheduled_chunk); - ranges_to_req.push( - (self.data_buf_position + bytes_offset) - ..(self.data_buf_position + bytes_offset + current_chunk.chunk_size_bytes), - ); - } + }) + .sum::() + ); - let data = io.submit_request(ranges_to_req, self.priority); + let chunks_needed = chunk_instructions + .iter() + .map(|ci| ci.chunk_idx) + .unique() + .collect::>(); + let mut loaded_chunks = self.lookup_chunks(&chunks_needed); + let chunk_ranges = loaded_chunks + .iter() + .map(|c| c.byte_range.clone()) + .collect::>(); + let loaded_chunk_data = io.submit_request(chunk_ranges, self.priority); let rep_decompressor = self.rep_decompressor.clone(); let def_decompressor = self.def_decompressor.clone(); let value_decompressor = self.value_decompressor.clone(); - let dictionary = self + let num_buffers = self.num_buffers; + let dictionary = page_meta .dictionary .as_ref() - .map(|dictionary| dictionary.dictionary_data.clone()); + .map(|dictionary| dictionary.clone()); + let def_meaning = self.def_meaning.clone(); - for scheduled_chunk in scheduled_chunks.iter_mut() { - scheduled_chunk.vals_targeted = - scheduled_chunk.ranges.iter().map(|r| r.end - r.start).sum(); - } - - Ok(async move { - let data = data.await?; - for (chunk, data) in scheduled_chunks.iter_mut().zip(data) { - chunk.data = LanceBuffer::from_bytes(data, 1); + let res = async move { + let loaded_chunk_data = loaded_chunk_data.await?; + for (loaded_chunk, chunk_data) in loaded_chunks.iter_mut().zip(loaded_chunk_data) { + loaded_chunk.data = LanceBuffer::from_bytes(chunk_data, 1); } + Ok(Box::new(MiniBlockDecoder { rep_decompressor, def_decompressor, value_decompressor, - data: scheduled_chunks, + def_meaning, + loaded_chunks: VecDeque::from_iter(loaded_chunks), + instructions: VecDeque::from(chunk_instructions), offset_in_current_chunk: 0, - num_rows, dictionary, + num_rows, + num_buffers, }) as Box) } - .boxed()) + .boxed(); + Ok(res) } } +#[derive(Debug)] +struct FullZipRepIndexDetails { + buf_position: u64, + bytes_per_value: u64, // Will be 1, 2, 4, or 8 +} + +#[derive(Debug)] +enum PerValueDecompressor { + Fixed(Arc), + Variable(Arc), +} + +#[derive(Debug)] +struct FullZipDecodeDetails { + value_decompressor: PerValueDecompressor, + def_meaning: Arc<[DefinitionInterpretation]>, + ctrl_word_parser: ControlWordParser, + max_rep: u16, + max_visible_def: u16, +} + /// A scheduler for full-zip encoded data /// /// When the data type has a fixed-width then we simply need to map from @@ -915,10 +1850,11 @@ impl StructuralPageScheduler for MiniBlockScheduler { #[derive(Debug)] pub struct FullZipScheduler { data_buf_position: u64, + rep_index: Option, priority: u64, rows_in_page: u64, - value_decompressor: Arc, - ctrl_word_parser: ControlWordParser, + bits_per_offset: u8, + details: Arc, } impl FullZipScheduler { @@ -928,53 +1864,223 @@ impl FullZipScheduler { rows_in_page: u64, layout: &pb::FullZipLayout, decompressors: &dyn DecompressorStrategy, + bits_per_offset: u8, ) -> Result { - // We don't need the data_buf_size because we either the data type is + // We don't need the data_buf_size because either the data type is // fixed-width (and we can tell size from rows_in_page) or it is not // and we have a repetition index. let (data_buf_position, _) = buffer_offsets_and_sizes[0]; - let value_decompressor = decompressors - .create_per_value_decompressor(layout.value_compression.as_ref().unwrap())?; + let rep_index = buffer_offsets_and_sizes.get(1).map(|(pos, len)| { + let num_reps = rows_in_page + 1; + let bytes_per_rep = len / num_reps; + debug_assert_eq!(len % num_reps, 0); + debug_assert!( + bytes_per_rep == 1 + || bytes_per_rep == 2 + || bytes_per_rep == 4 + || bytes_per_rep == 8 + ); + FullZipRepIndexDetails { + buf_position: *pos, + bytes_per_value: bytes_per_rep, + } + }); + + let value_decompressor = match layout.details { + Some(pb::full_zip_layout::Details::BitsPerValue(_)) => { + let decompressor = decompressors.create_fixed_per_value_decompressor( + layout.value_compression.as_ref().unwrap(), + )?; + PerValueDecompressor::Fixed(decompressor.into()) + } + Some(pb::full_zip_layout::Details::BitsPerOffset(_)) => { + let decompressor = decompressors.create_variable_per_value_decompressor( + layout.value_compression.as_ref().unwrap(), + )?; + PerValueDecompressor::Variable(decompressor.into()) + } + None => { + panic!("Full-zip layout must have a `details` field"); + } + }; let ctrl_word_parser = ControlWordParser::new( layout.bits_rep.try_into().unwrap(), layout.bits_def.try_into().unwrap(), ); + let def_meaning = layout + .layers + .iter() + .map(|l| ProtobufUtils::repdef_layer_to_def_interp(*l)) + .collect::>(); + + let max_rep = def_meaning.iter().filter(|d| d.is_list()).count() as u16; + let max_visible_def = def_meaning + .iter() + .filter(|d| !d.is_list()) + .map(|d| d.num_def_levels()) + .sum(); + + let details = Arc::new(FullZipDecodeDetails { + value_decompressor, + def_meaning: def_meaning.into(), + ctrl_word_parser, + max_rep, + max_visible_def, + }); Ok(Self { data_buf_position, - value_decompressor: value_decompressor.into(), + rep_index, + details, priority, rows_in_page, - ctrl_word_parser, + bits_per_offset, }) } -} -impl StructuralPageScheduler for FullZipScheduler { - fn initialize<'a>(&'a mut self, _io: &Arc) -> BoxFuture<'a, Result<()>> { - std::future::ready(Ok(())).boxed() + /// Schedules indirectly by first fetching the data ranges from the + /// repetition index and then fetching the data + /// + /// This approach is needed whenever we have a repetition index and + /// the data has a variable length. + #[allow(clippy::too_many_arguments)] + async fn indirect_schedule_ranges( + data_buffer_pos: u64, + row_ranges: Vec>, + rep_index_ranges: Vec>, + bytes_per_rep: u64, + io: Arc, + priority: u64, + bits_per_offset: u8, + details: Arc, + ) -> Result> { + let byte_ranges = io + .submit_request(rep_index_ranges, priority) + .await? + .into_iter() + .map(|d| LanceBuffer::from_bytes(d, 1)) + .collect::>(); + let byte_ranges = LanceBuffer::concat(&byte_ranges); + let byte_ranges = ByteUnpacker::new(byte_ranges, bytes_per_rep as usize) + .chunks(2) + .into_iter() + .map(|mut c| { + let start = c.next().unwrap() + data_buffer_pos; + let end = c.next().unwrap() + data_buffer_pos; + start..end + }) + .collect::>(); + + let data = io.submit_request(byte_ranges, priority); + + let data = data.await?; + let data = data + .into_iter() + .map(|d| LanceBuffer::from_bytes(d, 1)) + .collect(); + let num_rows = row_ranges.into_iter().map(|r| r.end - r.start).sum(); + + match &details.value_decompressor { + PerValueDecompressor::Fixed(decompressor) => { + let bits_per_value = decompressor.bits_per_value(); + assert!(bits_per_value > 0); + if bits_per_value % 8 != 0 { + // Unlikely we will ever want this since full-zip values are so large the few bits we shave off don't + // make much difference. + unimplemented!("Bit-packed full-zip"); + } + let bytes_per_value = bits_per_value / 8; + let total_bytes_per_value = + bytes_per_value as usize + details.ctrl_word_parser.bytes_per_word(); + Ok(Box::new(FixedFullZipDecoder { + details, + data, + num_rows, + offset_in_current: 0, + bytes_per_value: bytes_per_value as usize, + total_bytes_per_value, + }) as Box) + } + PerValueDecompressor::Variable(_decompressor) => { + // Variable full-zip + + Ok(Box::new(VariableFullZipDecoder::new( + details, + data, + num_rows, + bits_per_offset, + bits_per_offset, + ))) + } + } + } + + /// Schedules ranges in the presence of a repetition index + fn schedule_ranges_rep( + &self, + ranges: &[Range], + io: &Arc, + rep_index: &FullZipRepIndexDetails, + ) -> Result>>> { + let rep_index_ranges = ranges + .iter() + .flat_map(|r| { + let first_val_start = + rep_index.buf_position + (r.start * rep_index.bytes_per_value); + let first_val_end = first_val_start + rep_index.bytes_per_value; + let last_val_start = rep_index.buf_position + (r.end * rep_index.bytes_per_value); + let last_val_end = last_val_start + rep_index.bytes_per_value; + [first_val_start..first_val_end, last_val_start..last_val_end] + }) + .collect::>(); + + // Create the decoder + + Ok(Self::indirect_schedule_ranges( + self.data_buf_position, + ranges.to_vec(), + rep_index_ranges, + rep_index.bytes_per_value, + io.clone(), + self.priority, + self.bits_per_offset, + self.details.clone(), + ) + .boxed()) } - fn schedule_ranges( + // In the simple case there is no repetition and we just have large fixed-width + // rows of data. We can just map row ranges to byte ranges directly using the + // fixed-width of the data type. + fn schedule_ranges_simple( &self, ranges: &[Range], io: &dyn EncodingsIo, ) -> Result>>> { - let bits_per_value = self.value_decompressor.bits_per_value(); + // Convert row ranges to item ranges (i.e. multiply by items per row) + let num_rows = ranges.iter().map(|r| r.end - r.start).sum(); + + let PerValueDecompressor::Fixed(decompressor) = &self.details.value_decompressor else { + unreachable!() + }; + + // Convert item ranges to byte ranges (i.e. multiply by bytes per item) + let bits_per_value = decompressor.bits_per_value(); assert_eq!(bits_per_value % 8, 0); let bytes_per_value = bits_per_value / 8; - let bytes_per_cw = self.ctrl_word_parser.bytes_per_word(); + let bytes_per_cw = self.details.ctrl_word_parser.bytes_per_word(); let total_bytes_per_value = bytes_per_value + bytes_per_cw as u64; - // We simply map row ranges into byte ranges let byte_ranges = ranges.iter().map(|r| { debug_assert!(r.end <= self.rows_in_page); let start = self.data_buf_position + r.start * total_bytes_per_value; let end = self.data_buf_position + r.end * total_bytes_per_value; start..end }); + + // Request byte ranges let data = io.submit_request(byte_ranges.collect(), self.priority); - let value_decompressor = self.value_decompressor.clone(); - let num_rows = ranges.iter().map(|r| r.end - r.start).sum(); - let ctrl_word_parser = self.ctrl_word_parser; + + let details = self.details.clone(); + Ok(async move { let data = data.await?; let data = data @@ -982,10 +2088,9 @@ impl StructuralPageScheduler for FullZipScheduler { .map(|d| LanceBuffer::from_bytes(d, 1)) .collect(); Ok(Box::new(FixedFullZipDecoder { - value_decompressor, + details, data, num_rows, - ctrl_word_parser, offset_in_current: 0, bytes_per_value: bytes_per_value as usize, total_bytes_per_value: total_bytes_per_value as usize, @@ -995,6 +2100,30 @@ impl StructuralPageScheduler for FullZipScheduler { } } +impl StructuralPageScheduler for FullZipScheduler { + // TODO: Add opt-in caching of repetition index + fn initialize<'a>( + &'a mut self, + _io: &Arc, + ) -> BoxFuture<'a, Result>> { + std::future::ready(Ok(Arc::new(NoCachedPageData) as Arc)).boxed() + } + + fn load(&mut self, _cache: &Arc) {} + + fn schedule_ranges( + &self, + ranges: &[Range], + io: &Arc, + ) -> Result>>> { + if let Some(rep_index) = self.rep_index.as_ref() { + self.schedule_ranges_rep(ranges, io, rep_index) + } else { + self.schedule_ranges_simple(ranges, io.as_ref()) + } + } +} + /// A decoder for full-zip encoded data when the data has a fixed-width /// /// Here we need to unzip the control words from the values themselves and @@ -1004,8 +2133,7 @@ impl StructuralPageScheduler for FullZipScheduler { /// requested data. This decoder / scheduler does not do any read amplification. #[derive(Debug)] struct FixedFullZipDecoder { - value_decompressor: Arc, - ctrl_word_parser: ControlWordParser, + details: Arc, data: VecDeque, offset_in_current: usize, bytes_per_value: usize, @@ -1013,37 +2141,356 @@ struct FixedFullZipDecoder { num_rows: u64, } -impl StructuralPageDecoder for FixedFullZipDecoder { - fn drain(&mut self, num_rows: u64) -> Result> { - let mut task_data = Vec::with_capacity(self.data.len()); - let mut remaining = num_rows; - while remaining > 0 { - let cur_buf = self.data.front_mut().unwrap(); - let bytes_avail = cur_buf.len() - self.offset_in_current; - - let bytes_needed = remaining as usize * self.total_bytes_per_value; - let bytes_to_take = bytes_needed.min(bytes_avail); +impl FixedFullZipDecoder { + fn slice_next_task(&mut self, num_rows: u64) -> FullZipDecodeTaskItem { + debug_assert!(num_rows > 0); + let cur_buf = self.data.front_mut().unwrap(); + let start = self.offset_in_current; + if self.details.ctrl_word_parser.has_rep() { + // This is a slightly slower path. In order to figure out where to split we need to + // examine the rep index so we can convert num_lists to num_rows + let mut rows_started = 0; + // We always need at least one value. Now loop through until we have passed num_rows + // values + let mut num_items = 0; + while self.offset_in_current < cur_buf.len() { + let control = self.details.ctrl_word_parser.parse_desc( + &cur_buf[self.offset_in_current..], + self.details.max_rep, + self.details.max_visible_def, + ); + if control.is_new_row { + if rows_started == num_rows { + break; + } + rows_started += 1; + } + num_items += 1; + if control.is_visible { + self.offset_in_current += self.total_bytes_per_value; + } else { + self.offset_in_current += self.details.ctrl_word_parser.bytes_per_word(); + } + } - let task_slice = cur_buf.slice_with_length(self.offset_in_current, bytes_to_take); - let rows_in_task = (bytes_to_take / self.total_bytes_per_value) as u64; + let task_slice = cur_buf.slice_with_length(start, self.offset_in_current - start); + if self.offset_in_current == cur_buf.len() { + self.data.pop_front(); + self.offset_in_current = 0; + } - task_data.push((task_slice, rows_in_task)); + FullZipDecodeTaskItem { + data: PerValueDataBlock::Fixed(FixedWidthDataBlock { + data: task_slice, + bits_per_value: self.bytes_per_value as u64 * 8, + num_values: num_items, + block_info: BlockInfo::new(), + }), + rows_in_buf: rows_started, + } + } else { + // If there's no repetition we can calculate the slicing point by just multiplying + // the number of rows by the total bytes per value + let cur_buf = self.data.front_mut().unwrap(); + let bytes_avail = cur_buf.len() - self.offset_in_current; + let offset_in_cur = self.offset_in_current; - remaining -= rows_in_task; - if bytes_to_take + self.offset_in_current == cur_buf.len() { - self.data.pop_front(); + let bytes_needed = num_rows as usize * self.total_bytes_per_value; + let mut rows_taken = num_rows; + let task_slice = if bytes_needed >= bytes_avail { self.offset_in_current = 0; + rows_taken = bytes_avail as u64 / self.total_bytes_per_value as u64; + self.data + .pop_front() + .unwrap() + .slice_with_length(offset_in_cur, bytes_avail) } else { - self.offset_in_current += bytes_to_take; + self.offset_in_current += bytes_needed; + cur_buf.slice_with_length(offset_in_cur, bytes_needed) + }; + FullZipDecodeTaskItem { + data: PerValueDataBlock::Fixed(FixedWidthDataBlock { + data: task_slice, + bits_per_value: self.bytes_per_value as u64 * 8, + num_values: rows_taken, + block_info: BlockInfo::new(), + }), + rows_in_buf: rows_taken, } } - let num_rows = task_data.iter().map(|td| td.1).sum::() as usize; + } +} + +impl StructuralPageDecoder for FixedFullZipDecoder { + fn drain(&mut self, num_rows: u64) -> Result> { + let mut task_data = Vec::with_capacity(self.data.len()); + let mut remaining = num_rows; + while remaining > 0 { + let task_item = self.slice_next_task(remaining); + remaining -= task_item.rows_in_buf; + task_data.push(task_item); + } Ok(Box::new(FixedFullZipDecodeTask { - value_decompressor: self.value_decompressor.clone(), - ctrl_word_parser: self.ctrl_word_parser, + details: self.details.clone(), data: task_data, bytes_per_value: self.bytes_per_value, + num_rows: num_rows as usize, + })) + } + + fn num_rows(&self) -> u64 { + self.num_rows + } +} + +/// A decoder for full-zip encoded data when the data has a variable-width +/// +/// Here we need to unzip the control words AND lengths from the values and +/// then decompress the requested values. +#[derive(Debug)] +struct VariableFullZipDecoder { + details: Arc, + decompressor: Arc, + data: LanceBuffer, + offsets: LanceBuffer, + rep: ScalarBuffer, + def: ScalarBuffer, + repdef_starts: Vec, + data_starts: Vec, + offset_starts: Vec, + visible_item_counts: Vec, + bits_per_offset: u8, + current_idx: usize, + num_rows: u64, +} + +impl VariableFullZipDecoder { + fn new( + details: Arc, + data: VecDeque, + num_rows: u64, + in_bits_per_length: u8, + out_bits_per_offset: u8, + ) -> Self { + let decompressor = match details.value_decompressor { + PerValueDecompressor::Variable(ref d) => d.clone(), + _ => unreachable!(), + }; + + assert_eq!(in_bits_per_length % 8, 0); + assert!(out_bits_per_offset == 32 || out_bits_per_offset == 64); + + let mut decoder = Self { + details, + decompressor, + data: LanceBuffer::empty(), + offsets: LanceBuffer::empty(), + rep: LanceBuffer::empty().borrow_to_typed_slice(), + def: LanceBuffer::empty().borrow_to_typed_slice(), + bits_per_offset: out_bits_per_offset, + repdef_starts: Vec::with_capacity(num_rows as usize + 1), + data_starts: Vec::with_capacity(num_rows as usize + 1), + offset_starts: Vec::with_capacity(num_rows as usize + 1), + visible_item_counts: Vec::with_capacity(num_rows as usize + 1), + current_idx: 0, num_rows, + }; + + // There's no great time to do this and this is the least worst time. If we don't unzip then + // we can't slice the data during the decode phase. This is because we need the offsets to be + // unpacked to know where the values start and end. + // + // We don't want to unzip on the decode thread because that is a single-threaded path + // We don't want to unzip on the scheduling thread because that is a single-threaded path + // + // Fortunately, we know variable length data will always be read indirectly and so we can do it + // here, which should be on the indirect thread. The primary disadvantage to doing it here is that + // we load all the data into memory and then throw it away only to load it all into memory again during + // the decode. + // + // There are some alternatives to investigate: + // - Instead of just reading the beginning and end of the rep index we could read the entire + // range in between. This will give us the break points that we need for slicing and won't increase + // the number of IOPs but it will mean we are doing more total I/O and we need to load the rep index + // even when doing a full scan. + // - We could force each decode task to do a full unzip of all the data. Each decode task now + // has to do more work but the work is all fused. + // - We could just try doing this work on the decode thread and see if it is a problem. + decoder.unzip(data, in_bits_per_length, out_bits_per_offset, num_rows); + + decoder + } + + unsafe fn parse_length(data: &[u8], bits_per_offset: u8) -> u64 { + match bits_per_offset { + 8 => *data.get_unchecked(0) as u64, + 16 => u16::from_le_bytes([*data.get_unchecked(0), *data.get_unchecked(1)]) as u64, + 32 => u32::from_le_bytes([ + *data.get_unchecked(0), + *data.get_unchecked(1), + *data.get_unchecked(2), + *data.get_unchecked(3), + ]) as u64, + 64 => u64::from_le_bytes([ + *data.get_unchecked(0), + *data.get_unchecked(1), + *data.get_unchecked(2), + *data.get_unchecked(3), + *data.get_unchecked(4), + *data.get_unchecked(5), + *data.get_unchecked(6), + *data.get_unchecked(7), + ]), + _ => unreachable!(), + } + } + + fn unzip( + &mut self, + data: VecDeque, + in_bits_per_length: u8, + out_bits_per_offset: u8, + num_rows: u64, + ) { + // This undercounts if there are lists but, at this point, we don't really know how many items we have + let mut rep = Vec::with_capacity(num_rows as usize); + let mut def = Vec::with_capacity(num_rows as usize); + let bytes_cw = self.details.ctrl_word_parser.bytes_per_word() * num_rows as usize; + + // This undercounts if there are lists + // It can also overcount if there are invisible items + let bytes_per_offset = out_bits_per_offset as usize / 8; + let bytes_offsets = bytes_per_offset * (num_rows as usize + 1); + let mut offsets_data = Vec::with_capacity(bytes_offsets); + + let bytes_per_length = in_bits_per_length as usize / 8; + let bytes_lengths = bytes_per_length * num_rows as usize; + + let bytes_data = data.iter().map(|d| d.len()).sum::(); + // This overcounts since bytes_lengths and bytes_cw are undercounts + // It can also undercount if there are invisible items (hence the saturating_sub) + let mut unzipped_data = + Vec::with_capacity((bytes_data - bytes_cw).saturating_sub(bytes_lengths)); + + let mut current_offset = 0_u64; + let mut visible_item_count = 0_u64; + for databuf in data.into_iter() { + let mut databuf = databuf.as_ref(); + while !databuf.is_empty() { + let data_start = unzipped_data.len(); + let offset_start = offsets_data.len(); + // We might have only-rep or only-def, neither, or both. They move at the same + // speed though so we only need one index into it + let repdef_start = rep.len().max(def.len()); + // TODO: Kind of inefficient we parse the control word twice here + let ctrl_desc = self.details.ctrl_word_parser.parse_desc( + databuf, + self.details.max_rep, + self.details.max_visible_def, + ); + self.details + .ctrl_word_parser + .parse(databuf, &mut rep, &mut def); + databuf = &databuf[self.details.ctrl_word_parser.bytes_per_word()..]; + + if ctrl_desc.is_new_row { + self.repdef_starts.push(repdef_start); + self.data_starts.push(data_start); + self.offset_starts.push(offset_start); + self.visible_item_counts.push(visible_item_count); + } + if ctrl_desc.is_visible { + visible_item_count += 1; + if ctrl_desc.is_valid_item { + // Safety: Data should have at least bytes_per_length bytes remaining + debug_assert!(databuf.len() >= bytes_per_length); + let length = unsafe { Self::parse_length(databuf, in_bits_per_length) }; + match out_bits_per_offset { + 32 => offsets_data + .extend_from_slice(&(current_offset as u32).to_le_bytes()), + 64 => offsets_data.extend_from_slice(¤t_offset.to_le_bytes()), + _ => unreachable!(), + }; + databuf = &databuf[bytes_per_offset..]; + unzipped_data.extend_from_slice(&databuf[..length as usize]); + databuf = &databuf[length as usize..]; + current_offset += length; + } else { + // Null items still get an offset + match out_bits_per_offset { + 32 => offsets_data + .extend_from_slice(&(current_offset as u32).to_le_bytes()), + 64 => offsets_data.extend_from_slice(¤t_offset.to_le_bytes()), + _ => unreachable!(), + } + } + } + } + } + self.repdef_starts.push(rep.len().max(def.len())); + self.data_starts.push(unzipped_data.len()); + self.offset_starts.push(offsets_data.len()); + self.visible_item_counts.push(visible_item_count); + match out_bits_per_offset { + 32 => offsets_data.extend_from_slice(&(current_offset as u32).to_le_bytes()), + 64 => offsets_data.extend_from_slice(¤t_offset.to_le_bytes()), + _ => unreachable!(), + }; + self.rep = ScalarBuffer::from(rep); + self.def = ScalarBuffer::from(def); + self.data = LanceBuffer::Owned(unzipped_data); + self.offsets = LanceBuffer::Owned(offsets_data); + } +} + +impl StructuralPageDecoder for VariableFullZipDecoder { + fn drain(&mut self, num_rows: u64) -> Result> { + let start = self.current_idx; + let end = start + num_rows as usize; + + // This might seem a little peculiar. We are returning the entire data for every single + // batch. This is because the offsets are relative to the start of the data. In other words + // imagine we have a data buffer that is 100 bytes long and the offsets are [0, 10, 20, 30, 40] + // and we return in batches of two. The second set of offsets will be [20, 30, 40]. + // + // So either we pay for a copy to normalize the offsets or we just return the entire data buffer + // which is slightly cheaper. + let data = self.data.borrow_and_clone(); + + let offset_start = self.offset_starts[start]; + let offset_end = self.offset_starts[end] + (self.bits_per_offset as usize / 8); + let offsets = self + .offsets + .slice_with_length(offset_start, offset_end - offset_start); + + let repdef_start = self.repdef_starts[start]; + let repdef_end = self.repdef_starts[end]; + let rep = if self.rep.is_empty() { + self.rep.clone() + } else { + self.rep.slice(repdef_start, repdef_end - repdef_start) + }; + let def = if self.def.is_empty() { + self.def.clone() + } else { + self.def.slice(repdef_start, repdef_end - repdef_start) + }; + + let visible_item_counts_start = self.visible_item_counts[start]; + let visible_item_counts_end = self.visible_item_counts[end]; + let num_visible_items = visible_item_counts_end - visible_item_counts_start; + + self.current_idx += num_rows as usize; + + Ok(Box::new(VariableFullZipDecodeTask { + details: self.details.clone(), + decompressor: self.decompressor.clone(), + data, + offsets, + bits_per_offset: self.bits_per_offset, + num_visible_items, + rep, + def, })) } @@ -1052,13 +2499,51 @@ impl StructuralPageDecoder for FixedFullZipDecoder { } } +#[derive(Debug)] +struct VariableFullZipDecodeTask { + details: Arc, + decompressor: Arc, + data: LanceBuffer, + offsets: LanceBuffer, + bits_per_offset: u8, + num_visible_items: u64, + rep: ScalarBuffer, + def: ScalarBuffer, +} + +impl DecodePageTask for VariableFullZipDecodeTask { + fn decode(self: Box) -> Result { + let block = VariableWidthBlock { + data: self.data, + offsets: self.offsets, + bits_per_offset: self.bits_per_offset, + num_values: self.num_visible_items, + block_info: BlockInfo::new(), + }; + let decomopressed = self.decompressor.decompress(block)?; + let rep = self.rep.to_vec(); + let def = self.def.to_vec(); + let unraveler = + RepDefUnraveler::new(Some(rep), Some(def), self.details.def_meaning.clone()); + Ok(DecodedPage { + data: decomopressed, + repdef: unraveler, + }) + } +} + +#[derive(Debug)] +struct FullZipDecodeTaskItem { + data: PerValueDataBlock, + rows_in_buf: u64, +} + /// A task to unzip and decompress full-zip encoded data when that data /// has a fixed-width. #[derive(Debug)] struct FixedFullZipDecodeTask { - value_decompressor: Arc, - ctrl_word_parser: ControlWordParser, - data: Vec<(LanceBuffer, u64)>, + details: Arc, + data: Vec, num_rows: usize, bytes_per_value: usize, } @@ -1066,60 +2551,101 @@ struct FixedFullZipDecodeTask { impl DecodePageTask for FixedFullZipDecodeTask { fn decode(self: Box) -> Result { // Multiply by 2 to make a stab at the size of the output buffer (which will be decompressed and thus bigger) - let estimated_size_bytes = self.data.iter().map(|data| data.0.len()).sum::() * 2; + let estimated_size_bytes = self + .data + .iter() + .map(|task_item| task_item.data.data_size() as usize) + .sum::() + * 2; let mut data_builder = DataBlockBuilder::with_capacity_estimate(estimated_size_bytes as u64); - if self.ctrl_word_parser.bytes_per_word() == 0 { + if self.details.ctrl_word_parser.bytes_per_word() == 0 { // Fast path, no need to unzip because there is no rep/def // // We decompress each buffer and add it to our output buffer - for (buf, rows_in_buf) in self.data.into_iter() { - let decompressed = self.value_decompressor.decompress(buf, rows_in_buf)?; - data_builder.append(&decompressed, 0..rows_in_buf); + for task_item in self.data.into_iter() { + let PerValueDataBlock::Fixed(fixed_data) = task_item.data else { + unreachable!() + }; + let PerValueDecompressor::Fixed(decompressor) = &self.details.value_decompressor + else { + unreachable!() + }; + debug_assert_eq!(fixed_data.num_values, task_item.rows_in_buf); + let decompressed = decompressor.decompress(fixed_data, task_item.rows_in_buf)?; + data_builder.append(&decompressed, 0..task_item.rows_in_buf); } + let unraveler = RepDefUnraveler::new(None, None, self.details.def_meaning.clone()); + Ok(DecodedPage { data: data_builder.finish(), - repetition: None, - definition: None, + repdef: unraveler, }) } else { // Slow path, unzipping needed let mut rep = Vec::with_capacity(self.num_rows); let mut def = Vec::with_capacity(self.num_rows); - for (buf, rows_in_buf) in self.data.into_iter() { - let mut buf_slice = buf.as_ref(); + for task_item in self.data.into_iter() { + let PerValueDataBlock::Fixed(fixed_data) = task_item.data else { + unreachable!() + }; + let mut buf_slice = fixed_data.data.as_ref(); + let num_values = fixed_data.num_values as usize; // We will be unzipping repdef in to `rep` and `def` and the // values into `values` (which contains the compressed values) let mut values = Vec::with_capacity( - buf.len() - (self.ctrl_word_parser.bytes_per_word() * rows_in_buf as usize), + fixed_data.data.len() + - (self.details.ctrl_word_parser.bytes_per_word() * num_values), ); - for _ in 0..rows_in_buf { + let mut visible_items = 0; + for _ in 0..num_values { // Extract rep/def - self.ctrl_word_parser.parse(buf_slice, &mut rep, &mut def); - buf_slice = &buf_slice[self.ctrl_word_parser.bytes_per_word()..]; - // Extract value - values.extend_from_slice(buf_slice[..self.bytes_per_value].as_ref()); - buf_slice = &buf_slice[self.bytes_per_value..]; + self.details + .ctrl_word_parser + .parse(buf_slice, &mut rep, &mut def); + buf_slice = &buf_slice[self.details.ctrl_word_parser.bytes_per_word()..]; + + let is_visible = def + .last() + .map(|d| *d <= self.details.max_visible_def) + .unwrap_or(true); + if is_visible { + // Extract value + values.extend_from_slice(buf_slice[..self.bytes_per_value].as_ref()); + buf_slice = &buf_slice[self.bytes_per_value..]; + visible_items += 1; + } } // Finally, we decompress the values and add them to our output buffer let values_buf = LanceBuffer::Owned(values); - let decompressed = self - .value_decompressor - .decompress(values_buf, rows_in_buf)?; - data_builder.append(&decompressed, 0..rows_in_buf); + let fixed_data = FixedWidthDataBlock { + bits_per_value: self.bytes_per_value as u64 * 8, + block_info: BlockInfo::new(), + data: values_buf, + num_values: visible_items, + }; + let PerValueDecompressor::Fixed(decompressor) = &self.details.value_decompressor + else { + unreachable!() + }; + let decompressed = decompressor.decompress(fixed_data, visible_items)?; + data_builder.append(&decompressed, 0..visible_items); } let repetition = if rep.is_empty() { None } else { Some(rep) }; let definition = if def.is_empty() { None } else { Some(def) }; + let unraveler = + RepDefUnraveler::new(repetition, definition, self.details.def_meaning.clone()); + let data = data_builder.finish(); + Ok(DecodedPage { - data: data_builder.finish(), - repetition, - definition, + data, + repdef: unraveler, }) } } @@ -1131,7 +2657,6 @@ struct StructuralPrimitiveFieldSchedulingJob<'a> { ranges: Vec>, page_idx: usize, range_idx: usize, - range_offset: u64, global_row_offset: u64, } @@ -1142,13 +2667,12 @@ impl<'a> StructuralPrimitiveFieldSchedulingJob<'a> { ranges, page_idx: 0, range_idx: 0, - range_offset: 0, global_row_offset: 0, } } } -impl<'a> StructuralSchedulingJob for StructuralPrimitiveFieldSchedulingJob<'a> { +impl StructuralSchedulingJob for StructuralPrimitiveFieldSchedulingJob<'_> { fn schedule_next( &mut self, context: &mut SchedulerContext, @@ -1158,7 +2682,6 @@ impl<'a> StructuralSchedulingJob for StructuralPrimitiveFieldSchedulingJob<'a> { } // Get our current range let mut range = self.ranges[self.range_idx].clone(); - range.start += self.range_offset; let priority = range.start; let mut cur_page = &self.scheduler.page_schedulers[self.page_idx]; @@ -1214,7 +2737,7 @@ impl<'a> StructuralSchedulingJob for StructuralPrimitiveFieldSchedulingJob<'a> { let page_decoder = cur_page .scheduler - .schedule_ranges(&ranges_in_page, context.io().as_ref())?; + .schedule_ranges(&ranges_in_page, context.io())?; let cur_path = context.current_path(); let page_index = cur_page.page_index; @@ -1262,7 +2785,12 @@ impl StructuralPrimitiveFieldScheduler { .iter() .enumerate() .map(|(page_index, page_info)| { - Self::page_info_to_scheduler(page_info, page_index, decompressors) + Self::page_info_to_scheduler( + page_info, + page_index, + column_info.index as usize, + decompressors, + ) }) .collect::>>()?; Ok(Self { @@ -1274,6 +2802,7 @@ impl StructuralPrimitiveFieldScheduler { fn page_info_to_scheduler( page_info: &PageInfo, page_index: usize, + _column_index: usize, decompressors: &dyn DecompressorStrategy, ) -> Result { let scheduler: Box = @@ -1282,7 +2811,7 @@ impl StructuralPrimitiveFieldScheduler { Box::new(MiniBlockScheduler::try_new( &page_info.buffer_offsets_and_sizes, page_info.priority, - page_info.num_rows, + mini_block.num_items, mini_block, decompressors, )?) @@ -1294,10 +2823,26 @@ impl StructuralPrimitiveFieldScheduler { page_info.num_rows, full_zip, decompressors, + /*bits_per_offset=*/ 32, )?) } - Some(pb::page_layout::Layout::AllNullLayout(_)) => { - Box::new(SimpleAllNullScheduler::default()) as Box + Some(pb::page_layout::Layout::AllNullLayout(all_null)) => { + let def_meaning = all_null + .layers + .iter() + .map(|l| ProtobufUtils::repdef_layer_to_def_interp(*l)) + .collect::>(); + if def_meaning.len() == 1 + && def_meaning[0] == DefinitionInterpretation::NullableItem + { + Box::new(SimpleAllNullScheduler::default()) + as Box + } else { + Box::new(ComplexAllNullScheduler::new( + page_info.buffer_offsets_and_sizes.clone(), + def_meaning.into(), + )) as Box + } } _ => todo!(), }; @@ -1309,19 +2854,61 @@ impl StructuralPrimitiveFieldScheduler { } } +pub trait CachedPageData: Any + Send + Sync + DeepSizeOf + 'static { + fn as_arc_any(self: Arc) -> Arc; +} + +pub struct NoCachedPageData; + +impl DeepSizeOf for NoCachedPageData { + fn deep_size_of_children(&self, _ctx: &mut Context) -> usize { + 0 + } +} +impl CachedPageData for NoCachedPageData { + fn as_arc_any(self: Arc) -> Arc { + self + } +} + +pub struct CachedFieldData { + pages: Vec>, +} + +impl DeepSizeOf for CachedFieldData { + fn deep_size_of_children(&self, ctx: &mut Context) -> usize { + self.pages.deep_size_of_children(ctx) + } +} + impl StructuralFieldScheduler for StructuralPrimitiveFieldScheduler { fn initialize<'a>( &'a mut self, _filter: &'a FilterExpression, context: &'a SchedulerContext, ) -> BoxFuture<'a, Result<()>> { - let page_init = self + let cache_key = self.column_index.to_string(); + if let Some(cached_data) = context.cache().get_by_str::(&cache_key) { + self.page_schedulers + .iter_mut() + .zip(cached_data.pages.iter()) + .for_each(|(page_scheduler, cached_data)| { + page_scheduler.scheduler.load(cached_data); + }); + return std::future::ready(Ok(())).boxed(); + }; + + let cache = context.cache().clone(); + let page_data = self .page_schedulers .iter_mut() .map(|s| s.scheduler.initialize(context.io())) - .collect::>(); + .collect::>(); + async move { - page_init.try_collect::>().await?; + let page_data = page_data.try_collect::>().await?; + let cached_data = Arc::new(CachedFieldData { pages: page_data }); + cache.insert_by_str::(&cache_key, cached_data); Ok(()) } .boxed() @@ -1464,7 +3051,6 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { Ok(NextDecodeTask { task, num_rows: rows_to_take, - has_more: self.rows_drained != self.num_rows, }) } @@ -1498,35 +3084,43 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { #[derive(Debug)] pub struct StructuralCompositeDecodeArrayTask { tasks: Vec>, - num_values: u64, - data_type: DataType, should_validate: bool, + data_type: DataType, +} + +impl StructuralCompositeDecodeArrayTask { + fn restore_validity( + array: Arc, + unraveler: &mut CompositeRepDefUnraveler, + ) -> Arc { + let validity = unraveler.unravel_validity(array.len()); + let Some(validity) = validity else { + return array; + }; + if array.data_type() == &DataType::Null { + // We unravel from a null array but we don't add the null buffer because arrow-rs doesn't like it + return array; + } + assert_eq!(validity.len(), array.len()); + // SAFETY: We've should have already asserted the buffers are all valid, we are just + // adding null buffers to the array here + make_array(unsafe { + array + .to_data() + .into_builder() + .nulls(Some(validity)) + .build_unchecked() + }) + } } impl StructuralDecodeArrayTask for StructuralCompositeDecodeArrayTask { fn decode(self: Box) -> Result { let mut arrays = Vec::with_capacity(self.tasks.len()); - let mut all_rep = LevelBuffer::with_capacity(self.num_values as usize); - let mut all_def = LevelBuffer::with_capacity(self.num_values as usize); - let mut offset = 0; - let mut has_def = false; + let mut unravelers = Vec::with_capacity(self.tasks.len()); for task in self.tasks { let decoded = task.decode()?; - - if let Some(rep) = &decoded.repetition { - // Note: if one chunk has repetition, all chunks will have repetition - // and so all_rep will either end up with len=num_values or len=0 - all_rep.extend(rep); - } - if let Some(def) = &decoded.definition { - if !has_def { - // This is the first validity we have seen, need to backfill with all-valid - // if we've processed any all-valid pages - has_def = true; - all_def.extend(iter::repeat(0).take(offset)); - } - all_def.extend(def); - } + unravelers.push(decoded.repdef); let array = make_array( decoded @@ -1534,41 +3128,14 @@ impl StructuralDecodeArrayTask for StructuralCompositeDecodeArrayTask { .into_arrow(self.data_type.clone(), self.should_validate)?, ); - offset += array.len(); arrays.push(array); } let array_refs = arrays.iter().map(|arr| arr.as_ref()).collect::>(); let array = arrow_select::concat::concat(&array_refs)?; - let all_rep = if all_rep.is_empty() { - None - } else { - Some(all_rep) - }; - let all_def = if all_def.is_empty() { - None - } else { - Some(all_def) - }; - let mut repdef = RepDefUnraveler::new(all_rep, all_def); + let mut repdef = CompositeRepDefUnraveler::new(unravelers); + + let array = Self::restore_validity(array, &mut repdef); - // The primitive array itself has a validity - let mut validity = repdef.unravel_validity(); - if matches!(self.data_type, DataType::Null) { - // Null arrays don't have a validity but we still pretend they do for consistency's sake - // up until this point. We need to remove it here. - validity = None; - } - if let Some(validity) = validity.as_ref() { - assert!(validity.len() == array.len()); - } - // SAFETY: We are just replacing the validity and asserted it is the correct size - let array = make_array(unsafe { - array - .to_data() - .into_builder() - .nulls(validity) - .build_unchecked() - }); Ok(DecodedArray { array, repdef }) } } @@ -1621,9 +3188,8 @@ impl StructuralFieldDecoder for StructuralPrimitiveFieldDecoder { } Ok(Box::new(StructuralCompositeDecodeArrayTask { tasks, - data_type: self.field.data_type().clone(), should_validate: self.should_validate, - num_values: num_rows, + data_type: self.field.data_type().clone(), })) } @@ -1640,6 +3206,8 @@ pub struct AccumulationQueue { current_bytes: u64, // Row number of the first item in buffered_arrays, reset on flush row_number: u64, + // Number of top level rows represented in buffered_arrays, reset on flush + num_rows: u64, // This is only for logging / debugging purposes column_index: u32, } @@ -1653,15 +3221,22 @@ impl AccumulationQueue { column_index, keep_original_array, row_number: u64::MAX, + num_rows: 0, } } /// Adds an array to the queue, if there is enough data then the queue is flushed /// and returned - pub fn insert(&mut self, array: ArrayRef, row_number: u64) -> Option<(Vec, u64)> { + pub fn insert( + &mut self, + array: ArrayRef, + row_number: u64, + num_rows: u64, + ) -> Option<(Vec, u64, u64)> { if self.row_number == u64::MAX { self.row_number = row_number; } + self.num_rows += num_rows; self.current_bytes += array.get_array_memory_size() as u64; if self.current_bytes > self.cache_bytes { debug!( @@ -1673,7 +3248,13 @@ impl AccumulationQueue { self.current_bytes = 0; let row_number = self.row_number; self.row_number = u64::MAX; - Some((std::mem::take(&mut self.buffered_arrays), row_number)) + let num_rows = self.num_rows; + self.num_rows = 0; + Some(( + std::mem::take(&mut self.buffered_arrays), + row_number, + num_rows, + )) } else { trace!( "Accumulating data for column {}. Now at {} bytes", @@ -1689,7 +3270,7 @@ impl AccumulationQueue { } } - pub fn flush(&mut self) -> Option<(Vec, u64)> { + pub fn flush(&mut self) -> Option<(Vec, u64, u64)> { if self.buffered_arrays.is_empty() { trace!( "No final flush since no data at column {}", @@ -1704,8 +3285,14 @@ impl AccumulationQueue { ); self.current_bytes = 0; let row_number = self.row_number; - self.row_number = 0; - Some((std::mem::take(&mut self.buffered_arrays), row_number)) + self.row_number = u64::MAX; + let num_rows = self.num_rows; + self.num_rows = 0; + Some(( + std::mem::take(&mut self.buffered_arrays), + row_number, + num_rows, + )) } } } @@ -1784,6 +3371,9 @@ impl PrimitiveFieldEncoder { let part_size = bit_util::ceil(array.len(), num_parts); for _ in 0..num_parts { let avail = array.len() - offset; + if avail == 0 { + break; + } let chunk_size = avail.min(part_size); let part = array.slice(offset, chunk_size); let task = self.create_encode_task(vec![part])?; @@ -1808,9 +3398,10 @@ impl FieldEncoder for PrimitiveFieldEncoder { array: ArrayRef, _external_buffers: &mut OutOfLineBuffers, _repdef: RepDefBuilder, - _row_number: u64, + row_number: u64, + num_rows: u64, ) -> Result> { - if let Some(arrays) = self.accumulation_queue.insert(array, /*row_number=*/ 0) { + if let Some(arrays) = self.accumulation_queue.insert(array, row_number, num_rows) { Ok(self.do_flush(arrays.0)?) } else { Ok(vec![]) @@ -1838,6 +3429,14 @@ impl FieldEncoder for PrimitiveFieldEncoder { } } +/// The serialized representation of full-zip data +struct SerializedFullZip { + /// The zipped values buffer + values: LanceBuffer, + /// The repetition index (only present if there is repetition) + repetition_index: Option, +} + // We align and pad mini-blocks to 8 byte boundaries for two reasons. First, // to allow us to store a chunk size in 12 bits. // @@ -1858,7 +3457,6 @@ impl FieldEncoder for PrimitiveFieldEncoder { // Note: by "aligned to 8 bytes" we mean BOTH "aligned to 8 bytes from the start of // the page" and "aligned to 8 bytes from the start of the file." const MINIBLOCK_ALIGNMENT: usize = 8; -const MINIBLOCK_MAX_PADDING: usize = MINIBLOCK_ALIGNMENT - 1; /// An encoder for primitive (leaf) arrays /// @@ -1894,6 +3492,24 @@ pub struct PrimitiveStructuralEncoder { compression_strategy: Arc, column_index: u32, field: Field, + encoding_metadata: Arc>, +} + +struct CompressedLevelsChunk { + data: LanceBuffer, + num_levels: u16, +} + +struct CompressedLevels { + data: Vec, + compression: pb::ArrayEncoding, + rep_index: Option, +} + +struct SerializedMiniBlockPage { + num_buffers: u64, + data: LanceBuffer, + metadata: LanceBuffer, } impl PrimitiveStructuralEncoder { @@ -1902,6 +3518,7 @@ impl PrimitiveStructuralEncoder { compression_strategy: Arc, column_index: u32, field: Field, + encoding_metadata: Arc>, ) -> Result { Ok(Self { accumulation_queue: AccumulationQueue::new( @@ -1913,6 +3530,7 @@ impl PrimitiveStructuralEncoder { column_index, compression_strategy, field, + encoding_metadata, }) } @@ -1935,25 +3553,41 @@ impl PrimitiveStructuralEncoder { return true; } } - false + false + } + + fn prefers_miniblock( + data_block: &DataBlock, + encoding_metadata: &HashMap, + ) -> bool { + // If the user specifically requested miniblock then use it + if let Some(user_requested) = encoding_metadata.get(STRUCTURAL_ENCODING_META_KEY) { + return user_requested.to_lowercase() == STRUCTURAL_ENCODING_MINIBLOCK; + } + // Otherwise only use miniblock if it is narrow + Self::is_narrow(data_block) + } + + fn prefers_fullzip(encoding_metadata: &HashMap) -> bool { + // Fullzip is the backup option so the only reason we wouldn't use it is if the + // user specifically requested not to use it (in which case we're probably going + // to emit an error) + if let Some(user_requested) = encoding_metadata.get(STRUCTURAL_ENCODING_META_KEY) { + return user_requested.to_lowercase() == STRUCTURAL_ENCODING_FULLZIP; + } + true } // Converts value data, repetition levels, and definition levels into a single // buffer of mini-blocks. In addition, creates a buffer of mini-block metadata - // which tells us the size of each block. + // which tells us the size of each block. Finally, if repetition is present then + // we also create a buffer for the repetition index. // // Each chunk is serialized as: - // | rep_len (2 bytes) | def_len (2 bytes) | values_len (2 bytes) | rep | P1 | def | P2 | values | P3 | - // - // P1 - Up to 1 padding byte to ensure `def` is 2-byte aligned - // P2 - Up to 7 padding bytes to ensure `values` is 8-byte aligned - // P3 - Up to 7 padding bytes to ensure the chunk is a multiple of 8 bytes (this also ensures - // that the next `chunk` is 8-byte aligned) + // | num_bufs (1 byte) | buf_lens (2 bytes per buffer) | P | buf0 | P | buf1 | ... | bufN | P | // - // rep is guaranteed to be 2-byte aligned - // def is guaranteed to be 2-byte aligned - // values is guaranteed to be 8-byte aligned - // rep_len, def_len, and values_len are guaranteed to be 2-byte aligned but this shouldn't matter. + // P - Padding inserted to ensure each buffer is 8-byte aligned and the buffer size is a multiple + // of 8 bytes (so that the next chunk is 8-byte aligned). // // Each block has a u16 word of metadata. The upper 12 bits contain 1/6 the // # of bytes in the block (if the block does not have an even number of bytes @@ -1968,63 +3602,117 @@ impl PrimitiveStructuralEncoder { // // All metadata words are serialized (as little endian) into a single buffer // of metadata values. + // + // If there is repetition then we also create a repetition index. This is a + // single buffer of integer vectors (stored in row major order). There is one + // entry for each chunk. The size of the vector is based on the depth of random + // access we want to support. + // + // A vector of size 2 is the minimum and will support row-based random access (e.g. + // "take the 57th row"). A vector of size 3 will support 1 level of nested access + // (e.g. "take the 3rd item in the 57th row"). A vector of size 4 will support 2 + // levels of nested access and so on. + // + // The first number in the vector is the number of top-level rows that complete in + // the chunk. The second number is the number of second-level rows that complete + // after the final top-level row completed (or beginning of the chunk if no top-level + // row completes in the chunk). And so on. The final number in the vector is always + // the number of leftover items not covered by earlier entries in the vector. + // + // Currently we are limited to 0 levels of nested access but that will change in the + // future. + // + // The repetition index and the chunk metadata are read at initialization time and + // cached in memory. fn serialize_miniblocks( miniblocks: MiniBlockCompressed, - rep: Vec, - def: Vec, - ) -> (LanceBuffer, LanceBuffer) { - let bytes_rep = rep.iter().map(|r| r.len()).sum::(); - let bytes_def = def.iter().map(|d| d.len()).sum::(); - let max_bytes_repdef_len = rep.len() * 4; - let max_padding = miniblocks.chunks.len() * (1 + (2 * MINIBLOCK_MAX_PADDING)); - let mut data_buffer = Vec::with_capacity( - miniblocks.data.len() // `values` - + bytes_rep // `rep_len * num_blocks` - + bytes_def // `def_len * num_blocks` - + max_bytes_repdef_len // `rep` and `def` - + max_padding, // `P1`, `P2`, and `P3` for each block - ); - let mut meta_buffer = Vec::with_capacity(miniblocks.data.len() * 2); + rep: Option>, + def: Option>, + ) -> SerializedMiniBlockPage { + let bytes_rep = rep + .as_ref() + .map(|rep| rep.iter().map(|r| r.data.len()).sum::()) + .unwrap_or(0); + let bytes_def = def + .as_ref() + .map(|def| def.iter().map(|d| d.data.len()).sum::()) + .unwrap_or(0); + let bytes_data = miniblocks.data.iter().map(|d| d.len()).sum::(); + let mut num_buffers = miniblocks.data.len(); + if rep.is_some() { + num_buffers += 1; + } + if def.is_some() { + num_buffers += 1; + } + // 2 bytes for the length of each buffer and up to 7 bytes of padding per buffer + let max_extra = 9 * num_buffers; + let mut data_buffer = Vec::with_capacity(bytes_rep + bytes_def + bytes_data + max_extra); + let mut meta_buffer = Vec::with_capacity(miniblocks.chunks.len() * 2); - let mut value_offset = 0; - for ((chunk, rep), def) in miniblocks.chunks.into_iter().zip(rep).zip(def) { - let start_len = data_buffer.len(); + let mut rep_iter = rep.map(|r| r.into_iter()); + let mut def_iter = def.map(|d| d.into_iter()); + + let mut buffer_offsets = vec![0; miniblocks.data.len()]; + for chunk in miniblocks.chunks { + let start_pos = data_buffer.len(); // Start of chunk should be aligned - debug_assert_eq!(start_len % MINIBLOCK_ALIGNMENT, 0); - - assert!(rep.len() < u16::MAX as usize); - assert!(def.len() < u16::MAX as usize); - let bytes_rep = rep.len() as u16; - let bytes_def = def.len() as u16; - let bytes_val = chunk.num_bytes; - - // Each chunk starts with the size of the rep buffer (2 bytes) the size of - // the def buffer (2 bytes) and the size of the values buffer (2 bytes) - data_buffer.extend_from_slice(&bytes_rep.to_le_bytes()); - data_buffer.extend_from_slice(&bytes_def.to_le_bytes()); - data_buffer.extend_from_slice(&bytes_val.to_le_bytes()); - - data_buffer.extend_from_slice(&rep); - // In theory we should insert P1 here. However, since we do not have bit-packing of rep - // def levels yet we can skip this step. - debug_assert_eq!(data_buffer.len() % 2, 0); - data_buffer.extend_from_slice(&def); - - let p2 = pad_bytes::(data_buffer.len()); - // SAFETY: We ensured the data buffer would be large enough when we allocated - data_buffer.extend(iter::repeat(0).take(p2)); - - let num_value_bytes = chunk.num_bytes as usize; - let values = - &miniblocks.data[value_offset as usize..value_offset as usize + num_value_bytes]; - debug_assert_eq!(data_buffer.len() % MINIBLOCK_ALIGNMENT, 0); - data_buffer.extend_from_slice(values); - - let p3 = pad_bytes::(data_buffer.len()); - data_buffer.extend(iter::repeat(0).take(p3)); - value_offset += num_value_bytes as u64; - - let chunk_bytes = data_buffer.len() - start_len; + debug_assert_eq!(start_pos % MINIBLOCK_ALIGNMENT, 0); + + let rep = rep_iter.as_mut().map(|r| r.next().unwrap()); + let def = def_iter.as_mut().map(|d| d.next().unwrap()); + + // Write the number of levels, or 0 if there is no rep/def + let num_levels = rep + .as_ref() + .map(|r| r.num_levels) + .unwrap_or(def.as_ref().map(|d| d.num_levels).unwrap_or(0)); + data_buffer.extend_from_slice(&num_levels.to_le_bytes()); + + // Write the buffer lengths + if let Some(rep) = rep.as_ref() { + let bytes_rep = u16::try_from(rep.data.len()).unwrap(); + data_buffer.extend_from_slice(&bytes_rep.to_le_bytes()); + } + if let Some(def) = def.as_ref() { + let bytes_def = u16::try_from(def.data.len()).unwrap(); + data_buffer.extend_from_slice(&bytes_def.to_le_bytes()); + } + + for buffer_size in &chunk.buffer_sizes { + let bytes = *buffer_size; + data_buffer.extend_from_slice(&bytes.to_le_bytes()); + } + + // Pad + let add_padding = |data_buffer: &mut Vec| { + let pad = pad_bytes::(data_buffer.len()); + data_buffer.extend(iter::repeat_n(FILL_BYTE, pad)); + }; + add_padding(&mut data_buffer); + + // Write the buffers themselves + if let Some(rep) = rep.as_ref() { + data_buffer.extend_from_slice(&rep.data); + add_padding(&mut data_buffer); + } + if let Some(def) = def.as_ref() { + data_buffer.extend_from_slice(&def.data); + add_padding(&mut data_buffer); + } + for (buffer_size, (buffer, buffer_offset)) in chunk + .buffer_sizes + .iter() + .zip(miniblocks.data.iter().zip(buffer_offsets.iter_mut())) + { + let start = *buffer_offset; + let end = start + *buffer_size as usize; + *buffer_offset += *buffer_size as usize; + data_buffer.extend_from_slice(&buffer[start..end]); + add_padding(&mut data_buffer); + } + + let chunk_bytes = data_buffer.len() - start_pos; assert!(chunk_bytes <= 16 * 1024); assert!(chunk_bytes > 0); assert_eq!(chunk_bytes % 8, 0); @@ -2038,63 +3726,129 @@ impl PrimitiveStructuralEncoder { meta_buffer.extend_from_slice(&metadata.to_le_bytes()); } - ( - LanceBuffer::Owned(data_buffer), - LanceBuffer::Owned(meta_buffer), - ) + let data_buffer = LanceBuffer::Owned(data_buffer); + let metadata_buffer = LanceBuffer::Owned(meta_buffer); + + SerializedMiniBlockPage { + num_buffers: miniblocks.data.len() as u64, + data: data_buffer, + metadata: metadata_buffer, + } } /// Compresses a buffer of levels into chunks /// - /// TODO: Use bit-packing here + /// If these are repetition levels then we also calculate the repetition index here (that + /// is the third return value) fn compress_levels( - levels: Option, - num_values: u64, + mut levels: RepDefSlicer<'_>, + num_elements: u64, compression_strategy: &dyn CompressionStrategy, chunks: &[MiniBlockChunk], - ) -> Result<(Vec, pb::ArrayEncoding)> { - if let Some(levels) = levels { - debug_assert_eq!(num_values as usize, levels.len()); - // Make the levels into a FixedWidth data block - let mut levels_buf = LanceBuffer::reinterpret_vec(levels); - let levels_block = DataBlock::FixedWidth(FixedWidthDataBlock { - data: levels_buf.borrow_and_clone(), + // This will be 0 if we are compressing def levels + max_rep: u16, + ) -> Result { + let mut rep_index = if max_rep > 0 { + Vec::with_capacity(chunks.len()) + } else { + vec![] + }; + // Make the levels into a FixedWidth data block + let num_levels = levels.num_levels() as u64; + let mut levels_buf = levels.all_levels().try_clone().unwrap(); + let levels_block = DataBlock::FixedWidth(FixedWidthDataBlock { + data: levels_buf.borrow_and_clone(), + bits_per_value: 16, + num_values: num_levels, + block_info: BlockInfo::new(), + }); + let levels_field = Field::new_arrow("", DataType::UInt16, false)?; + // Pick a block compressor + let (compressor, compressor_desc) = + compression_strategy.create_block_compressor(&levels_field, &levels_block)?; + // Compress blocks of levels (sized according to the chunks) + let mut level_chunks = Vec::with_capacity(chunks.len()); + let mut values_counter = 0; + for (chunk_idx, chunk) in chunks.iter().enumerate() { + let chunk_num_values = chunk.num_values(values_counter, num_elements); + values_counter += chunk_num_values; + let mut chunk_levels = if chunk_idx < chunks.len() - 1 { + levels.slice_next(chunk_num_values as usize) + } else { + levels.slice_rest() + }; + let num_chunk_levels = (chunk_levels.len() / 2) as u64; + if max_rep > 0 { + // If max_rep > 0 then we are working with rep levels and we need + // to calculate the repetition index. The repetition index for a + // chunk is currently 2 values (in the future it may be more). + // + // The first value is the number of rows that _finish_ in the + // chunk. + // + // The second value is the number of "leftovers" after the last + // finished row in the chunk. + let rep_values = chunk_levels.borrow_to_typed_slice::(); + let rep_values = rep_values.as_ref(); + + // We skip 1 here because a max_rep at spot 0 doesn't count as a finished list (we + // will count it in the previous chunk) + let mut num_rows = rep_values.iter().skip(1).filter(|v| **v == max_rep).count(); + let num_leftovers = if chunk_idx < chunks.len() - 1 { + rep_values + .iter() + .rev() + .position(|v| *v == max_rep) + // # of leftovers includes the max_rep spot + .map(|pos| pos + 1) + .unwrap_or(rep_values.len()) + } else { + // Last chunk can't have leftovers + 0 + }; + + if chunk_idx != 0 && rep_values[0] == max_rep { + // This chunk starts with a new row and so, if we thought we had leftovers + // in the previous chunk, we were mistaken + // TODO: Can use unchecked here + let rep_len = rep_index.len(); + if rep_index[rep_len - 1] != 0 { + // We thought we had leftovers but that was actually a full row + rep_index[rep_len - 2] += 1; + rep_index[rep_len - 1] = 0; + } + } + + if chunk_idx == chunks.len() - 1 { + // The final list + num_rows += 1; + } + rep_index.push(num_rows as u64); + rep_index.push(num_leftovers as u64); + } + let chunk_levels_block = DataBlock::FixedWidth(FixedWidthDataBlock { + data: chunk_levels, bits_per_value: 16, - num_values, + num_values: num_chunk_levels, block_info: BlockInfo::new(), }); - let levels_field = Field::new_arrow("", DataType::UInt16, false)?; - // Pick a block compressor - let (compressor, compressor_desc) = - compression_strategy.create_block_compressor(&levels_field, &levels_block)?; - // Compress blocks of levels (sized according to the chunks) - let mut buffers = Vec::with_capacity(chunks.len()); - let mut off = 0; - let mut values_counter = 0; - for chunk in chunks { - let chunk_num_values = chunk.num_values(values_counter, num_values); - values_counter += chunk_num_values; - let level_bytes = chunk_num_values as usize * 2; - let chunk_levels = levels_buf.slice_with_length(off, level_bytes); - let chunk_levels_block = DataBlock::FixedWidth(FixedWidthDataBlock { - data: chunk_levels, - bits_per_value: 16, - num_values: chunk_num_values, - block_info: BlockInfo::new(), - }); - let compressed_levels = compressor.compress(chunk_levels_block)?; - off += level_bytes; - buffers.push(compressed_levels); - } - Ok((buffers, compressor_desc)) - } else { - // Everything is valid or we have no repetition so we encode as a constant - // array of 0 - let data = chunks.iter().map(|_| LanceBuffer::empty()).collect(); - let scalar = 0_u16.to_le_bytes().to_vec(); - let encoding = ProtobufUtils::constant(scalar, num_values); - Ok((data, encoding)) + let compressed_levels = compressor.compress(chunk_levels_block)?; + level_chunks.push(CompressedLevelsChunk { + data: compressed_levels, + num_levels: num_chunk_levels as u16, + }); } + debug_assert_eq!(levels.num_levels_remaining(), 0); + let rep_index = if rep_index.is_empty() { + None + } else { + Some(LanceBuffer::reinterpret_vec(rep_index)) + }; + Ok(CompressedLevels { + data: level_chunks, + compression: compressor_desc, + rep_index, + }) } fn encode_simple_all_null( @@ -2112,6 +3866,41 @@ impl PrimitiveStructuralEncoder { }) } + // Encodes a page where all values are null but we have rep/def + // information that we need to store (e.g. to distinguish between + // different kinds of null) + fn encode_complex_all_null( + column_idx: u32, + repdefs: Vec, + row_number: u64, + num_rows: u64, + ) -> Result { + let repdef = RepDefBuilder::serialize(repdefs); + + // TODO: Actually compress repdef + let rep_bytes = if let Some(rep) = repdef.repetition_levels.as_ref() { + LanceBuffer::reinterpret_slice(rep.clone()) + } else { + LanceBuffer::empty() + }; + + let def_bytes = if let Some(def) = repdef.definition_levels.as_ref() { + LanceBuffer::reinterpret_slice(def.clone()) + } else { + LanceBuffer::empty() + }; + + let description = ProtobufUtils::all_null_layout(&repdef.def_meaning); + Ok(EncodedPage { + column_idx, + data: vec![rep_bytes, def_bytes], + description: PageEncoding::Structural(description), + num_rows, + row_number, + }) + } + + #[allow(clippy::too_many_arguments)] fn encode_miniblock( column_idx: u32, field: &Field, @@ -2120,6 +3909,7 @@ impl PrimitiveStructuralEncoder { repdefs: Vec, row_number: u64, dictionary_data: Option, + num_rows: u64, ) -> Result { let repdef = RepDefBuilder::serialize(repdefs); @@ -2129,36 +3919,71 @@ impl PrimitiveStructuralEncoder { todo!() } - let num_values = data.num_values(); - // The validity is encoded in repdef so we can remove it - let data = data.remove_validity(); + // The top-level validity is encoded in repdef so we can remove it. There may be inner + // validities if we have FSL fields but those are not included in the repdef and need to + // be encoded. + let data = data.remove_outer_validity(); + + let num_items = data.num_values(); let compressor = compression_strategy.create_miniblock_compressor(field, &data)?; let (compressed_data, value_encoding) = compressor.compress(data)?; - let (compressed_rep, rep_encoding) = Self::compress_levels( - repdef.repetition_levels, - num_values, - compression_strategy, - &compressed_data.chunks, - )?; + let max_rep = repdef.def_meaning.iter().filter(|l| l.is_list()).count() as u16; + + let mut compressed_rep = repdef + .rep_slicer() + .map(|rep_slicer| { + Self::compress_levels( + rep_slicer, + num_items, + compression_strategy, + &compressed_data.chunks, + max_rep, + ) + }) + .transpose()?; - let (compressed_def, def_encoding) = Self::compress_levels( - repdef.definition_levels, - num_values, - compression_strategy, - &compressed_data.chunks, - )?; + let (rep_index, rep_index_depth) = + match compressed_rep.as_mut().and_then(|cr| cr.rep_index.as_mut()) { + Some(rep_index) => (Some(rep_index.borrow_and_clone()), 1), + None => (None, 0), + }; + + let mut compressed_def = repdef + .def_slicer() + .map(|def_slicer| { + Self::compress_levels( + def_slicer, + num_items, + compression_strategy, + &compressed_data.chunks, + /*max_rep=*/ 0, + ) + }) + .transpose()?; // TODO: Parquet sparsely encodes values here. We could do the same but // then we won't have log2 values per chunk. This means more metadata // and potentially more decoder asymmetry. However, it may be worth // investigating at some point - let (block_value_buffer, block_meta_buffer) = - Self::serialize_miniblocks(compressed_data, compressed_rep, compressed_def); + let rep_data = compressed_rep + .as_mut() + .map(|cr| std::mem::take(&mut cr.data)); + let def_data = compressed_def + .as_mut() + .map(|cd| std::mem::take(&mut cd.data)); + + let serialized = Self::serialize_miniblocks(compressed_data, rep_data, def_data); + + // Metadata, Data, Dictionary, (maybe) Repetition Index + let mut data = Vec::with_capacity(4); + data.push(serialized.metadata); + data.push(serialized.data); if let Some(dictionary_data) = dictionary_data { + let num_dictionary_items = dictionary_data.num_values(); // field in `create_block_compressor` is not used currently. let dummy_dictionary_field = Field::new_arrow("", DataType::UInt16, false)?; @@ -2166,26 +3991,52 @@ impl PrimitiveStructuralEncoder { .create_block_compressor(&dummy_dictionary_field, &dictionary_data)?; let dictionary_buffer = compressor.compress(dictionary_data)?; + data.push(dictionary_buffer); + if let Some(rep_index) = rep_index { + data.push(rep_index); + } + let description = ProtobufUtils::miniblock_layout( - rep_encoding, - def_encoding, + compressed_rep.map(|cr| cr.compression), + compressed_def.map(|cd| cd.compression), value_encoding, - Some(dictionary_encoding), + rep_index_depth, + serialized.num_buffers, + Some((dictionary_encoding, num_dictionary_items)), + &repdef.def_meaning, + num_items, ); Ok(EncodedPage { - num_rows: num_values, + num_rows, column_idx, - data: vec![block_meta_buffer, block_value_buffer, dictionary_buffer], + data, description: PageEncoding::Structural(description), row_number, }) } else { - let description = - ProtobufUtils::miniblock_layout(rep_encoding, def_encoding, value_encoding, None); + let description = ProtobufUtils::miniblock_layout( + compressed_rep.map(|cr| cr.compression), + compressed_def.map(|cd| cd.compression), + value_encoding, + rep_index_depth, + serialized.num_buffers, + None, + &repdef.def_meaning, + num_items, + ); + + if let Some(mut rep_index) = rep_index { + let view = rep_index.borrow_to_typed_slice::(); + let total = view.chunks_exact(2).map(|c| c[0]).sum::(); + debug_assert_eq!(total, num_rows); + + data.push(rep_index); + } + Ok(EncodedPage { - num_rows: num_values, + num_rows, column_idx, - data: vec![block_meta_buffer, block_value_buffer], + data, description: PageEncoding::Structural(description), row_number, }) @@ -2196,9 +4047,19 @@ impl PrimitiveStructuralEncoder { fn serialize_full_zip_fixed( fixed: FixedWidthDataBlock, mut repdef: ControlWordIterator, - ) -> LanceBuffer { - let len = fixed.data.len() + repdef.bytes_per_word() * fixed.num_values as usize; - let mut buf = Vec::with_capacity(len); + num_values: u64, + ) -> SerializedFullZip { + let len = fixed.data.len() + repdef.bytes_per_word() * num_values as usize; + let mut zipped_data = Vec::with_capacity(len); + + let max_rep_index_val = if repdef.has_repetition() { + len as u64 + } else { + // Setting this to 0 means we won't write a repetition index + 0 + }; + let mut rep_index_builder = + BytepackedIntegerEncoder::with_capacity(num_values as usize + 1, max_rep_index_val); // I suppose we can just pad to the nearest byte but I'm not sure we need to worry about this anytime soon // because it is unlikely compression of large values is going to yield a result that is not byte aligned @@ -2210,19 +4071,50 @@ impl PrimitiveStructuralEncoder { let bytes_per_value = fixed.bits_per_value as usize / 8; - for value in fixed.data.chunks_exact(bytes_per_value) { - repdef.append_next(&mut buf); - buf.extend_from_slice(value); + let mut data_iter = fixed.data.chunks_exact(bytes_per_value); + let mut offset = 0; + while let Some(control) = repdef.append_next(&mut zipped_data) { + if control.is_new_row { + // We have finished a row + debug_assert!(offset <= len); + // SAFETY: We know that `start <= len` + unsafe { rep_index_builder.append(offset as u64) }; + } + if control.is_visible { + let value = data_iter.next().unwrap(); + zipped_data.extend_from_slice(value); + } + offset = zipped_data.len(); + } + + debug_assert_eq!(zipped_data.len(), len); + // Put the final value in the rep index + // SAFETY: `zipped_data.len() == len` + unsafe { + rep_index_builder.append(zipped_data.len() as u64); } - LanceBuffer::Owned(buf) + let zipped_data = LanceBuffer::Owned(zipped_data); + let rep_index = rep_index_builder.into_data(); + let rep_index = if rep_index.is_empty() { + None + } else { + Some(LanceBuffer::Owned(rep_index)) + }; + SerializedFullZip { + values: zipped_data, + repetition_index: rep_index, + } } // For variable-size data we encode < control word | length | data > for each value + // + // In addition, we create a second buffer, the repetition index fn serialize_full_zip_variable( mut variable: VariableWidthBlock, mut repdef: ControlWordIterator, - ) -> LanceBuffer { + num_items: u64, + ) -> SerializedFullZip { let bytes_per_offset = variable.bits_per_offset as usize / 8; assert_eq!( variable.bits_per_offset % 8, @@ -2230,34 +4122,82 @@ impl PrimitiveStructuralEncoder { "Only byte-aligned offsets supported" ); let len = variable.data.len() - + repdef.bytes_per_word() * variable.num_values as usize + + repdef.bytes_per_word() * num_items as usize + bytes_per_offset * variable.num_values as usize; let mut buf = Vec::with_capacity(len); - // TODO: We may want to bit-pack lengths in the future. We probably don't need - // full bitpacking (which would cause the data to become unaligned) but we could - // bitpack to the nearest word size (e.g. u8 / u16 / u32) + let max_rep_index_val = len as u64; + let mut rep_index_builder = + BytepackedIntegerEncoder::with_capacity(num_items as usize + 1, max_rep_index_val); + + // TODO: byte pack the item lengths with varint encoding match bytes_per_offset { 4 => { let offs = variable.offsets.borrow_to_typed_slice::(); - for offsets in offs.as_ref().windows(2) { - repdef.append_next(&mut buf); - buf.extend_from_slice(&(offsets[1] - offsets[0]).to_le_bytes()); - buf.extend_from_slice(&variable.data[offsets[0] as usize..offsets[1] as usize]); + let mut rep_offset = 0; + let mut windows_iter = offs.as_ref().windows(2); + while let Some(control) = repdef.append_next(&mut buf) { + if control.is_new_row { + // We have finished a row + debug_assert!(rep_offset <= len); + // SAFETY: We know that `buf.len() <= len` + unsafe { rep_index_builder.append(rep_offset as u64) }; + } + if control.is_visible { + let window = windows_iter.next().unwrap(); + if control.is_valid_item { + buf.extend_from_slice(&(window[1] - window[0]).to_le_bytes()); + buf.extend_from_slice( + &variable.data[window[0] as usize..window[1] as usize], + ); + } + } + rep_offset = buf.len(); } } 8 => { let offs = variable.offsets.borrow_to_typed_slice::(); - for offsets in offs.as_ref().windows(2) { - repdef.append_next(&mut buf); - buf.extend_from_slice(&(offsets[1] - offsets[0]).to_le_bytes()); - buf.extend_from_slice(&variable.data[offsets[0] as usize..offsets[1] as usize]); + let mut rep_offset = 0; + let mut windows_iter = offs.as_ref().windows(2); + while let Some(control) = repdef.append_next(&mut buf) { + if control.is_new_row { + // We have finished a row + debug_assert!(rep_offset <= len); + // SAFETY: We know that `buf.len() <= len` + unsafe { rep_index_builder.append(rep_offset as u64) }; + } + if control.is_visible { + let window = windows_iter.next().unwrap(); + if control.is_valid_item { + buf.extend_from_slice(&(window[1] - window[0]).to_le_bytes()); + buf.extend_from_slice( + &variable.data[window[0] as usize..window[1] as usize], + ); + } + } + rep_offset = buf.len(); } } _ => panic!("Unsupported offset size"), } - LanceBuffer::Owned(buf) + // We might have saved a few bytes by not copying lengths when the length was zero. However, + // if we are over `len` then we have a bug. + debug_assert!(buf.len() <= len); + // Put the final value in the rep index + // SAFETY: `zipped_data.len() == len` + unsafe { + rep_index_builder.append(buf.len() as u64); + } + + let zipped_data = LanceBuffer::Owned(buf); + let rep_index = rep_index_builder.into_data(); + debug_assert!(!rep_index.is_empty()); + let rep_index = Some(LanceBuffer::Owned(rep_index)); + SerializedFullZip { + values: zipped_data, + repetition_index: rep_index, + } } /// Serializes data into a single buffer according to the full-zip format which zips @@ -2265,10 +4205,15 @@ impl PrimitiveStructuralEncoder { fn serialize_full_zip( compressed_data: PerValueDataBlock, repdef: ControlWordIterator, - ) -> LanceBuffer { + num_items: u64, + ) -> SerializedFullZip { match compressed_data { - PerValueDataBlock::Fixed(fixed) => Self::serialize_full_zip_fixed(fixed, repdef), - PerValueDataBlock::Variable(var) => Self::serialize_full_zip_variable(var, repdef), + PerValueDataBlock::Fixed(fixed) => { + Self::serialize_full_zip_fixed(fixed, repdef, num_items) + } + PerValueDataBlock::Variable(var) => { + Self::serialize_full_zip_variable(var, repdef, num_items) + } } } @@ -2279,6 +4224,7 @@ impl PrimitiveStructuralEncoder { data: DataBlock, repdefs: Vec, row_number: u64, + num_lists: u64, ) -> Result { let repdef = RepDefBuilder::serialize(repdefs); let max_rep = repdef @@ -2289,29 +4235,72 @@ impl PrimitiveStructuralEncoder { .definition_levels .as_ref() .map_or(0, |d| d.iter().max().copied().unwrap_or(0)); + + // The top-level validity is encoded in repdef so we can remove it + let data = data.remove_outer_validity(); + + // To handle FSL we just flatten + // let data = data.flatten(); + + let (num_items, num_visible_items) = + if let Some(rep_levels) = repdef.repetition_levels.as_ref() { + // If there are rep levels there may be "invisible" items and we need to encode + // rep_levels.len() things which might be larger than data.num_values() + (rep_levels.len() as u64, data.num_values()) + } else { + // If there are no rep levels then we encode data.num_values() things + (data.num_values(), data.num_values()) + }; + + let max_visible_def = repdef.max_visible_level.unwrap_or(u16::MAX); + let repdef_iter = build_control_word_iterator( - repdef.repetition_levels, + repdef.repetition_levels.as_deref(), max_rep, - repdef.definition_levels, + repdef.definition_levels.as_deref(), max_def, + max_visible_def, + num_items as usize, ); let bits_rep = repdef_iter.bits_rep(); let bits_def = repdef_iter.bits_def(); - let num_values = data.num_values(); - // The validity is encoded in repdef so we can remove it - let data = data.remove_validity(); - let compressor = compression_strategy.create_per_value(field, &data)?; let (compressed_data, value_encoding) = compressor.compress(data)?; - let zipped = Self::serialize_full_zip(compressed_data, repdef_iter); + let description = match &compressed_data { + PerValueDataBlock::Fixed(fixed) => ProtobufUtils::fixed_full_zip_layout( + bits_rep, + bits_def, + fixed.bits_per_value as u32, + value_encoding, + &repdef.def_meaning, + num_items as u32, + num_visible_items as u32, + ), + PerValueDataBlock::Variable(variable) => ProtobufUtils::variable_full_zip_layout( + bits_rep, + bits_def, + variable.bits_per_offset as u32, + value_encoding, + &repdef.def_meaning, + num_items as u32, + num_visible_items as u32, + ), + }; + + let zipped = Self::serialize_full_zip(compressed_data, repdef_iter, num_items); + + let data = if let Some(repindex) = zipped.repetition_index { + vec![zipped.values, repindex] + } else { + vec![zipped.values] + }; - let description = ProtobufUtils::full_zip_layout(bits_rep, bits_def, value_encoding); Ok(EncodedPage { - num_rows: num_values, + num_rows: num_lists, column_idx, - data: vec![zipped], + data, description: PageEncoding::Structural(description), row_number, }) @@ -2420,7 +4409,7 @@ impl PrimitiveStructuralEncoder { } } _ => { - unreachable!() + unreachable!("dictionary encode called with data block {:?}", data_block) } } } @@ -2431,27 +4420,54 @@ impl PrimitiveStructuralEncoder { arrays: Vec, repdefs: Vec, row_number: u64, + num_rows: u64, ) -> Result> { let column_idx = self.column_index; let compression_strategy = self.compression_strategy.clone(); let field = self.field.clone(); + let encoding_metadata = self.encoding_metadata.clone(); let task = spawn_cpu(move || { let num_values = arrays.iter().map(|arr| arr.len() as u64).sum(); + if num_values == 0 { + // We should not encode empty arrays. So if we get here that should mean that we + // either have all empty lists or all null lists (or a mix). We still need to encode + // the rep/def information but we can skip the data encoding. + return Self::encode_complex_all_null(column_idx, repdefs, row_number, num_rows); + } let num_nulls = arrays .iter() .map(|arr| arr.logical_nulls().map(|n| n.null_count()).unwrap_or(0) as u64) .sum::(); - if num_values == num_nulls && repdefs.iter().all(|rd| rd.is_simple_validity()) { - log::debug!( - "Encoding column {} with {} rows using simple-null layout", - column_idx, - num_values - ); - Self::encode_simple_all_null(column_idx, num_values, row_number) + if num_values == num_nulls { + if repdefs.iter().all(|rd| rd.is_simple_validity()) { + log::debug!( + "Encoding column {} with {} items using simple-null layout", + column_idx, + num_values + ); + // Simple case, no rep/def and all nulls, we don't need to encode any data + Self::encode_simple_all_null(column_idx, num_values, row_number) + } else { + // If we get here then we have definition levels (presumably due to FSL) and + // we need to store those + Self::encode_complex_all_null(column_idx, repdefs, row_number, num_rows) + } } else { let data_block = DataBlock::from_arrays(&arrays, num_values); - const DICTIONARY_ENCODING_THRESHOLD: u64 = 100; + + // if the `data_block` is a `StructDataBlock`, then this is a struct with packed struct encoding. + if let DataBlock::Struct(ref struct_data_block) = data_block { + if struct_data_block + .children + .iter() + .any(|child| !matches!(child, DataBlock::FixedWidth(_))) + { + panic!("packed struct encoding currently only supports fixed-width fields.") + } + } + + let dictionary_encoding_threshold: u64 = 100.max(data_block.num_values() / 4); let cardinality = if let Some(cardinality_array) = data_block.get_stat(Stat::Cardinality) { cardinality_array.as_primitive::().value(0) @@ -2460,7 +4476,7 @@ impl PrimitiveStructuralEncoder { }; // The triggering threshold for dictionary encoding can be further tuned. - if cardinality <= DICTIONARY_ENCODING_THRESHOLD + if cardinality <= dictionary_encoding_threshold && data_block.num_values() >= 10 * cardinality { let (indices_data_block, dictionary_data_block) = @@ -2473,10 +4489,11 @@ impl PrimitiveStructuralEncoder { repdefs, row_number, Some(dictionary_data_block), + num_rows, ) - } else if Self::is_narrow(&data_block) { + } else if Self::prefers_miniblock(&data_block, encoding_metadata.as_ref()) { log::debug!( - "Encoding column {} with {} rows using mini-block layout", + "Encoding column {} with {} items using mini-block layout", column_idx, num_values ); @@ -2488,10 +4505,11 @@ impl PrimitiveStructuralEncoder { repdefs, row_number, None, + num_rows, ) - } else { + } else if Self::prefers_fullzip(encoding_metadata.as_ref()) { log::debug!( - "Encoding column {} with {} rows using full-zip layout", + "Encoding column {} with {} items using full-zip layout", column_idx, num_values ); @@ -2502,7 +4520,10 @@ impl PrimitiveStructuralEncoder { data_block, repdefs, row_number, + num_rows, ) + } else { + Err(Error::InvalidInput { source: format!("Cannot determine structural encoding for field {}. This typically indicates an invalid value of the field metadata key {}", field.name, STRUCTURAL_ENCODING_META_KEY).into(), location: location!() }) } } }) @@ -2526,6 +4547,14 @@ impl PrimitiveStructuralEncoder { DataType::Dictionary(_, _) => { unreachable!() } + // Extract our validity buf but NOT any child validity bufs. (they will be encoded in + // as part of the values). Note: for FSL we do not use repdef.add_fsl because we do + // NOT want to increase the repdef depth. + // + // This would be quite catasrophic for something like vector embeddings. Imagine we + // had thousands of vectors and some were null but no vector contained null items. If + // we treated the vectors (primitive FSL) like we treat structural FSL we would end up + // with a rep/def value for every single item in the vector. _ => Self::extract_validity_buf(array, repdef), } } @@ -2539,13 +4568,16 @@ impl FieldEncoder for PrimitiveStructuralEncoder { _external_buffers: &mut OutOfLineBuffers, mut repdef: RepDefBuilder, row_number: u64, + num_rows: u64, ) -> Result> { Self::extract_validity(array.as_ref(), &mut repdef); self.accumulated_repdefs.push(repdef); - if let Some((arrays, row_number)) = self.accumulation_queue.insert(array, row_number) { + if let Some((arrays, row_number, num_rows)) = + self.accumulation_queue.insert(array, row_number, num_rows) + { let accumulated_repdefs = std::mem::take(&mut self.accumulated_repdefs); - Ok(self.do_flush(arrays, accumulated_repdefs, row_number)?) + Ok(self.do_flush(arrays, accumulated_repdefs, row_number, num_rows)?) } else { Ok(vec![]) } @@ -2553,9 +4585,9 @@ impl FieldEncoder for PrimitiveStructuralEncoder { // If there is any data left in the buffer then create an encode task from it fn flush(&mut self, _external_buffers: &mut OutOfLineBuffers) -> Result> { - if let Some((arrays, row_number)) = self.accumulation_queue.flush() { + if let Some((arrays, row_number, num_rows)) = self.accumulation_queue.flush() { let accumulated_repdefs = std::mem::take(&mut self.accumulated_repdefs); - Ok(self.do_flush(arrays, accumulated_repdefs, row_number)?) + Ok(self.do_flush(arrays, accumulated_repdefs, row_number, num_rows)?) } else { Ok(vec![]) } @@ -2574,14 +4606,19 @@ impl FieldEncoder for PrimitiveStructuralEncoder { } #[cfg(test)] +#[allow(clippy::single_range_in_vec_init)] mod tests { - use std::sync::Arc; + use std::{collections::VecDeque, sync::Arc}; use arrow_array::{ArrayRef, Int8Array, StringArray}; - use crate::encodings::logical::primitive::PrimitiveStructuralEncoder; + use crate::encodings::logical::primitive::{ + ChunkDrainInstructions, PrimitiveStructuralEncoder, + }; - use super::DataBlock; + use super::{ + ChunkInstructions, DataBlock, DecodeMiniBlockTask, PreambleAction, RepetitionIndex, + }; #[test] fn test_is_narrow() { @@ -2602,4 +4639,591 @@ mod tests { let block = DataBlock::from_array(string_array); assert!((!PrimitiveStructuralEncoder::is_narrow(&block))); } + + #[test] + fn test_map_range() { + // Null in the middle + // [[A, B, C], [D, E], NULL, [F, G, H]] + let rep = Some(vec![1, 0, 0, 1, 0, 1, 1, 0, 0]); + let def = Some(vec![0, 0, 0, 0, 0, 1, 0, 0, 0]); + let max_visible_def = 0; + let total_items = 8; + let max_rep = 1; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Absent, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + check(0..1, 0..3, 0..3); + check(1..2, 3..5, 3..5); + check(2..3, 5..5, 5..6); + check(3..4, 5..8, 6..9); + check(0..2, 0..5, 0..5); + check(1..3, 3..5, 3..6); + check(2..4, 5..8, 5..9); + check(0..3, 0..5, 0..6); + check(1..4, 3..8, 3..9); + check(0..4, 0..8, 0..9); + + // Null at start + // [NULL, [A, B], [C]] + let rep = Some(vec![1, 1, 0, 1]); + let def = Some(vec![1, 0, 0, 0]); + let max_visible_def = 0; + let total_items = 3; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Absent, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + check(0..1, 0..0, 0..1); + check(1..2, 0..2, 1..3); + check(2..3, 2..3, 3..4); + check(0..2, 0..2, 0..3); + check(1..3, 0..3, 1..4); + check(0..3, 0..3, 0..4); + + // Null at end + // [[A], [B, C], NULL] + let rep = Some(vec![1, 1, 0, 1]); + let def = Some(vec![0, 0, 0, 1]); + let max_visible_def = 0; + let total_items = 3; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Absent, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + check(0..1, 0..1, 0..1); + check(1..2, 1..3, 1..3); + check(2..3, 3..3, 3..4); + check(0..2, 0..3, 0..3); + check(1..3, 1..3, 1..4); + check(0..3, 0..3, 0..4); + + // No nulls, with repetition + // [[A, B], [C, D], [E, F]] + let rep = Some(vec![1, 0, 1, 0, 1, 0]); + let def: Option<&[u16]> = None; + let max_visible_def = 0; + let total_items = 6; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Absent, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + check(0..1, 0..2, 0..2); + check(1..2, 2..4, 2..4); + check(2..3, 4..6, 4..6); + check(0..2, 0..4, 0..4); + check(1..3, 2..6, 2..6); + check(0..3, 0..6, 0..6); + + // No repetition, with nulls (this case is trivial) + // [A, B, NULL, C] + let rep: Option<&[u16]> = None; + let def = Some(vec![0, 0, 1, 0]); + let max_visible_def = 1; + let total_items = 4; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Absent, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + check(0..1, 0..1, 0..1); + check(1..2, 1..2, 1..2); + check(2..3, 2..3, 2..3); + check(0..2, 0..2, 0..2); + check(1..3, 1..3, 1..3); + check(0..3, 0..3, 0..3); + + // Tricky case, this chunk is a continuation and starts with a rep-index = 0 + // [[..., A] [B, C], NULL] + // + // What we do will depend on the preamble action + let rep = Some(vec![0, 1, 0, 1]); + let def = Some(vec![0, 0, 0, 1]); + let max_visible_def = 0; + let total_items = 3; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Take, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + // If we are taking the preamble then the range must start at 0 + check(0..1, 0..3, 0..3); + check(0..2, 0..3, 0..4); + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Skip, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + check(0..1, 1..3, 1..3); + check(1..2, 3..3, 3..4); + check(0..2, 1..3, 1..4); + + // Another preamble case but now it doesn't end with a new list + // [[..., A], NULL, [D, E]] + // + // What we do will depend on the preamble action + let rep = Some(vec![0, 1, 1, 0]); + let def = Some(vec![0, 1, 0, 0]); + let max_visible_def = 0; + let total_items = 4; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Take, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + // If we are taking the preamble then the range must start at 0 + check(0..1, 0..1, 0..2); + check(0..2, 0..3, 0..4); + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Skip, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + // If we are taking the preamble then the range must start at 0 + check(0..1, 1..1, 1..2); + check(1..2, 1..3, 2..4); + check(0..2, 1..3, 1..4); + + // Now a preamble case without any definition levels + // [[..., A] [B, C], [D]] + let rep = Some(vec![0, 1, 0, 1]); + let def: Option> = None; + let max_visible_def = 0; + let total_items = 4; + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Take, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + // If we are taking the preamble then the range must start at 0 + check(0..1, 0..3, 0..3); + check(0..2, 0..4, 0..4); + + let check = |range, expected_item_range, expected_level_range| { + let (item_range, level_range) = DecodeMiniBlockTask::map_range( + range, + rep.as_ref(), + def.as_ref(), + max_rep, + max_visible_def, + total_items, + PreambleAction::Skip, + ); + assert_eq!(item_range, expected_item_range); + assert_eq!(level_range, expected_level_range); + }; + + check(0..1, 1..3, 1..3); + check(1..2, 3..4, 3..4); + check(0..2, 1..4, 1..4); + } + + #[test] + fn test_schedule_instructions() { + let repetition_index = vec![vec![5, 2], vec![3, 0], vec![4, 7], vec![2, 0]]; + let repetition_index = RepetitionIndex::decode(&repetition_index); + + let check = |user_ranges, expected_instructions| { + let instructions = + ChunkInstructions::schedule_instructions(&repetition_index, user_ranges); + assert_eq!(instructions, expected_instructions); + }; + + // The instructions we expect if we're grabbing the whole range + let expected_take_all = vec![ + ChunkInstructions { + chunk_idx: 0, + preamble: PreambleAction::Absent, + rows_to_skip: 0, + rows_to_take: 5, + take_trailer: true, + }, + ChunkInstructions { + chunk_idx: 1, + preamble: PreambleAction::Take, + rows_to_skip: 0, + rows_to_take: 2, + take_trailer: false, + }, + ChunkInstructions { + chunk_idx: 2, + preamble: PreambleAction::Absent, + rows_to_skip: 0, + rows_to_take: 4, + take_trailer: true, + }, + ChunkInstructions { + chunk_idx: 3, + preamble: PreambleAction::Take, + rows_to_skip: 0, + rows_to_take: 1, + take_trailer: false, + }, + ]; + + // Take all as 1 range + check(&[0..14], expected_take_all.clone()); + + // Take all a individual rows + check( + &[ + 0..1, + 1..2, + 2..3, + 3..4, + 4..5, + 5..6, + 6..7, + 7..8, + 8..9, + 9..10, + 10..11, + 11..12, + 12..13, + 13..14, + ], + expected_take_all, + ); + + // Test some partial takes + + // 2 rows in the same chunk but not contiguous + check( + &[0..1, 3..4], + vec![ + ChunkInstructions { + chunk_idx: 0, + preamble: PreambleAction::Absent, + rows_to_skip: 0, + rows_to_take: 1, + take_trailer: false, + }, + ChunkInstructions { + chunk_idx: 0, + preamble: PreambleAction::Absent, + rows_to_skip: 3, + rows_to_take: 1, + take_trailer: false, + }, + ], + ); + + // Taking just a trailer/preamble + check( + &[5..6], + vec![ + ChunkInstructions { + chunk_idx: 0, + preamble: PreambleAction::Absent, + rows_to_skip: 5, + rows_to_take: 0, + take_trailer: true, + }, + ChunkInstructions { + chunk_idx: 1, + preamble: PreambleAction::Take, + rows_to_skip: 0, + rows_to_take: 0, + take_trailer: false, + }, + ], + ); + + // Skipping an entire chunk + check( + &[7..10], + vec![ + ChunkInstructions { + chunk_idx: 1, + preamble: PreambleAction::Skip, + rows_to_skip: 1, + rows_to_take: 1, + take_trailer: false, + }, + ChunkInstructions { + chunk_idx: 2, + preamble: PreambleAction::Absent, + rows_to_skip: 0, + rows_to_take: 2, + take_trailer: false, + }, + ], + ); + } + + #[test] + fn test_drain_instructions() { + fn drain_from_instructions( + instructions: &mut VecDeque, + mut rows_desired: u64, + need_preamble: &mut bool, + skip_in_chunk: &mut u64, + ) -> Vec { + // Note: instructions.len() is an upper bound, we typically take much fewer + let mut drain_instructions = Vec::with_capacity(instructions.len()); + while rows_desired > 0 || *need_preamble { + let (next_instructions, consumed_chunk) = instructions + .front() + .unwrap() + .drain_from_instruction(&mut rows_desired, need_preamble, skip_in_chunk); + if consumed_chunk { + instructions.pop_front(); + } + drain_instructions.push(next_instructions); + } + drain_instructions + } + + let repetition_index = vec![vec![5, 2], vec![3, 0], vec![4, 7], vec![2, 0]]; + let repetition_index = RepetitionIndex::decode(&repetition_index); + let user_ranges = vec![1..7, 10..14]; + + // First, schedule the ranges + let scheduled = ChunkInstructions::schedule_instructions(&repetition_index, &user_ranges); + + let mut to_drain = VecDeque::from(scheduled.clone()); + + // Now we drain in batches of 4 + + let mut need_preamble = false; + let mut skip_in_chunk = 0; + + let next_batch = + drain_from_instructions(&mut to_drain, 4, &mut need_preamble, &mut skip_in_chunk); + + assert!(!need_preamble); + assert_eq!(skip_in_chunk, 4); + assert_eq!( + next_batch, + vec![ChunkDrainInstructions { + chunk_instructions: scheduled[0].clone(), + rows_to_take: 4, + rows_to_skip: 0, + preamble_action: PreambleAction::Absent, + }] + ); + + let next_batch = + drain_from_instructions(&mut to_drain, 4, &mut need_preamble, &mut skip_in_chunk); + + assert!(!need_preamble); + assert_eq!(skip_in_chunk, 2); + + assert_eq!( + next_batch, + vec![ + ChunkDrainInstructions { + chunk_instructions: scheduled[0].clone(), + rows_to_take: 1, + rows_to_skip: 4, + preamble_action: PreambleAction::Absent, + }, + ChunkDrainInstructions { + chunk_instructions: scheduled[1].clone(), + rows_to_take: 1, + rows_to_skip: 0, + preamble_action: PreambleAction::Take, + }, + ChunkDrainInstructions { + chunk_instructions: scheduled[2].clone(), + rows_to_take: 2, + rows_to_skip: 0, + preamble_action: PreambleAction::Absent, + } + ] + ); + + let next_batch = + drain_from_instructions(&mut to_drain, 2, &mut need_preamble, &mut skip_in_chunk); + + assert!(!need_preamble); + assert_eq!(skip_in_chunk, 0); + + assert_eq!( + next_batch, + vec![ + ChunkDrainInstructions { + chunk_instructions: scheduled[2].clone(), + rows_to_take: 1, + rows_to_skip: 2, + preamble_action: PreambleAction::Absent, + }, + ChunkDrainInstructions { + chunk_instructions: scheduled[3].clone(), + rows_to_take: 1, + rows_to_skip: 0, + preamble_action: PreambleAction::Take, + }, + ] + ); + + // Regression case. Need a chunk with preamble, rows, and trailer (the middle chunk here) + let repetition_index = vec![vec![5, 2], vec![3, 3], vec![20, 0]]; + let repetition_index = RepetitionIndex::decode(&repetition_index); + let user_ranges = vec![0..28]; + + // First, schedule the ranges + let scheduled = ChunkInstructions::schedule_instructions(&repetition_index, &user_ranges); + + let mut to_drain = VecDeque::from(scheduled.clone()); + + // Drain first chunk and some of second chunk + + let mut need_preamble = false; + let mut skip_in_chunk = 0; + + let next_batch = + drain_from_instructions(&mut to_drain, 7, &mut need_preamble, &mut skip_in_chunk); + + assert_eq!( + next_batch, + vec![ + ChunkDrainInstructions { + chunk_instructions: scheduled[0].clone(), + rows_to_take: 6, + rows_to_skip: 0, + preamble_action: PreambleAction::Absent, + }, + ChunkDrainInstructions { + chunk_instructions: scheduled[1].clone(), + rows_to_take: 1, + rows_to_skip: 0, + preamble_action: PreambleAction::Take, + }, + ] + ); + + assert!(!need_preamble); + assert_eq!(skip_in_chunk, 1); + + // Now, the tricky part. We drain the second chunk, including the trailer, and need to make sure + // we get a drain task to take the preamble of the third chunk (and nothing else) + let next_batch = + drain_from_instructions(&mut to_drain, 2, &mut need_preamble, &mut skip_in_chunk); + + assert_eq!( + next_batch, + vec![ + ChunkDrainInstructions { + chunk_instructions: scheduled[1].clone(), + rows_to_take: 2, + rows_to_skip: 1, + preamble_action: PreambleAction::Skip, + }, + ChunkDrainInstructions { + chunk_instructions: scheduled[2].clone(), + rows_to_take: 0, + rows_to_skip: 0, + preamble_action: PreambleAction::Take, + }, + ] + ); + + assert!(!need_preamble); + assert_eq!(skip_in_chunk, 0); + } } diff --git a/rust/lance-encoding/src/encodings/logical/struct.rs b/rust/lance-encoding/src/encodings/logical/struct.rs index a4cc44afc71..cd1d9ce29d3 100644 --- a/rust/lance-encoding/src/encodings/logical/struct.rs +++ b/rust/lance-encoding/src/encodings/logical/struct.rs @@ -8,15 +8,16 @@ use std::{ }; use arrow_array::{cast::AsArray, Array, ArrayRef, StructArray}; -use arrow_schema::{DataType, Fields}; +use arrow_schema::{DataType, Field, Fields}; use futures::{ future::BoxFuture, stream::{FuturesOrdered, FuturesUnordered}, FutureExt, StreamExt, TryStreamExt, }; use itertools::Itertools; +use lance_arrow::FieldExt; use log::trace; -use snafu::{location, Location}; +use snafu::location; use crate::{ decoder::{ @@ -31,7 +32,7 @@ use crate::{ }; use lance_core::{Error, Result}; -use super::primitive::StructuralPrimitiveFieldDecoder; +use super::{list::StructuralListDecoder, primitive::StructuralPrimitiveFieldDecoder}; #[derive(Debug)] struct SchedulingJobWithStatus<'a> { @@ -42,27 +43,110 @@ struct SchedulingJobWithStatus<'a> { rows_remaining: u64, } -impl<'a> PartialEq for SchedulingJobWithStatus<'a> { +impl PartialEq for SchedulingJobWithStatus<'_> { fn eq(&self, other: &Self) -> bool { self.col_idx == other.col_idx } } -impl<'a> Eq for SchedulingJobWithStatus<'a> {} +impl Eq for SchedulingJobWithStatus<'_> {} -impl<'a> PartialOrd for SchedulingJobWithStatus<'a> { +impl PartialOrd for SchedulingJobWithStatus<'_> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl<'a> Ord for SchedulingJobWithStatus<'a> { +impl Ord for SchedulingJobWithStatus<'_> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { // Note this is reversed to make it min-heap other.rows_scheduled.cmp(&self.rows_scheduled) } } +#[derive(Debug)] +struct EmptyStructDecodeTask { + num_rows: u64, +} + +impl DecodeArrayTask for EmptyStructDecodeTask { + fn decode(self: Box) -> Result { + Ok(Arc::new(StructArray::new_empty_fields( + self.num_rows as usize, + None, + ))) + } +} + +#[derive(Debug)] +struct EmptyStructDecoder { + num_rows: u64, + rows_drained: u64, + data_type: DataType, +} + +impl EmptyStructDecoder { + fn new(num_rows: u64) -> Self { + Self { + num_rows, + rows_drained: 0, + data_type: DataType::Struct(Fields::from(Vec::::default())), + } + } +} + +impl LogicalPageDecoder for EmptyStructDecoder { + fn wait_for_loaded(&mut self, _loaded_need: u64) -> BoxFuture> { + Box::pin(std::future::ready(Ok(()))) + } + fn rows_loaded(&self) -> u64 { + self.num_rows + } + fn rows_unloaded(&self) -> u64 { + 0 + } + fn num_rows(&self) -> u64 { + self.num_rows + } + fn rows_drained(&self) -> u64 { + self.rows_drained + } + fn drain(&mut self, num_rows: u64) -> Result { + self.rows_drained += num_rows; + Ok(NextDecodeTask { + num_rows, + task: Box::new(EmptyStructDecodeTask { num_rows }), + }) + } + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[derive(Debug)] +struct EmptyStructSchedulerJob { + num_rows: u64, +} + +impl SchedulingJob for EmptyStructSchedulerJob { + fn schedule_next( + &mut self, + context: &mut SchedulerContext, + _priority: &dyn PriorityRange, + ) -> Result { + let empty_decoder = Box::new(EmptyStructDecoder::new(self.num_rows)); + let struct_decoder = context.locate_decoder(empty_decoder); + Ok(ScheduledScanLine { + decoders: vec![MessageType::DecoderReady(struct_decoder)], + rows_scheduled: self.num_rows, + }) + } + + fn num_rows(&self) -> u64 { + self.num_rows + } +} + /// Scheduling job for struct data /// /// The order in which we schedule the children is important. We want to schedule the child @@ -106,7 +190,7 @@ impl<'a> SimpleStructSchedulerJob<'a> { } } -impl<'a> SchedulingJob for SimpleStructSchedulerJob<'a> { +impl SchedulingJob for SimpleStructSchedulerJob<'_> { fn schedule_next( &mut self, mut context: &mut SchedulerContext, @@ -174,9 +258,15 @@ pub struct SimpleStructScheduler { } impl SimpleStructScheduler { - pub fn new(children: Vec>, child_fields: Fields) -> Self { - debug_assert!(!children.is_empty()); - let num_rows = children[0].num_rows(); + pub fn new( + children: Vec>, + child_fields: Fields, + num_rows: u64, + ) -> Self { + let num_rows = children + .first() + .map(|child| child.num_rows()) + .unwrap_or(num_rows); debug_assert!(children.iter().all(|child| child.num_rows() == num_rows)); Self { children, @@ -192,6 +282,11 @@ impl FieldScheduler for SimpleStructScheduler { ranges: &[Range], filter: &FilterExpression, ) -> Result> { + if self.children.is_empty() { + return Ok(Box::new(EmptyStructSchedulerJob { + num_rows: ranges.iter().map(|r| r.end - r.start).sum(), + })); + } let child_schedulers = self .children .iter() @@ -239,21 +334,21 @@ struct StructuralSchedulingJobWithStatus<'a> { rows_remaining: u64, } -impl<'a> PartialEq for StructuralSchedulingJobWithStatus<'a> { +impl PartialEq for StructuralSchedulingJobWithStatus<'_> { fn eq(&self, other: &Self) -> bool { self.col_idx == other.col_idx } } -impl<'a> Eq for StructuralSchedulingJobWithStatus<'a> {} +impl Eq for StructuralSchedulingJobWithStatus<'_> {} -impl<'a> PartialOrd for StructuralSchedulingJobWithStatus<'a> { +impl PartialOrd for StructuralSchedulingJobWithStatus<'_> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl<'a> Ord for StructuralSchedulingJobWithStatus<'a> { +impl Ord for StructuralSchedulingJobWithStatus<'_> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { // Note this is reversed to make it min-heap other.rows_scheduled.cmp(&self.rows_scheduled) @@ -297,7 +392,7 @@ impl<'a> RepDefStructSchedulingJob<'a> { } } -impl<'a> StructuralSchedulingJob for RepDefStructSchedulingJob<'a> { +impl StructuralSchedulingJob for RepDefStructSchedulingJob<'_> { fn schedule_next( &mut self, mut context: &mut SchedulerContext, @@ -583,10 +678,12 @@ pub struct StructuralStructDecoder { children: Vec>, data_type: DataType, child_fields: Fields, + // The root decoder is slightly different because it cannot have nulls + is_root: bool, } impl StructuralStructDecoder { - pub fn new(fields: Fields, should_validate: bool) -> Self { + pub fn new(fields: Fields, should_validate: bool, is_root: bool) -> Self { let children = fields .iter() .map(|field| Self::field_to_decoder(field, should_validate)) @@ -596,6 +693,7 @@ impl StructuralStructDecoder { data_type, children, child_fields: fields, + is_root, } } @@ -604,8 +702,22 @@ impl StructuralStructDecoder { should_validate: bool, ) -> Box { match field.data_type() { - DataType::Struct(fields) => Box::new(Self::new(fields.clone(), should_validate)), - DataType::List(_) | DataType::LargeList(_) => todo!(), + DataType::Struct(fields) => { + if field.is_packed_struct() { + let decoder = + StructuralPrimitiveFieldDecoder::new(&field.clone(), should_validate); + Box::new(decoder) + } else { + Box::new(Self::new(fields.clone(), should_validate, false)) + } + } + DataType::List(child_field) | DataType::LargeList(child_field) => { + let child_decoder = Self::field_to_decoder(child_field, should_validate); + Box::new(StructuralListDecoder::new( + child_decoder, + field.data_type().clone(), + )) + } DataType::RunEndEncoded(_, _) => todo!(), DataType::ListView(_) | DataType::LargeListView(_) => todo!(), DataType::Map(_, _) => todo!(), @@ -613,6 +725,14 @@ impl StructuralStructDecoder { _ => Box::new(StructuralPrimitiveFieldDecoder::new(field, should_validate)), } } + + pub fn drain_batch_task(&mut self, num_rows: u64) -> Result { + let array_drain = self.drain(num_rows)?; + Ok(NextDecodeTask { + num_rows, + task: Box::new(array_drain), + }) + } } impl StructuralFieldDecoder for StructuralStructDecoder { @@ -633,6 +753,7 @@ impl StructuralFieldDecoder for StructuralStructDecoder { Ok(Box::new(RepDefStructDecodeTask { children: child_tasks, child_fields: self.child_fields.clone(), + is_root: self.is_root, })) } @@ -645,6 +766,7 @@ impl StructuralFieldDecoder for StructuralStructDecoder { struct RepDefStructDecodeTask { children: Vec>, child_fields: Fields, + is_root: bool, } impl StructuralDecodeArrayTask for RepDefStructDecodeTask { @@ -657,15 +779,22 @@ impl StructuralDecodeArrayTask for RepDefStructDecodeTask { let mut children = Vec::with_capacity(arrays.len()); let mut arrays_iter = arrays.into_iter(); let first_array = arrays_iter.next().unwrap(); + let length = first_array.array.len(); // The repdef should be identical across all children at this point let mut repdef = first_array.repdef; children.push(first_array.array); + for array in arrays_iter { + debug_assert_eq!(length, array.array.len()); children.push(array.array); } - let validity = repdef.unravel_validity(); + let validity = if self.is_root { + None + } else { + repdef.unravel_validity(length) + }; let array = StructArray::new(self.child_fields, children, validity); Ok(DecodedArray { array: Arc::new(array), @@ -760,16 +889,13 @@ impl LogicalPageDecoder for SimpleStructDecoder { .map(|child| child.drain(num_rows)) .collect::>>()?; let num_rows = child_tasks[0].num_rows; - let has_more = child_tasks[0].has_more; debug_assert!(child_tasks.iter().all(|task| task.num_rows == num_rows)); - debug_assert!(child_tasks.iter().all(|task| task.has_more == has_more)); Ok(NextDecodeTask { task: Box::new(SimpleStructDecodeTask { children: child_tasks, child_fields: self.child_fields.clone(), }), num_rows, - has_more, }) } @@ -836,6 +962,7 @@ impl FieldEncoder for StructStructuralEncoder { external_buffers: &mut OutOfLineBuffers, mut repdef: RepDefBuilder, row_number: u64, + num_rows: u64, ) -> Result> { let struct_array = array.as_struct(); if let Some(validity) = struct_array.nulls() { @@ -848,7 +975,13 @@ impl FieldEncoder for StructStructuralEncoder { .iter_mut() .zip(struct_array.columns().iter()) .map(|(encoder, arr)| { - encoder.maybe_encode(arr.clone(), external_buffers, repdef.clone(), row_number) + encoder.maybe_encode( + arr.clone(), + external_buffers, + repdef.clone(), + row_number, + num_rows, + ) }) .collect::>>()?; Ok(child_tasks.into_iter().flatten().collect::>()) @@ -913,6 +1046,7 @@ impl FieldEncoder for StructFieldEncoder { external_buffers: &mut OutOfLineBuffers, repdef: RepDefBuilder, row_number: u64, + num_rows: u64, ) -> Result> { self.num_rows_seen += array.len() as u64; let struct_array = array.as_struct(); @@ -921,7 +1055,13 @@ impl FieldEncoder for StructFieldEncoder { .iter_mut() .zip(struct_array.columns().iter()) .map(|(encoder, arr)| { - encoder.maybe_encode(arr.clone(), external_buffers, repdef.clone(), row_number) + encoder.maybe_encode( + arr.clone(), + external_buffers, + repdef.clone(), + row_number, + num_rows, + ) }) .collect::>>()?; Ok(child_tasks.into_iter().flatten().collect::>()) @@ -1074,6 +1214,15 @@ mod tests { check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await; } + #[test_log::test(tokio::test)] + async fn test_empty_struct() { + // It's technically legal for a struct to have 0 children, need to + // make sure we support that + let data_type = DataType::Struct(Fields::from(Vec::::default())); + let field = Field::new("row", data_type, false); + check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await; + } + #[test_log::test(tokio::test)] async fn test_complicated_struct() { let data_type = DataType::Struct(Fields::from(vec![ diff --git a/rust/lance-encoding/src/encodings/physical.rs b/rust/lance-encoding/src/encodings/physical.rs index a108b679e16..284315b9db4 100644 --- a/rust/lance-encoding/src/encodings/physical.rs +++ b/rust/lance-encoding/src/encodings/physical.rs @@ -12,6 +12,7 @@ use self::{ dictionary::DictionaryPageScheduler, fixed_size_list::FixedListScheduler, value::ValuePageScheduler, }; +use crate::buffer::LanceBuffer; use crate::encodings::physical::block_compress::CompressionScheme; use crate::{ decoder::PageScheduler, @@ -29,6 +30,7 @@ pub mod fixed_size_binary; pub mod fixed_size_list; pub mod fsst; pub mod packed_struct; +pub mod struct_encoding; pub mod value; /// These contain the file buffers shared across the entire file @@ -235,16 +237,32 @@ pub fn decoder_from_array_encoding( let inner = decoder_from_array_encoding(fsst.binary.as_ref().unwrap(), buffers, data_type); - Box::new(FsstPageScheduler::new(inner, fsst.symbol_table.clone())) + Box::new(FsstPageScheduler::new( + inner, + LanceBuffer::from_bytes(fsst.symbol_table.clone(), 1), + )) } pb::array_encoding::ArrayEncoding::Dictionary(dictionary) => { let indices_encoding = dictionary.indices.as_ref().unwrap(); let items_encoding = dictionary.items.as_ref().unwrap(); let num_dictionary_items = dictionary.num_dictionary_items; + // We can get here in 2 ways. The data is dictionary encoded and the user wants a dictionary or + // the data is dictionary encoded, as an optimization, and the user wants the value type. Figure + // out the value type. + let value_type = if let DataType::Dictionary(_, value_type) = data_type { + value_type + } else { + data_type + }; + + // Note: we don't actually know the indices type here, passing down `data_type` works ok because + // the dictionary indices are always integers and we don't need the data_type to figure out how + // to decode integers. let indices_scheduler = decoder_from_array_encoding(indices_encoding, buffers, data_type); - let items_scheduler = decoder_from_array_encoding(items_encoding, buffers, data_type); + + let items_scheduler = decoder_from_array_encoding(items_encoding, buffers, value_type); let should_decode_dict = !data_type.is_dictionary(); @@ -282,11 +300,7 @@ pub fn decoder_from_array_encoding( // This will change in the future when we add support for struct nullability. pb::array_encoding::ArrayEncoding::Struct(_) => unreachable!(), // 2.1 only - pb::array_encoding::ArrayEncoding::Constant(_) => unreachable!(), - pb::array_encoding::ArrayEncoding::Bitpack2(_) => unreachable!(), - pb::array_encoding::ArrayEncoding::BinaryMiniBlock(_) => unreachable!(), - pb::array_encoding::ArrayEncoding::FsstMiniBlock(_) => unreachable!(), - pb::array_encoding::ArrayEncoding::BinaryBlock(_) => unreachable!(), + _ => unreachable!("Unsupported array encoding: {:?}", encoding), } } diff --git a/rust/lance-encoding/src/encodings/physical/binary.rs b/rust/lance-encoding/src/encodings/physical/binary.rs index bdae9557bad..cd1dcef63ef 100644 --- a/rust/lance-encoding/src/encodings/physical/binary.rs +++ b/rust/lance-encoding/src/encodings/physical/binary.rs @@ -12,19 +12,24 @@ use bytemuck::{cast_slice, try_cast_slice}; use byteorder::{ByteOrder, LittleEndian}; use futures::TryFutureExt; use lance_core::utils::bit::pad_bytes; -use snafu::{location, Location}; +use snafu::location; use futures::{future::BoxFuture, FutureExt}; -use crate::decoder::{BlockDecompressor, LogicalPageDecoder, MiniBlockDecompressor}; -use crate::encoder::{BlockCompressor, MiniBlockChunk, MiniBlockCompressed, MiniBlockCompressor}; +use crate::decoder::{ + BlockDecompressor, LogicalPageDecoder, MiniBlockDecompressor, VariablePerValueDecompressor, +}; +use crate::encoder::{ + BlockCompressor, MiniBlockChunk, MiniBlockCompressed, MiniBlockCompressor, PerValueCompressor, + PerValueDataBlock, +}; use crate::encodings::logical::primitive::PrimitiveFieldDecoder; use crate::buffer::LanceBuffer; use crate::data::{ BlockInfo, DataBlock, FixedWidthDataBlock, NullableDataBlock, VariableWidthBlock, }; -use crate::format::ProtobufUtils; +use crate::format::{pb, ProtobufUtils}; use crate::{ decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, EncodedArray}, @@ -607,7 +612,7 @@ impl BinaryMiniBlockEncoder { let this_chunk_size = (num_values_in_this_chunk + 1) * 4 + (offsets[offsets.len() - 1] - offsets[last_offset_in_orig_idx]) as usize; - let padded_chunk_size = ((this_chunk_size + 3) / 4) * 4; + let padded_chunk_size = this_chunk_size.next_multiple_of(4); // the bytes are put after the offsets let this_chunk_bytes_start_offset = (num_values_in_this_chunk + 1) * 4; @@ -619,7 +624,7 @@ impl BinaryMiniBlockEncoder { }); chunks.push(MiniBlockChunk { log_num_values: 0, - num_bytes: padded_chunk_size as u16, + buffer_sizes: vec![padded_chunk_size as u16], }); break; } else { @@ -631,7 +636,7 @@ impl BinaryMiniBlockEncoder { + (offsets[this_last_offset_in_orig_idx] - offsets[last_offset_in_orig_idx]) as usize; - let padded_chunk_size = ((this_chunk_size + 3) / 4) * 4; + let padded_chunk_size = this_chunk_size.next_multiple_of(4); // the bytes are put after the offsets let this_chunk_bytes_start_offset = (num_values_in_this_chunk + 1) * 4; @@ -645,7 +650,7 @@ impl BinaryMiniBlockEncoder { chunks.push(MiniBlockChunk { log_num_values: num_values_in_this_chunk.trailing_zeros() as u8, - num_bytes: padded_chunk_size as u16, + buffer_sizes: vec![padded_chunk_size as u16], }); last_offset_in_orig_idx = this_last_offset_in_orig_idx; @@ -683,20 +688,17 @@ impl BinaryMiniBlockEncoder { ( MiniBlockCompressed { - data: LanceBuffer::reinterpret_vec(output), + data: vec![LanceBuffer::reinterpret_vec(output)], chunks, num_values: data.num_values, }, - ProtobufUtils::binary_miniblock(), + ProtobufUtils::variable(/*bits_per_value=*/ 32), ) } } impl MiniBlockCompressor for BinaryMiniBlockEncoder { - fn compress( - &self, - data: DataBlock, - ) -> Result<(MiniBlockCompressed, crate::format::pb::ArrayEncoding)> { + fn compress(&self, data: DataBlock) -> Result<(MiniBlockCompressed, pb::ArrayEncoding)> { match data { DataBlock::VariableWidth(variable_width) => Ok(self.chunk_data(variable_width)), _ => Err(Error::InvalidInput { @@ -718,7 +720,9 @@ impl MiniBlockDecompressor for BinaryMiniBlockDecompressor { // decompress a MiniBlock of binary data, the num_values must be less than or equal // to the number of values this MiniBlock has, BinaryMiniBlock doesn't store `the number of values` // it has so assertion can not be done here and the caller of `decompress` must ensure `num_values` <= number of values in the chunk. - fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result { + fn decompress(&self, data: Vec, num_values: u64) -> Result { + assert_eq!(data.len(), 1); + let data = data.into_iter().next().unwrap(); assert!(data.len() >= 8); let offsets: &[u32] = try_cast_slice(&data) .expect("casting buffer failed during BinaryMiniBlock decompression"); @@ -740,9 +744,11 @@ impl MiniBlockDecompressor for BinaryMiniBlockDecompressor { } } +/// Most basic encoding for variable-width data which does no compression at all #[derive(Debug, Default)] -pub struct BinaryBlockEncoder {} -impl BlockCompressor for BinaryBlockEncoder { +pub struct VariableEncoder {} + +impl BlockCompressor for VariableEncoder { fn compress(&self, data: DataBlock) -> Result { let num_values: u32 = data .num_values() @@ -785,13 +791,33 @@ impl BlockCompressor for BinaryBlockEncoder { } } +impl PerValueCompressor for VariableEncoder { + fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)> { + let DataBlock::VariableWidth(variable) = data else { + panic!("BinaryPerValueCompressor can only work with Variable Width DataBlock."); + }; + + let encoding = ProtobufUtils::variable(variable.bits_per_offset); + Ok((PerValueDataBlock::Variable(variable), encoding)) + } +} + +#[derive(Debug, Default)] +pub struct VariableDecoder {} + +impl VariablePerValueDecompressor for VariableDecoder { + fn decompress(&self, data: VariableWidthBlock) -> Result { + Ok(DataBlock::VariableWidth(data)) + } +} + #[derive(Debug, Default)] pub struct BinaryBlockDecompressor {} impl BlockDecompressor for BinaryBlockDecompressor { - fn decompress(&self, data: LanceBuffer) -> Result { + fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result { // the first 4 bytes in the BinaryBlock compressed buffer stores the num_values this block has. - let num_values = LittleEndian::read_u32(&data[..4]) as u64; + debug_assert_eq!(num_values, LittleEndian::read_u32(&data[..4]) as u64); // the next 4 bytes in the BinaryBlock compressed buffer stores the bytes_start_offset. let bytes_start_offset = LittleEndian::read_u32(&data[4..8]); @@ -823,6 +849,10 @@ pub mod tests { }; use arrow_schema::{DataType, Field}; + use lance_core::datatypes::{ + COMPRESSION_META_KEY, STRUCTURAL_ENCODING_FULLZIP, STRUCTURAL_ENCODING_META_KEY, + STRUCTURAL_ENCODING_MINIBLOCK, + }; use rstest::rstest; use std::{collections::HashMap, sync::Arc, vec}; @@ -878,11 +908,39 @@ pub mod tests { #[test_log::test(tokio::test)] async fn test_binary( #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + #[values(DataType::Utf8, DataType::Binary)] data_type: DataType, ) { - let field = Field::new("", DataType::Binary, false); + use lance_core::datatypes::STRUCTURAL_ENCODING_META_KEY; + + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + + let field = Field::new("", data_type, false).with_metadata(field_metadata); check_round_trip_encoding_random(field, version).await; } + #[rstest] + #[test_log::test(tokio::test)] + async fn test_binary_fsst( + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + ) { + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); + field_metadata.insert(COMPRESSION_META_KEY.to_string(), "fsst".into()); + + let field = Field::new("", DataType::Utf8, true).with_metadata(field_metadata); + check_round_trip_encoding_random(field, LanceFileVersion::V2_1).await; + } + #[test_log::test(tokio::test)] async fn test_large_binary() { let field = Field::new("", DataType::LargeBinary, true); @@ -897,10 +955,22 @@ pub mod tests { #[rstest] #[test_log::test(tokio::test)] - async fn test_simple_utf8_binary( + async fn test_simple_binary( #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + #[values(STRUCTURAL_ENCODING_MINIBLOCK, STRUCTURAL_ENCODING_FULLZIP)] + structural_encoding: &str, + #[values(DataType::Utf8, DataType::Binary)] data_type: DataType, ) { + use lance_core::datatypes::STRUCTURAL_ENCODING_META_KEY; + let string_array = StringArray::from(vec![Some("abc"), None, Some("pqr"), None, Some("m")]); + let string_array = arrow_cast::cast(&string_array, &data_type).unwrap(); + + let mut field_metadata = HashMap::new(); + field_metadata.insert( + STRUCTURAL_ENCODING_META_KEY.to_string(), + structural_encoding.into(), + ); let test_cases = TestCases::default() .with_range(0..2) @@ -911,7 +981,7 @@ pub mod tests { check_round_trip_encoding_of_data( vec![Arc::new(string_array)], &test_cases, - HashMap::new(), + field_metadata, ) .await; } @@ -1038,15 +1108,6 @@ pub mod tests { check_round_trip_encoding_of_data(arrs, &test_cases, HashMap::new()).await; } - #[rstest] - #[test_log::test(tokio::test)] - async fn test_binary_miniblock( - #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, - ) { - let field = Field::new("", DataType::Utf8, false); - check_round_trip_encoding_random(field, version).await; - } - #[test_log::test(tokio::test)] async fn test_binary_dictionary_encoding() { let test_cases = TestCases::default().with_file_version(LanceFileVersion::V2_1); diff --git a/rust/lance-encoding/src/encodings/physical/bitmap.rs b/rust/lance-encoding/src/encodings/physical/bitmap.rs index b61616fac6e..3e498452c5a 100644 --- a/rust/lance-encoding/src/encodings/physical/bitmap.rs +++ b/rust/lance-encoding/src/encodings/physical/bitmap.rs @@ -126,20 +126,79 @@ impl PrimitivePageDecoder for BitmapDecoder { #[cfg(test)] mod tests { + use arrow_array::BooleanArray; use arrow_schema::{DataType, Field}; use bytes::Bytes; + use rstest::rstest; + use std::{collections::HashMap, sync::Arc}; use crate::decoder::PrimitivePageDecoder; use crate::encodings::physical::bitmap::BitmapData; - use crate::testing::check_round_trip_encoding_random; + use crate::testing::{ + check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases, + }; use crate::version::LanceFileVersion; use super::BitmapDecoder; + #[rstest] #[test_log::test(tokio::test)] - async fn test_bitmap_boolean() { + async fn test_bitmap_boolean( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + ) { let field = Field::new("", DataType::Boolean, false); - check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await; + check_round_trip_encoding_random(field, version).await; + } + + #[test_log::test(tokio::test)] + async fn test_fsl_bitmap_boolean() { + let field = Field::new("", DataType::Boolean, true); + let field = Field::new("", DataType::FixedSizeList(Arc::new(field), 3), true); + check_round_trip_encoding_random(field, LanceFileVersion::V2_1).await; + } + + #[rstest] + #[test_log::test(tokio::test)] + async fn test_simple_boolean( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + ) { + let array = BooleanArray::from(vec![ + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + None, + None, + ]); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(0..3) + .with_range(1..9) + .with_indices(vec![0, 1, 3, 4]) + .with_file_version(version); + check_round_trip_encoding_of_data(vec![Arc::new(array)], &test_cases, HashMap::default()) + .await; + } + + #[rstest] + #[test_log::test(tokio::test)] + async fn test_tiny_boolean( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + ) { + // Test case for a tiny boolean array that is technically smaller than 1 byte + let array = BooleanArray::from(vec![Some(false), Some(true), None]); + + let test_cases = TestCases::default() + .with_range(0..1) + .with_range(1..3) + .with_indices(vec![0, 2]) + .with_file_version(version); + check_round_trip_encoding_of_data(vec![Arc::new(array)], &test_cases, HashMap::default()) + .await; } #[test] diff --git a/rust/lance-encoding/src/encodings/physical/bitpack.rs b/rust/lance-encoding/src/encodings/physical/bitpack.rs index 8c1ae3502ad..268349aafe5 100644 --- a/rust/lance-encoding/src/encodings/physical/bitpack.rs +++ b/rust/lance-encoding/src/encodings/physical/bitpack.rs @@ -14,7 +14,7 @@ use bytes::Bytes; use futures::future::{BoxFuture, FutureExt}; use log::trace; use num_traits::{AsPrimitive, PrimInt, ToPrimitive}; -use snafu::{location, Location}; +use snafu::location; use lance_arrow::DataTypeExt; use lance_core::{Error, Result}; @@ -530,9 +530,9 @@ enum StartOffset { /// * `buffer_len` - length buf buffer (in bytes) /// * `bits_per_value` - number of bits used to represent a single bitpacked value /// * `buffer_start_bit_offset` - offset of the start of the first value within the -/// buffer's first byte +/// buffer's first byte /// * `buffer_end_bit_offset` - end bit of the last value within the buffer. Can be -/// `None` if the end of the last value is byte aligned with end of buffer. +/// `None` if the end of the last value is byte aligned with end of buffer. fn compute_start_offset( rows_to_skip: u64, buffer_len: usize, diff --git a/rust/lance-encoding/src/encodings/physical/bitpack_fastlanes.rs b/rust/lance-encoding/src/encodings/physical/bitpack_fastlanes.rs index d5bd3dcf827..8f899ced424 100644 --- a/rust/lance-encoding/src/encodings/physical/bitpack_fastlanes.rs +++ b/rust/lance-encoding/src/encodings/physical/bitpack_fastlanes.rs @@ -7,12 +7,13 @@ use arrow::datatypes::{ Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow_array::{Array, PrimitiveArray}; +use arrow_buffer::ArrowNativeType; use arrow_schema::DataType; use byteorder::{ByteOrder, LittleEndian}; use bytes::Bytes; use futures::future::{BoxFuture, FutureExt}; use log::trace; -use snafu::{location, Location}; +use snafu::location; use lance_arrow::DataTypeExt; use lance_core::{Error, Result}; @@ -21,14 +22,19 @@ use crate::buffer::LanceBuffer; use crate::compression_algo::fastlanes::BitPacking; use crate::data::BlockInfo; use crate::data::{DataBlock, FixedWidthDataBlock, NullableDataBlock}; -use crate::decoder::{MiniBlockDecompressor, PageScheduler, PrimitivePageDecoder}; +use crate::decoder::{ + BlockDecompressor, FixedPerValueDecompressor, MiniBlockDecompressor, PageScheduler, + PrimitivePageDecoder, +}; use crate::encoder::{ - ArrayEncoder, EncodedArray, MiniBlockChunk, MiniBlockCompressed, MiniBlockCompressor, + ArrayEncoder, BlockCompressor, EncodedArray, MiniBlockChunk, MiniBlockCompressed, + MiniBlockCompressor, PerValueCompressor, PerValueDataBlock, }; use crate::format::{pb, ProtobufUtils}; use crate::statistics::{GetStat, Stat}; use arrow::array::ArrayRef; -use bytemuck::cast_slice; +use bytemuck::{cast_slice, AnyBitPattern}; + const LOG_ELEMS_PER_CHUNK: u8 = 10; const ELEMS_PER_CHUNK: u64 = 1 << LOG_ELEMS_PER_CHUNK; @@ -204,7 +210,7 @@ pub fn compute_compressed_bit_width_for_non_neg(arrays: &[ArrayRef]) -> u64 { // It outputs an fastlanes bitpacked EncodedArray macro_rules! encode_fixed_width { ($self:expr, $unpacked:expr, $data_type:ty, $buffer_index:expr) => {{ - let num_chunks = ($unpacked.num_values + ELEMS_PER_CHUNK - 1) / ELEMS_PER_CHUNK; + let num_chunks = $unpacked.num_values.div_ceil(ELEMS_PER_CHUNK); let num_full_chunks = $unpacked.num_values / ELEMS_PER_CHUNK; let uncompressed_bit_width = std::mem::size_of::<$data_type>() as u64 * 8; @@ -501,6 +507,7 @@ macro_rules! bitpacked_decode { while chunk_num * packed_chunk_size_in_byte < bytes.len() { // Copy for memory alignment + // TODO: This copy should not be needed let chunk_in_u8: Vec = bytes[chunk_num * packed_chunk_size_in_byte..] [..packed_chunk_size_in_byte] .to_vec(); @@ -597,985 +604,40 @@ fn bitpacked_for_non_neg_decode( } } -#[cfg(test)] -mod tests { - // use super::*; - // use arrow::array::{ - // Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, - // UInt8Array, - // }; - // use arrow::datatypes::DataType; - - // #[test_log::test(tokio::test)] - // async fn test_compute_compressed_bit_width_for_non_neg() {} - - // use std::collections::HashMap; - - // use lance_datagen::RowCount; - - // use crate::testing::{check_round_trip_encoding_of_data, TestCases}; - // use crate::version::LanceFileVersion; - - // async fn check_round_trip_bitpacked(array: Arc) { - // let test_cases = TestCases::default().with_file_version(LanceFileVersion::V2_1); - // check_round_trip_encoding_of_data(vec![array], &test_cases, HashMap::new()).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_u8() { - // let values: Vec = vec![5; 1024]; - // let array = UInt8Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = UInt8Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = UInt8Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = UInt8Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = UInt8Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(50)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt8)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_u16() { - // let values: Vec = vec![5; 1024]; - // let array = UInt16Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = UInt16Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = UInt16Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = UInt16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = UInt16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![300; 100]; - // let array = UInt16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![800; 100]; - // let array = UInt16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt16)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt16)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt16)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt16)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt16)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt16)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt16)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_u32() { - // let values: Vec = vec![5; 1024]; - // let array = UInt32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![7; 2000]; - // let array = UInt32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = UInt32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![666; 1000]; - // let array = UInt32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = UInt32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![1; 10000]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![300; 100]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![3000; 100]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![800; 100]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![8000; 100]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![65536; 100]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![655360; 100]; - // let array = UInt32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(50)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt32)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_u64() { - // let values: Vec = vec![5; 1024]; - // let array = UInt64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![7; 2000]; - // let array = UInt64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = UInt64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![666; 1000]; - // let array = UInt64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = UInt64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![1; 10000]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![300; 100]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![3000; 100]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![800; 100]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![8000; 100]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![65536; 100]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![655360; 100]; - // let array = UInt64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(50)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::UInt64)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_i8() { - // let values: Vec = vec![-5; 1024]; - // let array = Int8Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = Int8Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = Int8Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = Int8Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = Int8Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-88; 10000]; - // let array = Int8Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(50)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int8)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_i16() { - // let values: Vec = vec![-5; 1024]; - // let array = Int16Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = Int16Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = Int16Array::from(values); - // let array: Arc = Arc::new(array); - - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = Int16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = Int16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![300; 100]; - // let array = Int16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![800; 100]; - // let array = Int16Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(50)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int16)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_i32() { - // let values: Vec = vec![-5; 1024]; - // let array = Int32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = Int32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![-66; 1000]; - // let array = Int32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = Int32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![-77; 2000]; - // let array = Int32Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-88; 10000]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![300; 100]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-300; 100]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![800; 100]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-800; 100]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![65536; 100]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-65536; 100]; - // let array = Int32Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(50)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int32)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } - - // #[test_log::test(tokio::test)] - // async fn test_bitpack_fastlanes_i64() { - // let values: Vec = vec![-5; 1024]; - // let array = Int64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![66; 1000]; - // let array = Int64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![-66; 1000]; - // let array = Int64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![77; 2000]; - // let array = Int64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![-77; 2000]; - // let array = Int64Array::from(values); - // let array: Arc = Arc::new(array); - // check_round_trip_bitpacked(array).await; - - // let values: Vec = vec![0; 10000]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![88; 10000]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-88; 10000]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![300; 100]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-300; 100]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![800; 100]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-800; 100]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![65536; 100]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let values: Vec = vec![-65536; 100]; - // let array = Int64Array::from(values); - // let arr = Arc::new(array) as ArrayRef; - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(1)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(20)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(50)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(100)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(1000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(1024)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(2000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - - // let arr = lance_datagen::gen() - // .anon_col(lance_datagen::array::rand_type(&DataType::Int64)) - // .into_batch_rows(RowCount::from(3000)) - // .unwrap() - // .column(0) - // .clone(); - // check_round_trip_bitpacked(arr).await; - // } +#[derive(Debug, Default)] +pub struct InlineBitpacking { + uncompressed_bit_width: u64, } -// This macro chunks the FixedWidth DataBlock, bitpacks them with 1024 values per chunk, -// it puts the bit-width parameter in front of each chunk, -// and the bit-width parameter has the same bit-width as the uncompressed DataBlock -// for example, if the input DataBlock has `bits_per_value` of `16`, there will be 2 bytes(16 bits) -// in front of each chunk storing the `bit-width` parameter. -macro_rules! chunk_data_impl { - ($data:expr, $data_type:ty) => {{ - let data_buffer = $data.data.borrow_to_typed_slice::<$data_type>(); +impl InlineBitpacking { + pub fn new(uncompressed_bit_width: u64) -> Self { + Self { + uncompressed_bit_width, + } + } + + pub fn from_description(description: &pb::InlineBitpacking) -> Self { + Self { + uncompressed_bit_width: description.uncompressed_bits_per_value, + } + } + + pub fn min_size_bytes(bit_width: u64) -> u64 { + (ELEMS_PER_CHUNK * bit_width).div_ceil(8) + } + + /// Bitpacks a FixedWidthDataBlock into compressed chunks of 1024 values + /// + /// Each chunk can have a different bit width + /// + /// Each chunk has the compressed bit width stored inline in the chunk itself. + fn bitpack_chunked( + mut data: FixedWidthDataBlock, + ) -> MiniBlockCompressed { + let data_buffer = data.data.borrow_to_typed_slice::(); let data_buffer = data_buffer.as_ref(); - let bit_widths = $data - .get_stat(Stat::BitWidth) - .expect("FixedWidthDataBlock should have valid bit width statistics"); + let bit_widths = data.expect_stat(Stat::BitWidth); let bit_widths_array = bit_widths .as_any() .downcast_ref::>() @@ -1585,7 +647,7 @@ macro_rules! chunk_data_impl { .values() .iter() .map(|&bit_width| { - let chunk_size = ((1024 * bit_width) / $data.bits_per_value) as usize; + let chunk_size = ((1024 * bit_width) / data.bits_per_value) as usize; (chunk_size, chunk_size + 1) }) .fold( @@ -1596,86 +658,120 @@ macro_rules! chunk_data_impl { }, ); - let mut output: Vec<$data_type> = Vec::with_capacity(total_size); + let mut output: Vec = Vec::with_capacity(total_size); let mut chunks = Vec::with_capacity(bit_widths_array.len()); - for i in 0..bit_widths_array.len() - 1 { + for (i, packed_chunk_size) in packed_chunk_sizes + .iter() + .enumerate() + .take(bit_widths_array.len() - 1) + { let start_elem = i * ELEMS_PER_CHUNK as usize; - let bit_width = bit_widths_array.value(i) as $data_type; - output.push(bit_width); + let bit_width = bit_widths_array.value(i) as usize; + output.push(T::from_usize(bit_width).unwrap()); let output_len = output.len(); unsafe { - output.set_len(output_len + packed_chunk_sizes[i]); + output.set_len(output_len + *packed_chunk_size); BitPacking::unchecked_pack( - bit_width as usize, + bit_width, &data_buffer[start_elem..][..ELEMS_PER_CHUNK as usize], - &mut output[output_len..][..packed_chunk_sizes[i]], + &mut output[output_len..][..*packed_chunk_size], ); } chunks.push(MiniBlockChunk { - num_bytes: ((1 + packed_chunk_sizes[i]) * std::mem::size_of::<$data_type>()) as u16, + buffer_sizes: vec![((1 + *packed_chunk_size) * std::mem::size_of::()) as u16], log_num_values: LOG_ELEMS_PER_CHUNK, }); } // Handle the last chunk - let last_chunk_elem_num = if $data.num_values % ELEMS_PER_CHUNK == 0 { + let last_chunk_elem_num = if data.num_values % ELEMS_PER_CHUNK == 0 { 1024 } else { - $data.num_values % ELEMS_PER_CHUNK + data.num_values % ELEMS_PER_CHUNK }; - let mut last_chunk = vec![0; ELEMS_PER_CHUNK as usize]; + let mut last_chunk: Vec = vec![T::from_usize(0).unwrap(); ELEMS_PER_CHUNK as usize]; last_chunk[..last_chunk_elem_num as usize].clone_from_slice( - &data_buffer[$data.num_values as usize - last_chunk_elem_num as usize..], + &data_buffer[data.num_values as usize - last_chunk_elem_num as usize..], ); - let bit_width = bit_widths_array.value(bit_widths_array.len() - 1) as $data_type; - output.push(bit_width); + let bit_width = bit_widths_array.value(bit_widths_array.len() - 1) as usize; + output.push(T::from_usize(bit_width).unwrap()); let output_len = output.len(); unsafe { output.set_len(output_len + packed_chunk_sizes[bit_widths_array.len() - 1]); BitPacking::unchecked_pack( - bit_width as usize, + bit_width, &last_chunk, &mut output[output_len..][..packed_chunk_sizes[bit_widths_array.len() - 1]], ); } chunks.push(MiniBlockChunk { - num_bytes: ((1 + packed_chunk_sizes[bit_widths_array.len() - 1]) - * std::mem::size_of::<$data_type>()) as u16, + buffer_sizes: vec![ + ((1 + packed_chunk_sizes[bit_widths_array.len() - 1]) * std::mem::size_of::()) + as u16, + ], log_num_values: 0, }); - ( - MiniBlockCompressed { - data: LanceBuffer::reinterpret_vec(output), - chunks, - num_values: $data.num_values, - }, - ProtobufUtils::bitpack2($data.bits_per_value), - ) - }}; -} - -#[derive(Debug, Default)] -pub struct BitpackMiniBlockEncoder {} + MiniBlockCompressed { + data: vec![LanceBuffer::reinterpret_vec(output)], + chunks, + num_values: data.num_values, + } + } -impl BitpackMiniBlockEncoder { fn chunk_data( &self, - mut data: FixedWidthDataBlock, + data: FixedWidthDataBlock, ) -> (MiniBlockCompressed, crate::format::pb::ArrayEncoding) { assert!(data.bits_per_value % 8 == 0); - match data.bits_per_value { - 8 => chunk_data_impl!(data, u8), - 16 => chunk_data_impl!(data, u16), - 32 => chunk_data_impl!(data, u32), - 64 => chunk_data_impl!(data, u64), + assert_eq!(data.bits_per_value, self.uncompressed_bit_width); + let bits_per_value = data.bits_per_value; + let compressed = match bits_per_value { + 8 => Self::bitpack_chunked::(data), + 16 => Self::bitpack_chunked::(data), + 32 => Self::bitpack_chunked::(data), + 64 => Self::bitpack_chunked::(data), _ => unreachable!(), + }; + (compressed, ProtobufUtils::inline_bitpacking(bits_per_value)) + } + + fn unchunk( + data: LanceBuffer, + num_values: u64, + ) -> Result { + assert!(data.len() >= 8); + assert!(num_values <= ELEMS_PER_CHUNK); + + // This macro decompresses a chunk(1024 values) of bitpacked values. + let uncompressed_bit_width = std::mem::size_of::() * 8; + let mut decompressed = vec![T::from_usize(0).unwrap(); ELEMS_PER_CHUNK as usize]; + + // Copy for memory alignment + let chunk_in_u8: Vec = data.to_vec(); + let bit_width_bytes = &chunk_in_u8[..std::mem::size_of::()]; + let bit_width_value = LittleEndian::read_uint(bit_width_bytes, std::mem::size_of::()); + let chunk = cast_slice(&chunk_in_u8[std::mem::size_of::()..]); + + // The bit-packed chunk should have number of bytes (bit_width_value * ELEMS_PER_CHUNK / 8) + assert!(std::mem::size_of_val(chunk) == (bit_width_value * ELEMS_PER_CHUNK) as usize / 8); + + unsafe { + BitPacking::unchecked_unpack(bit_width_value as usize, chunk, &mut decompressed); } + + decompressed.truncate(num_values as usize); + Ok(DataBlock::FixedWidth(FixedWidthDataBlock { + data: LanceBuffer::reinterpret_vec(decompressed), + bits_per_value: uncompressed_bit_width as u64, + num_values, + block_info: BlockInfo::new(), + })) } } -impl MiniBlockCompressor for BitpackMiniBlockEncoder { +impl MiniBlockCompressor for InlineBitpacking { fn compress( &self, chunk: DataBlock, @@ -1694,67 +790,222 @@ impl MiniBlockCompressor for BitpackMiniBlockEncoder { } } -/// A decompressor for fixed-width data that has -/// been written, as-is, to disk in single contiguous array -#[derive(Debug)] -pub struct BitpackMiniBlockDecompressor { - uncompressed_bit_width: u64, +impl BlockCompressor for InlineBitpacking { + fn compress(&self, data: DataBlock) -> Result { + let fixed_width = data.as_fixed_width().unwrap(); + let (chunked, _) = self.chunk_data(fixed_width); + Ok(chunked.data.into_iter().next().unwrap()) + } } -impl BitpackMiniBlockDecompressor { - pub fn new(description: &pb::Bitpack2) -> Self { - Self { - uncompressed_bit_width: description.uncompressed_bits_per_value, +impl MiniBlockDecompressor for InlineBitpacking { + fn decompress(&self, data: Vec, num_values: u64) -> Result { + assert_eq!(data.len(), 1); + let data = data.into_iter().next().unwrap(); + match self.uncompressed_bit_width { + 8 => Self::unchunk::(data, num_values), + 16 => Self::unchunk::(data, num_values), + 32 => Self::unchunk::(data, num_values), + 64 => Self::unchunk::(data, num_values), + _ => unimplemented!("Bitpacking word size must be 8, 16, 32, or 64"), } } } -impl MiniBlockDecompressor for BitpackMiniBlockDecompressor { +impl BlockDecompressor for InlineBitpacking { fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result { - assert!(data.len() >= 8); - assert!(num_values <= ELEMS_PER_CHUNK); + match self.uncompressed_bit_width { + 8 => Self::unchunk::(data, num_values), + 16 => Self::unchunk::(data, num_values), + 32 => Self::unchunk::(data, num_values), + 64 => Self::unchunk::(data, num_values), + _ => unimplemented!("Bitpacking word size must be 8, 16, 32, or 64"), + } + } +} - // This macro decompresses a chunk(1024 values) of bitpacked values. - macro_rules! decompress_impl { - ($type:ty) => {{ - let uncompressed_bit_width = std::mem::size_of::<$type>() * 8; - let mut decompressed = vec![0 as $type; ELEMS_PER_CHUNK as usize]; +/// Bitpacks a FixedWidthDataBlock with a given bit width +/// +/// This function is simpler as it does not do any chunking, but slightly less efficient. +/// The compressed bits per value is constant across the entire buffer. +/// +/// Note: even though we are not strictly "chunking" we are still operating on chunks of +/// 1024 values because that's what the bitpacking primitives expect. They just don't +/// have a unique bit width for each chunk. +fn bitpack_out_of_line( + mut data: FixedWidthDataBlock, + bit_width: usize, +) -> LanceBuffer { + let data_buffer = data.data.borrow_to_typed_slice::(); + let data_buffer = data_buffer.as_ref(); + + let num_chunks = data_buffer.len().div_ceil(ELEMS_PER_CHUNK as usize); + let last_chunk_is_runt = data_buffer.len() % ELEMS_PER_CHUNK as usize != 0; + let words_per_chunk = + (ELEMS_PER_CHUNK as usize * bit_width).div_ceil(data.bits_per_value as usize); + #[allow(clippy::uninit_vec)] + let mut output: Vec = Vec::with_capacity(num_chunks * words_per_chunk); + #[allow(clippy::uninit_vec)] + unsafe { + output.set_len(num_chunks * words_per_chunk); + } - // Copy for memory alignment - let chunk_in_u8: Vec = data.to_vec(); - let bit_width_bytes = &chunk_in_u8[..std::mem::size_of::<$type>()]; - let bit_width_value = LittleEndian::read_uint(bit_width_bytes, std::mem::size_of::<$type>()); - let chunk = cast_slice(&chunk_in_u8[std::mem::size_of::<$type>()..]); + let num_whole_chunks = if last_chunk_is_runt { + num_chunks - 1 + } else { + num_chunks + }; - // The bit-packed chunk should have number of bytes (bit_width_value * ELEMS_PER_CHUNK / 8) - assert!(chunk.len() * std::mem::size_of::<$type>() == (bit_width_value * ELEMS_PER_CHUNK as u64) as usize / 8); + // Simple case for complete chunks + for i in 0..num_whole_chunks { + let input_start = i * ELEMS_PER_CHUNK as usize; + let input_end = input_start + ELEMS_PER_CHUNK as usize; + let output_start = i * words_per_chunk; + let output_end = output_start + words_per_chunk; + unsafe { + BitPacking::unchecked_pack( + bit_width, + &data_buffer[input_start..input_end], + &mut output[output_start..output_end], + ); + } + } - unsafe { - BitPacking::unchecked_unpack( - bit_width_value as usize, - chunk, - &mut decompressed, - ); - } + if !last_chunk_is_runt { + return LanceBuffer::reinterpret_vec(output); + } - decompressed.shrink_to(num_values as usize); - Ok(DataBlock::FixedWidth(FixedWidthDataBlock { - data: LanceBuffer::reinterpret_vec(decompressed), - bits_per_value: uncompressed_bit_width as u64, - num_values, - block_info: BlockInfo::new(), - })) - }}; - } + // If we get here then the last chunk needs to be padded with zeros + let remaining_items = data_buffer.len() % ELEMS_PER_CHUNK as usize; + let last_chunk_start = num_whole_chunks * ELEMS_PER_CHUNK as usize; - match self.uncompressed_bit_width { - 8 => decompress_impl!(u8), - 16 => decompress_impl!(u16), - 32 => decompress_impl!(u32), - 64 => decompress_impl!(u64), - _ => todo!(), + let mut last_chunk: Vec = vec![T::from_usize(0).unwrap(); ELEMS_PER_CHUNK as usize]; + last_chunk[..remaining_items].clone_from_slice(&data_buffer[last_chunk_start..]); + let output_start = num_whole_chunks * words_per_chunk; + unsafe { + BitPacking::unchecked_pack(bit_width, &last_chunk, &mut output[output_start..]); + } + + LanceBuffer::reinterpret_vec(output) +} + +/// Unpacks a FixedWidthDataBlock that has been bitpacked with a constant bit width +fn unpack_out_of_line( + mut data: FixedWidthDataBlock, + num_values: usize, + bits_per_value: usize, +) -> FixedWidthDataBlock { + let words_per_chunk = + (ELEMS_PER_CHUNK as usize * bits_per_value).div_ceil(data.bits_per_value as usize); + let compressed_words = data.data.borrow_to_typed_slice::(); + + let num_chunks = data.num_values as usize / words_per_chunk; + debug_assert_eq!(data.num_values as usize % words_per_chunk, 0); + + // This is slightly larger than num_values because the last chunk has some padding, we will truncate at the end + #[allow(clippy::uninit_vec)] + let mut decompressed = Vec::with_capacity(num_chunks * ELEMS_PER_CHUNK as usize); + #[allow(clippy::uninit_vec)] + unsafe { + decompressed.set_len(num_chunks * ELEMS_PER_CHUNK as usize); + } + + for chunk_idx in 0..num_chunks { + let input_start = chunk_idx * words_per_chunk; + let input_end = input_start + words_per_chunk; + let output_start = chunk_idx * ELEMS_PER_CHUNK as usize; + let output_end = output_start + ELEMS_PER_CHUNK as usize; + unsafe { + BitPacking::unchecked_unpack( + bits_per_value, + &compressed_words[input_start..input_end], + &mut decompressed[output_start..output_end], + ); } } + + decompressed.truncate(num_values); + + FixedWidthDataBlock { + data: LanceBuffer::reinterpret_vec(decompressed), + bits_per_value: data.bits_per_value, + num_values: num_values as u64, + block_info: BlockInfo::new(), + } +} + +/// A transparent compressor that bit packs data +/// +/// In order for the encoding to be transparent we must have a fixed bit width +/// across the entire array. Chunking within the buffer is not supported. This +/// means that we will be slightly less efficient than something like the mini-block +/// approach. +/// +/// WARNING: DO NOT USE YET. +/// +/// This was an interesting experiment but it can't be used as a per-value compressor +/// at the moment. The resulting data IS transparent but it's not quite so simple. We +/// compress in blocks of 1024 and each block has a fixed size but also has some padding. +/// +/// In other words, if we try the simple math to access the item at index `i` we will be +/// out of luck because `bits_per_value * i` is not the location. What we need is something +/// like: +/// +/// ```ignore +/// let chunk_idx = i / 1024; +/// let chunk_offset = i % 1024; +/// bits_per_chunk * chunk_idx + bits_per_value * chunk_offset +/// ``` +/// +/// However, this logic isn't expressible with the per-value traits we have today. We can +/// enhance these traits should we need to support it at some point in the future. +#[derive(Debug)] +pub struct OutOfLineBitpacking { + compressed_bit_width: usize, +} + +impl PerValueCompressor for OutOfLineBitpacking { + fn compress( + &self, + data: DataBlock, + ) -> Result<(crate::encoder::PerValueDataBlock, pb::ArrayEncoding)> { + let fixed_width = data.as_fixed_width().unwrap(); + let num_values = fixed_width.num_values; + let word_size = fixed_width.bits_per_value; + let compressed = match word_size { + 8 => bitpack_out_of_line::(fixed_width, self.compressed_bit_width), + 16 => bitpack_out_of_line::(fixed_width, self.compressed_bit_width), + 32 => bitpack_out_of_line::(fixed_width, self.compressed_bit_width), + 64 => bitpack_out_of_line::(fixed_width, self.compressed_bit_width), + _ => panic!("Bitpacking word size must be 8,16,32,64"), + }; + let compressed = FixedWidthDataBlock { + data: compressed, + bits_per_value: self.compressed_bit_width as u64, + num_values, + block_info: BlockInfo::new(), + }; + let encoding = + ProtobufUtils::out_of_line_bitpacking(word_size, self.compressed_bit_width as u64); + Ok((PerValueDataBlock::Fixed(compressed), encoding)) + } +} + +impl FixedPerValueDecompressor for OutOfLineBitpacking { + fn decompress(&self, data: FixedWidthDataBlock, num_values: u64) -> Result { + let unpacked = match data.bits_per_value { + 8 => unpack_out_of_line::(data, num_values as usize, self.compressed_bit_width), + 16 => unpack_out_of_line::(data, num_values as usize, self.compressed_bit_width), + 32 => unpack_out_of_line::(data, num_values as usize, self.compressed_bit_width), + 64 => unpack_out_of_line::(data, num_values as usize, self.compressed_bit_width), + _ => panic!("Bitpacking word size must be 8,16,32,64"), + }; + Ok(DataBlock::FixedWidth(unpacked)) + } + + fn bits_per_value(&self) -> u64 { + self.compressed_bit_width as u64 + } } #[cfg(test)] diff --git a/rust/lance-encoding/src/encodings/physical/block_compress.rs b/rust/lance-encoding/src/encodings/physical/block_compress.rs index c3ddd3326af..508bdef6e1f 100644 --- a/rust/lance-encoding/src/encodings/physical/block_compress.rs +++ b/rust/lance-encoding/src/encodings/physical/block_compress.rs @@ -1,8 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use arrow_buffer::ArrowNativeType; use arrow_schema::DataType; -use snafu::{location, Location}; +use snafu::location; use std::{ io::{Cursor, Write}, str::FromStr, @@ -11,9 +12,11 @@ use std::{ use lance_core::{Error, Result}; use crate::{ - data::{BlockInfo, DataBlock, OpaqueBlock}, - encoder::{ArrayEncoder, EncodedArray}, - format::ProtobufUtils, + buffer::LanceBuffer, + data::{BlockInfo, DataBlock, OpaqueBlock, VariableWidthBlock}, + decoder::VariablePerValueDecompressor, + encoder::{ArrayEncoder, EncodedArray, PerValueCompressor, PerValueDataBlock}, + format::{pb, ProtobufUtils}, }; #[derive(Debug, Clone, Copy, PartialEq)] @@ -31,7 +34,7 @@ impl CompressionConfig { impl Default for CompressionConfig { fn default() -> Self { Self { - scheme: CompressionScheme::Zstd, + scheme: CompressionScheme::Lz4, level: Some(0), } } @@ -40,14 +43,18 @@ impl Default for CompressionConfig { #[derive(Debug, Clone, Copy, PartialEq)] pub enum CompressionScheme { None, + Fsst, Zstd, + Lz4, } impl std::fmt::Display for CompressionScheme { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let scheme_str = match self { + Self::Fsst => "fsst", Self::Zstd => "zstd", Self::None => "none", + Self::Lz4 => "lz4", }; write!(f, "{}", scheme_str) } @@ -71,6 +78,7 @@ impl FromStr for CompressionScheme { pub trait BufferCompressor: std::fmt::Debug + Send + Sync { fn compress(&self, input_buf: &[u8], output_buf: &mut Vec) -> Result<()>; fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec) -> Result<()>; + fn name(&self) -> &str; } #[derive(Debug, Default)] @@ -99,6 +107,37 @@ impl BufferCompressor for ZstdBufferCompressor { zstd::stream::copy_decode(source, output_buf)?; Ok(()) } + + fn name(&self) -> &str { + "zstd" + } +} + +#[derive(Debug, Default)] +pub struct Lz4BufferCompressor {} + +impl BufferCompressor for Lz4BufferCompressor { + fn compress(&self, input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { + lz4::block::compress_to_buffer(input_buf, None, true, output_buf) + .map_err(|err| Error::Internal { + message: format!("LZ4 compression error: {}", err), + location: location!(), + }) + .map(|_| ()) + } + + fn decompress(&self, input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { + lz4::block::decompress_to_buffer(input_buf, None, output_buf) + .map_err(|err| Error::Internal { + message: format!("LZ4 decompression error: {}", err), + location: location!(), + }) + .map(|_| ()) + } + + fn name(&self) -> &str { + "zstd" + } } #[derive(Debug, Default)] @@ -114,6 +153,10 @@ impl BufferCompressor for NoopBufferCompressor { output_buf.extend_from_slice(input_buf); Ok(()) } + + fn name(&self) -> &str { + "none" + } } pub struct GeneralBufferCompressor {} @@ -121,9 +164,12 @@ pub struct GeneralBufferCompressor {} impl GeneralBufferCompressor { pub fn get_compressor(compression_config: CompressionConfig) -> Box { match compression_config.scheme { + // FSST has its own compression path and isn't implemented as a generic buffer compressor + CompressionScheme::Fsst => unimplemented!(), CompressionScheme::Zstd => Box::new(ZstdBufferCompressor::new( compression_config.level.unwrap_or(0), )), + CompressionScheme::Lz4 => Box::new(Lz4BufferCompressor::default()), CompressionScheme::None => Box::new(NoopBufferCompressor {}), } } @@ -151,6 +197,16 @@ impl CompressedBufferEncoder { let compressor = GeneralBufferCompressor::get_compressor(compression_config); Self { compressor } } + + pub fn from_scheme(scheme: &str) -> Result { + let scheme = CompressionScheme::from_str(scheme)?; + Ok(Self { + compressor: GeneralBufferCompressor::get_compressor(CompressionConfig { + scheme, + level: Some(0), + }), + }) + } } impl ArrayEncoder for CompressedBufferEncoder { @@ -188,6 +244,117 @@ impl ArrayEncoder for CompressedBufferEncoder { } } +impl CompressedBufferEncoder { + pub fn per_value_compress( + &self, + data: &[u8], + offsets: &[T], + compressed: &mut Vec, + ) -> Result { + let mut new_offsets: Vec = Vec::with_capacity(offsets.len()); + new_offsets.push(T::from_usize(0).unwrap()); + + for off in offsets.windows(2) { + let start = off[0].as_usize(); + let end = off[1].as_usize(); + self.compressor.compress(&data[start..end], compressed)?; + new_offsets.push(T::from_usize(compressed.len()).unwrap()); + } + + Ok(LanceBuffer::reinterpret_vec(new_offsets)) + } + + pub fn per_value_decompress( + &self, + data: &[u8], + offsets: &[T], + decompressed: &mut Vec, + ) -> Result { + let mut new_offsets: Vec = Vec::with_capacity(offsets.len()); + new_offsets.push(T::from_usize(0).unwrap()); + + for off in offsets.windows(2) { + let start = off[0].as_usize(); + let end = off[1].as_usize(); + self.compressor + .decompress(&data[start..end], decompressed)?; + new_offsets.push(T::from_usize(decompressed.len()).unwrap()); + } + + Ok(LanceBuffer::reinterpret_vec(new_offsets)) + } +} + +impl PerValueCompressor for CompressedBufferEncoder { + fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)> { + let data_type = data.name(); + let mut data = data.as_variable_width().ok_or(Error::Internal { + message: format!( + "Attempt to use CompressedBufferEncoder on data of type {}", + data_type + ), + location: location!(), + })?; + + let data_bytes = &data.data; + let mut compressed = Vec::with_capacity(data_bytes.len()); + + let new_offsets = match data.bits_per_offset { + 32 => self.per_value_compress::( + data_bytes, + &data.offsets.borrow_to_typed_slice::(), + &mut compressed, + )?, + 64 => self.per_value_compress::( + data_bytes, + &data.offsets.borrow_to_typed_slice::(), + &mut compressed, + )?, + _ => unreachable!(), + }; + + let compressed = PerValueDataBlock::Variable(VariableWidthBlock { + bits_per_offset: data.bits_per_offset, + data: LanceBuffer::from(compressed), + offsets: new_offsets, + num_values: data.num_values, + block_info: BlockInfo::new(), + }); + + let encoding = ProtobufUtils::block(self.compressor.name()); + + Ok((compressed, encoding)) + } +} + +impl VariablePerValueDecompressor for CompressedBufferEncoder { + fn decompress(&self, mut data: VariableWidthBlock) -> Result { + let data_bytes = &data.data; + let mut decompressed = Vec::with_capacity(data_bytes.len() * 2); + + let new_offsets = match data.bits_per_offset { + 32 => self.per_value_decompress( + data_bytes, + &data.offsets.borrow_to_typed_slice::(), + &mut decompressed, + )?, + 64 => self.per_value_decompress( + data_bytes, + &data.offsets.borrow_to_typed_slice::(), + &mut decompressed, + )?, + _ => unreachable!(), + }; + Ok(DataBlock::VariableWidth(VariableWidthBlock { + bits_per_offset: data.bits_per_offset, + data: LanceBuffer::from(decompressed), + offsets: new_offsets, + num_values: data.num_values, + block_info: BlockInfo::new(), + })) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance-encoding/src/encodings/physical/dictionary.rs b/rust/lance-encoding/src/encodings/physical/dictionary.rs index 89cd6a046e3..2ed9a09a70e 100644 --- a/rust/lance-encoding/src/encodings/physical/dictionary.rs +++ b/rust/lance-encoding/src/encodings/physical/dictionary.rs @@ -14,7 +14,7 @@ use arrow_schema::DataType; use futures::{future::BoxFuture, FutureExt}; use lance_arrow::DataTypeExt; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; use std::collections::HashMap; use crate::buffer::LanceBuffer; diff --git a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs index 6f8ba702fd7..eb06f5188f9 100644 --- a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs +++ b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs @@ -9,10 +9,10 @@ use lance_core::Result; use log::trace; use crate::{ - data::{BlockInfo, DataBlock, FixedSizeListBlock, FixedWidthDataBlock}, - decoder::{PageScheduler, PerValueDecompressor, PrimitivePageDecoder}, - encoder::{ArrayEncoder, EncodedArray, PerValueCompressor, PerValueDataBlock}, - format::{pb, ProtobufUtils}, + data::{DataBlock, FixedSizeListBlock}, + decoder::{PageScheduler, PrimitivePageDecoder}, + encoder::{ArrayEncoder, EncodedArray}, + format::ProtobufUtils, EncodingsIo, }; @@ -123,121 +123,150 @@ impl ArrayEncoder for FslEncoder { dimension: self.dimension as u64, }); - let encoding = ProtobufUtils::fixed_size_list(encoded_data.encoding, self.dimension as u64); + let encoding = + ProtobufUtils::fsl_encoding(self.dimension as u64, encoded_data.encoding, false); Ok(EncodedArray { data, encoding }) } } -/// A compressor for primitive FSLs that flattens each list into a -/// single value. If the inner list has validity then the validity -/// is zipped in with the values. -/// -/// In other words, if the list is FSL [[0, NULL], [4, 10]] then the -/// two buffers start as: -/// -/// values: 0x00 0x?? 0x04 0x0A -/// validity: 0b1011 -/// -/// The output will be: -/// -/// zipped: 0x01 0x00 0x00 0x?? 0x01 0x04 0x01 0x0A -/// -/// Note that we expand validity to be at least a byte per value so this -/// approach is not ideal for small lists, though we should be using mini-block -/// for small lists anyways. -#[derive(Debug)] -pub struct FslPerValueCompressor { - items_compressor: Box, - dimension: u64, -} - -impl FslPerValueCompressor { - pub fn new(items_compressor: Box, dimension: u64) -> Self { - Self { - items_compressor, - dimension, - } - } -} - -impl PerValueCompressor for FslPerValueCompressor { - fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)> { - let mut data = data.as_fixed_size_list().unwrap(); - let flattened = match data.child.as_mut() { - DataBlock::FixedWidth(fixed_width) => DataBlock::FixedWidth(FixedWidthDataBlock { - bits_per_value: fixed_width.bits_per_value * self.dimension, - data: fixed_width.data.borrow_and_clone(), - block_info: BlockInfo::new(), - num_values: fixed_width.num_values / self.dimension, - }), - DataBlock::VariableWidth(_) => todo!("GH-3111: FSL with variable inner type"), - DataBlock::Nullable(_) => todo!("GH-3112: FSL with nullable inner type"), - DataBlock::FixedSizeList(_) => todo!("GH-3113: Nested FSLs"), - _ => unreachable!(), - }; - let (compressed, encoding) = self.items_compressor.compress(flattened)?; - let wrapped_encoding = ProtobufUtils::fixed_size_list(encoding, self.dimension); - - Ok((compressed, wrapped_encoding)) - } -} - -/// Reversed the process described in [`FslPerValueCompressor`] -#[derive(Debug)] -pub struct FslPerValueDecompressor { - items_decompressor: Box, - dimension: u64, -} - -impl FslPerValueDecompressor { - pub fn new(items_decompressor: Box, dimension: u64) -> Self { - Self { - items_decompressor, - dimension, - } - } -} - -impl PerValueDecompressor for FslPerValueDecompressor { - fn decompress(&self, data: crate::buffer::LanceBuffer, num_values: u64) -> Result { - let decompressed = self.items_decompressor.decompress(data, num_values)?; - let unflattened = match decompressed { - DataBlock::FixedWidth(fixed_width) => DataBlock::FixedWidth(FixedWidthDataBlock { - bits_per_value: fixed_width.bits_per_value / self.dimension, - data: fixed_width.data, - block_info: BlockInfo::new(), - num_values: fixed_width.num_values * self.dimension, - }), - _ => todo!(), - }; - Ok(DataBlock::FixedSizeList(FixedSizeListBlock { - child: Box::new(unflattened), - dimension: self.dimension, - })) - } - - fn bits_per_value(&self) -> u64 { - self.items_decompressor.bits_per_value() - } -} - #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{collections::HashMap, sync::Arc}; + use arrow::datatypes::Int32Type; + use arrow_array::{FixedSizeListArray, Int32Array}; + use arrow_buffer::{BooleanBuffer, NullBuffer}; use arrow_schema::{DataType, Field}; + use lance_datagen::{array, gen_array, ArrayGeneratorExt, RowCount}; + use rstest::rstest; - use crate::{testing::check_round_trip_encoding_random, version::LanceFileVersion}; + use crate::{ + testing::{check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases}, + version::LanceFileVersion, + }; const PRIMITIVE_TYPES: &[DataType] = &[DataType::Int8, DataType::Float32, DataType::Float64]; + #[rstest] #[test_log::test(tokio::test)] - async fn test_value_fsl_primitive() { + async fn test_value_fsl_primitive( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + ) { for data_type in PRIMITIVE_TYPES { let inner_field = Field::new("item", data_type.clone(), true); let data_type = DataType::FixedSizeList(Arc::new(inner_field), 16); let field = Field::new("", data_type, false); - check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await; + check_round_trip_encoding_random(field, version).await; } } + + #[test_log::test(tokio::test)] + async fn test_simple_fsl() { + // [0, NULL], NULL, [4, 5] + let items = Arc::new(Int32Array::from(vec![ + Some(0), + None, + Some(2), + Some(3), + Some(4), + Some(5), + ])); + let items_field = Arc::new(Field::new("item", DataType::Int32, true)); + let list_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true])); + let list = Arc::new(FixedSizeListArray::new( + items_field, + 2, + items, + Some(list_nulls), + )); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_range(0..2) + .with_range(1..3) + .with_indices(vec![0, 1, 2]) + .with_indices(vec![1]) + .with_indices(vec![2]) + .with_file_version(LanceFileVersion::V2_1); + + check_round_trip_encoding_of_data(vec![list], &test_cases, HashMap::default()).await; + } + + #[test_log::test(tokio::test)] + #[ignore] + async fn test_simple_wide_fsl() { + let items = gen_array(array::rand::().with_random_nulls(0.1)) + .into_array_rows(RowCount::from(4096)) + .unwrap(); + let items_field = Arc::new(Field::new("item", DataType::Int32, true)); + let list_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, false])); + let list = Arc::new(FixedSizeListArray::new( + items_field, + 1024, + items, + Some(list_nulls), + )); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_range(0..2) + .with_range(1..3) + .with_indices(vec![0, 1, 2]) + .with_indices(vec![1]) + .with_indices(vec![2]) + .with_file_version(LanceFileVersion::V2_1); + + check_round_trip_encoding_of_data(vec![list], &test_cases, HashMap::default()).await; + } + + #[test_log::test(tokio::test)] + async fn test_nested_fsl() { + // [[0, 1], NULL], NULL, [[8, 9], [NULL, 11]] + let items = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + None, + None, + None, + None, + None, + None, + Some(8), + Some(9), + None, + Some(11), + ])); + let items_field = Arc::new(Field::new("item", DataType::Int32, true)); + let inner_list_nulls = NullBuffer::new(BooleanBuffer::from(vec![ + true, false, false, false, true, true, + ])); + let inner_list = Arc::new(FixedSizeListArray::new( + items_field.clone(), + 2, + items, + Some(inner_list_nulls), + )); + let inner_list_field = Arc::new(Field::new( + "item", + DataType::FixedSizeList(items_field, 2), + true, + )); + let outer_list_nulls = NullBuffer::new(BooleanBuffer::from(vec![true, false, true])); + let outer_list = Arc::new(FixedSizeListArray::new( + inner_list_field, + 2, + inner_list, + Some(outer_list_nulls), + )); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_range(0..2) + .with_range(1..3) + .with_indices(vec![0, 1, 2]) + .with_indices(vec![2]) + .with_file_version(LanceFileVersion::V2_1); + + check_round_trip_encoding_of_data(vec![outer_list], &test_cases, HashMap::default()).await; + } } diff --git a/rust/lance-encoding/src/encodings/physical/fsst.rs b/rust/lance-encoding/src/encodings/physical/fsst.rs index b247b8c290a..04827bedd94 100644 --- a/rust/lance-encoding/src/encodings/physical/fsst.rs +++ b/rust/lance-encoding/src/encodings/physical/fsst.rs @@ -8,16 +8,22 @@ use arrow_schema::DataType; use futures::{future::BoxFuture, FutureExt}; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; use crate::{ buffer::LanceBuffer, data::{BlockInfo, DataBlock, NullableDataBlock, VariableWidthBlock}, - decoder::{MiniBlockDecompressor, PageScheduler, PrimitivePageDecoder}, - encoder::{ArrayEncoder, EncodedArray}, - encoder::{MiniBlockCompressed, MiniBlockCompressor}, - format::pb::{self}, - format::ProtobufUtils, + decoder::{ + MiniBlockDecompressor, PageScheduler, PrimitivePageDecoder, VariablePerValueDecompressor, + }, + encoder::{ + ArrayEncoder, EncodedArray, MiniBlockCompressed, MiniBlockCompressor, PerValueCompressor, + PerValueDataBlock, + }, + format::{ + pb::{self}, + ProtobufUtils, + }, EncodingsIo, }; @@ -26,11 +32,11 @@ use super::binary::{BinaryMiniBlockDecompressor, BinaryMiniBlockEncoder}; #[derive(Debug)] pub struct FsstPageScheduler { inner_scheduler: Box, - symbol_table: Vec, + symbol_table: LanceBuffer, } impl FsstPageScheduler { - pub fn new(inner_scheduler: Box, symbol_table: Vec) -> Self { + pub fn new(inner_scheduler: Box, symbol_table: LanceBuffer) -> Self { Self { inner_scheduler, symbol_table, @@ -48,7 +54,7 @@ impl PageScheduler for FsstPageScheduler { let inner_decoder = self .inner_scheduler .schedule_ranges(ranges, scheduler, top_level_row); - let symbol_table = self.symbol_table.clone(); + let symbol_table = self.symbol_table.try_clone().unwrap(); async move { let inner_decoder = inner_decoder.await?; @@ -63,7 +69,7 @@ impl PageScheduler for FsstPageScheduler { struct FsstPageDecoder { inner_decoder: Box, - symbol_table: Vec, + symbol_table: LanceBuffer, } impl PrimitivePageDecoder for FsstPageDecoder { @@ -202,14 +208,13 @@ impl ArrayEncoder for FsstArrayEncoder { } } -#[derive(Debug, Default)] -pub struct FsstMiniBlockEncoder {} +struct FsstCompressed { + data: VariableWidthBlock, + symbol_table: Vec, +} -impl MiniBlockCompressor for FsstMiniBlockEncoder { - fn compress( - &self, - data: DataBlock, - ) -> Result<(MiniBlockCompressed, crate::format::pb::ArrayEncoding)> { +impl FsstCompressed { + fn fsst_compress(data: DataBlock) -> Result { match data { DataBlock::VariableWidth(mut variable_width) => { let offsets = variable_width.offsets.borrow_to_typed_slice::(); @@ -231,29 +236,22 @@ impl MiniBlockCompressor for FsstMiniBlockEncoder { )?; // construct `DataBlock` for BinaryMiniBlockEncoder, we may want some `DataBlock` construct methods later - let data_block = DataBlock::VariableWidth(VariableWidthBlock { + let compressed = VariableWidthBlock { data: LanceBuffer::reinterpret_vec(dest_values), bits_per_offset: 32, offsets: LanceBuffer::reinterpret_vec(dest_offsets), num_values: variable_width.num_values, block_info: BlockInfo::new(), - }); - - // compress the fsst compressed data using `BinaryMiniBlockEncoder` - let binary_compressor = - Box::new(BinaryMiniBlockEncoder::default()) as Box; - - let (binary_miniblock_compressed, binary_array_encoding) = - binary_compressor.compress(data_block)?; + }; - Ok(( - binary_miniblock_compressed, - ProtobufUtils::fsst_mini_block(binary_array_encoding, symbol_table), - )) + Ok(Self { + data: compressed, + symbol_table, + }) } _ => Err(Error::InvalidInput { source: format!( - "Cannot compress a data block of type {} with BinaryMiniBlockEncoder", + "Cannot compress a data block of type {} with FsstEncoder", data.name() ) .into(), @@ -263,21 +261,131 @@ impl MiniBlockCompressor for FsstMiniBlockEncoder { } } +#[derive(Debug, Default)] +pub struct FsstMiniBlockEncoder {} + +impl MiniBlockCompressor for FsstMiniBlockEncoder { + fn compress( + &self, + data: DataBlock, + ) -> Result<(MiniBlockCompressed, crate::format::pb::ArrayEncoding)> { + let compressed = FsstCompressed::fsst_compress(data)?; + + let data_block = DataBlock::VariableWidth(compressed.data); + + // compress the fsst compressed data using `BinaryMiniBlockEncoder` + let binary_compressor = + Box::new(BinaryMiniBlockEncoder::default()) as Box; + + let (binary_miniblock_compressed, binary_array_encoding) = + binary_compressor.compress(data_block)?; + + Ok(( + binary_miniblock_compressed, + ProtobufUtils::fsst(binary_array_encoding, compressed.symbol_table), + )) + } +} + +#[derive(Debug)] +pub struct FsstPerValueEncoder { + inner: Box, +} + +impl FsstPerValueEncoder { + pub fn new(inner: Box) -> Self { + Self { inner } + } +} + +impl PerValueCompressor for FsstPerValueEncoder { + fn compress(&self, data: DataBlock) -> Result<(PerValueDataBlock, pb::ArrayEncoding)> { + let compressed = FsstCompressed::fsst_compress(data)?; + + let data_block = DataBlock::VariableWidth(compressed.data); + + let (binary_compressed, binary_array_encoding) = self.inner.compress(data_block)?; + + Ok(( + binary_compressed, + ProtobufUtils::fsst(binary_array_encoding, compressed.symbol_table), + )) + } +} + +#[derive(Debug)] +pub struct FsstPerValueDecompressor { + symbol_table: LanceBuffer, + inner_decompressor: Box, +} + +impl FsstPerValueDecompressor { + pub fn new( + symbol_table: LanceBuffer, + inner_decompressor: Box, + ) -> Self { + Self { + symbol_table, + inner_decompressor, + } + } +} + +impl VariablePerValueDecompressor for FsstPerValueDecompressor { + fn decompress(&self, data: VariableWidthBlock) -> Result { + // Step 1. Run inner decompressor + let mut compressed_variable_data = self + .inner_decompressor + .decompress(data)? + .as_variable_width() + .unwrap(); + + // Step 2. FSST decompress + let bytes = compressed_variable_data.data.borrow_to_typed_slice::(); + let bytes = bytes.as_ref(); + let offsets = compressed_variable_data + .offsets + .borrow_to_typed_slice::(); + let offsets = offsets.as_ref(); + let num_values = compressed_variable_data.num_values; + + // The data will expand at most 8 times + // The offsets will be the same size because we have the same # of strings + let mut decompress_bytes_buf = vec![0u8; bytes.len() * 8]; + let mut decompress_offset_buf = vec![0i32; offsets.len()]; + fsst::fsst::decompress( + &self.symbol_table, + bytes, + offsets, + &mut decompress_bytes_buf, + &mut decompress_offset_buf, + )?; + + Ok(DataBlock::VariableWidth(VariableWidthBlock { + data: LanceBuffer::Owned(decompress_bytes_buf), + offsets: LanceBuffer::reinterpret_vec(decompress_offset_buf), + bits_per_offset: 32, + num_values, + block_info: BlockInfo::new(), + })) + } +} + #[derive(Debug)] pub struct FsstMiniBlockDecompressor { - symbol_table: Vec, + symbol_table: LanceBuffer, } impl FsstMiniBlockDecompressor { - pub fn new(description: &pb::FsstMiniBlock) -> Self { + pub fn new(description: &pb::Fsst) -> Self { Self { - symbol_table: description.symbol_table.clone(), + symbol_table: LanceBuffer::from_bytes(description.symbol_table.clone(), 1), } } } impl MiniBlockDecompressor for FsstMiniBlockDecompressor { - fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result { + fn decompress(&self, data: Vec, num_values: u64) -> Result { // Step 1. decompress data use `BinaryMiniBlockDecompressor` let binary_decompressor = Box::new(BinaryMiniBlockDecompressor::default()) as Box; diff --git a/rust/lance-encoding/src/encodings/physical/packed_struct.rs b/rust/lance-encoding/src/encodings/physical/packed_struct.rs index 4feca6d9c4c..84c9c6a6874 100644 --- a/rust/lance-encoding/src/encodings/physical/packed_struct.rs +++ b/rust/lance-encoding/src/encodings/physical/packed_struct.rs @@ -9,7 +9,7 @@ use bytes::BytesMut; use futures::{future::BoxFuture, FutureExt}; use lance_arrow::DataTypeExt; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; use crate::data::BlockInfo; use crate::data::FixedSizeListBlock; @@ -151,7 +151,10 @@ impl PrimitivePageDecoder for PackedStructPageDecoder { let child_block = FixedSizeListBlock::from_flat(child_block, field.data_type()); children.push(child_block); } - Ok(DataBlock::Struct(StructDataBlock { children })) + Ok(DataBlock::Struct(StructDataBlock { + children, + block_info: BlockInfo::default(), + })) } } @@ -266,9 +269,13 @@ pub mod tests { testing::{check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases}, version::LanceFileVersion, }; + use rstest::rstest; + #[rstest] #[test_log::test(tokio::test)] - async fn test_random_packed_struct() { + async fn test_random_packed_struct( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + ) { let data_type = DataType::Struct(Fields::from(vec![ Field::new("a", DataType::UInt64, false), Field::new("b", DataType::UInt32, false), @@ -278,11 +285,14 @@ pub mod tests { let field = Field::new("", data_type, false).with_metadata(metadata); - check_round_trip_encoding_random(field, LanceFileVersion::V2_0).await; + check_round_trip_encoding_random(field, version).await; } + #[rstest] #[test_log::test(tokio::test)] - async fn test_specific_packed_struct() { + async fn test_specific_packed_struct( + #[values(LanceFileVersion::V2_0, LanceFileVersion::V2_1)] version: LanceFileVersion, + ) { let array1 = Arc::new(UInt64Array::from(vec![1, 2, 3, 4])); let array2 = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); let array3 = Arc::new(UInt8Array::from(vec![9, 10, 11, 12])); @@ -325,7 +335,8 @@ pub mod tests { .with_range(0..2) .with_range(0..6) .with_range(1..4) - .with_indices(vec![1, 3, 7]); + .with_indices(vec![1, 3, 7]) + .with_file_version(version); let mut metadata = HashMap::new(); metadata.insert("packed".to_string(), "true".to_string()); @@ -338,8 +349,14 @@ pub mod tests { .await; } + // the current Lance V2.1 `packed-struct encoding` doesn't support `fixed size list`. + // the current Lance V2.0 test is disabled for now as we don't have statistics for `FixedSizeList` + #[rstest] #[test_log::test(tokio::test)] - async fn test_fsl_packed_struct() { + async fn test_fsl_packed_struct( + #[values(/*LanceFileVersion::V2_0,*/ /*LanceFileVersion::V2_1)*/)] + version: LanceFileVersion, + ) { let int_array = Arc::new(Int32Array::from(vec![12, 13, 14, 15])); let list_data_type = @@ -367,7 +384,8 @@ pub mod tests { .with_range(1..3) .with_range(0..1) .with_range(2..4) - .with_indices(vec![0, 2, 3]); + .with_indices(vec![0, 2, 3]) + .with_file_version(version); let mut metadata = HashMap::new(); metadata.insert("packed".to_string(), "true".to_string()); diff --git a/rust/lance-encoding/src/encodings/physical/struct_encoding.rs b/rust/lance-encoding/src/encodings/physical/struct_encoding.rs new file mode 100644 index 00000000000..2e241e63ca6 --- /dev/null +++ b/rust/lance-encoding/src/encodings/physical/struct_encoding.rs @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use arrow::datatypes::UInt64Type; + +use lance_core::{Error, Result}; +use snafu::location; + +use crate::{ + buffer::LanceBuffer, + data::{BlockInfo, DataBlock, FixedWidthDataBlock, StructDataBlock}, + decoder::MiniBlockDecompressor, + encoder::{MiniBlockCompressed, MiniBlockCompressor}, + format::{ + pb::{self}, + ProtobufUtils, + }, + statistics::{GetStat, Stat}, +}; + +use super::value::{ValueDecompressor, ValueEncoder}; + +// Transforms a `StructDataBlock` into a row major `FixedWidthDataBlock`. +// Only fields with fixed-width fields are supported for now, and the +// assumption that all fields has `bits_per_value % 8 == 0` is made. +fn struct_data_block_to_fixed_width_data_block( + struct_data_block: StructDataBlock, + bits_per_values: &[u32], +) -> DataBlock { + let data_size = struct_data_block.expect_single_stat::(Stat::DataSize); + let mut output = Vec::with_capacity(data_size as usize); + let num_values = struct_data_block.children[0].num_values(); + + for i in 0..num_values as usize { + for (j, child) in struct_data_block.children.iter().enumerate() { + let bytes_per_value = (bits_per_values[j] / 8) as usize; + let this_data = child + .as_fixed_width_ref() + .unwrap() + .data + .slice_with_length(bytes_per_value * i, bytes_per_value); + output.extend_from_slice(&this_data); + } + } + + DataBlock::FixedWidth(FixedWidthDataBlock { + bits_per_value: bits_per_values + .iter() + .map(|bits_per_value| *bits_per_value as u64) + .sum(), + data: LanceBuffer::Owned(output), + num_values, + block_info: BlockInfo::default(), + }) +} + +#[derive(Debug, Default)] +pub struct PackedStructFixedWidthMiniBlockEncoder {} + +impl MiniBlockCompressor for PackedStructFixedWidthMiniBlockEncoder { + fn compress( + &self, + data: DataBlock, + ) -> Result<(MiniBlockCompressed, crate::format::pb::ArrayEncoding)> { + match data { + DataBlock::Struct(struct_data_block) => { + let bits_per_values = struct_data_block.children.iter().map(|data_block| data_block.as_fixed_width_ref().unwrap().bits_per_value as u32).collect::>(); + + // transform struct datablock to fixed-width data block. + let data_block = struct_data_block_to_fixed_width_data_block(struct_data_block, &bits_per_values); + + // store and transformed fixed-width data block. + let value_miniblock_compressor = Box::new(ValueEncoder::default()) as Box; + let (value_miniblock_compressed, value_array_encoding) = + value_miniblock_compressor.compress(data_block)?; + + Ok(( + value_miniblock_compressed, + ProtobufUtils::packed_struct_fixed_width_mini_block(value_array_encoding, bits_per_values), + )) + } + _ => Err(Error::InvalidInput { + source: format!( + "Cannot compress a data block of type {} with PackedStructFixedWidthBlockEncoder", + data.name() + ) + .into(), + location: location!(), + }), + } + } +} + +#[derive(Debug)] +pub struct PackedStructFixedWidthMiniBlockDecompressor { + bits_per_values: Vec, + array_encoding: Box, +} + +impl PackedStructFixedWidthMiniBlockDecompressor { + pub fn new(description: &pb::PackedStructFixedWidthMiniBlock) -> Self { + let array_encoding: Box = match description + .flat + .as_ref() + .unwrap() + .array_encoding + .as_ref() + .unwrap() + { + pb::array_encoding::ArrayEncoding::Flat(flat) => Box::new(ValueDecompressor::from_flat(flat)), + _ => panic!("Currently only `ArrayEncoding::Flat` is supported in packed struct encoding in Lance 2.1."), + }; + Self { + bits_per_values: description.bits_per_values.clone(), + array_encoding, + } + } +} + +impl MiniBlockDecompressor for PackedStructFixedWidthMiniBlockDecompressor { + fn decompress(&self, data: Vec, num_values: u64) -> Result { + assert_eq!(data.len(), 1); + let encoded_data_block = self.array_encoding.decompress(data, num_values)?; + let DataBlock::FixedWidth(encoded_data_block) = encoded_data_block else { + panic!("ValueDecompressor should output FixedWidth DataBlock") + }; + + let bytes_per_values = self + .bits_per_values + .iter() + .map(|bits_per_value| *bits_per_value as usize / 8) + .collect::>(); + + assert!(encoded_data_block.bits_per_value % 8 == 0); + let encoded_bytes_per_row = (encoded_data_block.bits_per_value / 8) as usize; + + // use a prefix_sum vector as a helper to reconstruct to `StructDataBlock`. + let mut prefix_sum = vec![0; self.bits_per_values.len()]; + for i in 0..(self.bits_per_values.len() - 1) { + prefix_sum[i + 1] = prefix_sum[i] + bytes_per_values[i]; + } + + let mut children_data_block = vec![]; + for i in 0..self.bits_per_values.len() { + let child_buf_size = bytes_per_values[i] * num_values as usize; + let mut child_buf: Vec = Vec::with_capacity(child_buf_size); + + for j in 0..num_values as usize { + // the start of the data at this row is `j * encoded_bytes_per_row`, and the offset for this field is `prefix_sum[i]`, this field has length `bytes_per_values[i]`. + let this_value = encoded_data_block.data.slice_with_length( + prefix_sum[i] + (j * encoded_bytes_per_row), + bytes_per_values[i], + ); + + child_buf.extend_from_slice(&this_value); + } + + let child = DataBlock::FixedWidth(FixedWidthDataBlock { + data: LanceBuffer::Owned(child_buf), + bits_per_value: self.bits_per_values[i] as u64, + num_values, + block_info: BlockInfo::default(), + }); + children_data_block.push(child); + } + Ok(DataBlock::Struct(StructDataBlock { + children: children_data_block, + block_info: BlockInfo::default(), + })) + } +} diff --git a/rust/lance-encoding/src/encodings/physical/value.rs b/rust/lance-encoding/src/encodings/physical/value.rs index f68a6b94a63..a7ea3c1c3a7 100644 --- a/rust/lance-encoding/src/encodings/physical/value.rs +++ b/rust/lance-encoding/src/encodings/physical/value.rs @@ -1,19 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use arrow_buffer::bit_util; +use arrow_buffer::{bit_util, BooleanBufferBuilder}; use arrow_schema::DataType; use bytes::Bytes; use futures::{future::BoxFuture, FutureExt}; use log::trace; -use snafu::{location, Location}; +use snafu::location; use std::ops::Range; use std::sync::{Arc, Mutex}; use crate::buffer::LanceBuffer; -use crate::data::{BlockInfo, ConstantDataBlock, DataBlock, FixedWidthDataBlock}; -use crate::decoder::PerValueDecompressor; -use crate::decoder::{BlockDecompressor, MiniBlockDecompressor}; +use crate::data::{ + BlockInfo, ConstantDataBlock, DataBlock, FixedSizeListBlock, FixedWidthDataBlock, + NullableDataBlock, +}; +use crate::decoder::{BlockDecompressor, FixedPerValueDecompressor, MiniBlockDecompressor}; use crate::encoder::{ BlockCompressor, MiniBlockChunk, MiniBlockCompressed, MiniBlockCompressor, PerValueCompressor, PerValueDataBlock, MAX_MINIBLOCK_BYTES, MAX_MINIBLOCK_VALUES, @@ -227,15 +229,18 @@ pub struct ValueEncoder {} impl ValueEncoder { /// Use the largest chunk we can smaller than 4KiB - fn find_log_vals_per_chunk(bytes_per_value: u64) -> (u64, u64) { - let mut size_bytes = 2 * bytes_per_value; - let mut log_num_vals = 1; - let mut num_vals = 2; + fn find_log_vals_per_chunk(bytes_per_word: u64, values_per_word: u64) -> (u64, u64) { + let mut size_bytes = 2 * bytes_per_word; + let (mut log_num_vals, mut num_vals) = match values_per_word { + 1 => (1, 2), + 8 => (3, 8), + _ => unreachable!(), + }; // If the type is so wide that we can't even fit 2 values we shouldn't be here assert!(size_bytes < MAX_MINIBLOCK_BYTES); - while 2 * size_bytes < MAX_MINIBLOCK_BYTES && 2 * num_vals < MAX_MINIBLOCK_VALUES { + while 2 * size_bytes < MAX_MINIBLOCK_BYTES && 2 * num_vals <= MAX_MINIBLOCK_VALUES { log_num_vals += 1; size_bytes *= 2; num_vals *= 2; @@ -245,14 +250,22 @@ impl ValueEncoder { } fn chunk_data(data: FixedWidthDataBlock) -> MiniBlockCompressed { - // For now, only support byte-sized data - debug_assert!(data.bits_per_value % 8 == 0); - let bytes_per_value = data.bits_per_value / 8; + // Usually there are X bytes per value. However, when working with boolean + // or FSL we might have some number of bits per value that isn't + // divisible by 8. In this case, to avoid chunking in the middle of a byte + // we calculate how many 8-value words we can fit in a chunk. + let (bytes_per_word, values_per_word) = if data.bits_per_value % 8 == 0 { + (data.bits_per_value / 8, 1) + } else { + (data.bits_per_value, 8) + }; // Aim for 4KiB chunks - let (log_vals_per_chunk, vals_per_chunk) = Self::find_log_vals_per_chunk(bytes_per_value); + let (log_vals_per_chunk, vals_per_chunk) = + Self::find_log_vals_per_chunk(bytes_per_word, values_per_word); let num_chunks = bit_util::ceil(data.num_values as usize, vals_per_chunk as usize); - let bytes_per_chunk = bytes_per_value * vals_per_chunk; + debug_assert_eq!(vals_per_chunk % values_per_word, 0); + let bytes_per_chunk = bytes_per_word * (vals_per_chunk / values_per_word); let bytes_per_chunk = u16::try_from(bytes_per_chunk).unwrap(); let data_buffer = data.data; @@ -265,7 +278,7 @@ impl ValueEncoder { if row_offset + vals_per_chunk <= data.num_values { chunks.push(MiniBlockChunk { log_num_values: log_vals_per_chunk as u8, - num_bytes: bytes_per_chunk, + buffer_sizes: vec![bytes_per_chunk], }); row_offset += vals_per_chunk; bytes_counter += bytes_per_chunk as u64; @@ -275,7 +288,7 @@ impl ValueEncoder { let num_bytes = u16::try_from(num_bytes).unwrap(); chunks.push(MiniBlockChunk { log_num_values: 0, - num_bytes, + buffer_sizes: vec![num_bytes], }); break; } @@ -283,12 +296,345 @@ impl ValueEncoder { MiniBlockCompressed { chunks, - data: data_buffer, + data: vec![data_buffer], num_values: data.num_values, } } } +#[derive(Debug)] +struct MiniblockFslLayer { + validity: Option, + dimension: u64, +} + +/// This impl deals with encoding FSL>>> data as a mini-block compressor. +/// The tricky part of FSL data is that we want to include inner validity buffers (we don't want these +/// to be part of the rep-def because that usually ends up being more expensive). +/// +/// The resulting mini-block will, instead of having a single buffer, have X + 1 buffers where X is +/// the number of FSL layers that contain validity. +/// +/// In the simple case where there is no validity inside the FSL layers, all we are doing here is flattening +/// the FSL layers into a single buffer. +/// +/// Also: We don't allow a row to be broken across chunks. This typically isn't too big of a deal since we +/// are usually dealing with relatively small vectors if we are using mini-block. +/// +/// Note: when we do have validity we have to make copies of the validity buffers because they are bit buffers +/// and we need to bit slice them which requires copies or offsets. Paying the price at write time to make +/// the copies is better than paying the price at read time to do the bit slicing. +impl ValueEncoder { + fn make_fsl_encoding(layers: &[MiniblockFslLayer], bits_per_value: u64) -> ArrayEncoding { + let mut encoding = ProtobufUtils::flat_encoding(bits_per_value, 0, None); + for layer in layers.iter().rev() { + let has_validity = layer.validity.is_some(); + let dimension = layer.dimension; + encoding = ProtobufUtils::fsl_encoding(dimension, encoding, has_validity); + } + encoding + } + + fn extract_fsl_chunk( + data: &FixedWidthDataBlock, + layers: &[MiniblockFslLayer], + row_offset: usize, + num_rows: usize, + validity_buffers: &mut [Vec], + ) -> Vec { + let mut row_offset = row_offset; + let mut num_values = num_rows; + let mut buffer_counter = 0; + let mut buffer_sizes = Vec::with_capacity(validity_buffers.len() + 1); + for layer in layers { + row_offset *= layer.dimension as usize; + num_values *= layer.dimension as usize; + if let Some(validity) = &layer.validity { + let validity_slice = validity + .try_clone() + .unwrap() + .bit_slice_le_with_length(row_offset, num_values); + validity_buffers[buffer_counter].extend_from_slice(&validity_slice); + buffer_sizes.push(validity_slice.len() as u16); + buffer_counter += 1; + } + } + + let bits_in_chunk = data.bits_per_value * num_values as u64; + let bytes_in_chunk = bits_in_chunk.div_ceil(8); + let bytes_in_chunk = u16::try_from(bytes_in_chunk).unwrap(); + buffer_sizes.push(bytes_in_chunk); + + buffer_sizes + } + + fn chunk_fsl( + data: FixedWidthDataBlock, + layers: Vec, + num_rows: u64, + ) -> (MiniBlockCompressed, ArrayEncoding) { + // Count size to calculate rows per chunk + let mut ceil_bytes_validity = 0; + let mut cum_dim = 1; + let mut num_validity_buffers = 0; + for layer in &layers { + cum_dim *= layer.dimension; + if layer.validity.is_some() { + ceil_bytes_validity += cum_dim.div_ceil(8); + num_validity_buffers += 1; + } + } + // It's an estimate because validity buffers may have some padding bits + let cum_bits_per_value = data.bits_per_value * cum_dim; + let (cum_bytes_per_word, vals_per_word) = if cum_bits_per_value % 8 == 0 { + (cum_bits_per_value / 8, 1) + } else { + (cum_bits_per_value, 8) + }; + let est_bytes_per_word = (ceil_bytes_validity * vals_per_word) + cum_bytes_per_word; + let (log_rows_per_chunk, rows_per_chunk) = + Self::find_log_vals_per_chunk(est_bytes_per_word, vals_per_word); + + let num_chunks = num_rows.div_ceil(rows_per_chunk) as usize; + + // Allocate buffers for validity, these will be slightly bigger than the input validity buffers + let mut chunks = Vec::with_capacity(num_chunks); + let mut validity_buffers: Vec> = Vec::with_capacity(num_validity_buffers); + cum_dim = 1; + for layer in &layers { + cum_dim *= layer.dimension; + if let Some(validity) = &layer.validity { + let layer_bytes_validity = cum_dim.div_ceil(8); + let validity_with_padding = + layer_bytes_validity as usize * num_chunks * rows_per_chunk as usize; + debug_assert!(validity_with_padding >= validity.len()); + validity_buffers.push(Vec::with_capacity( + layer_bytes_validity as usize * num_chunks, + )); + } + } + + // Now go through and extract validity buffers + let mut row_offset = 0; + while row_offset + rows_per_chunk <= num_rows { + let buffer_sizes = Self::extract_fsl_chunk( + &data, + &layers, + row_offset as usize, + rows_per_chunk as usize, + &mut validity_buffers, + ); + row_offset += rows_per_chunk; + chunks.push(MiniBlockChunk { + log_num_values: log_rows_per_chunk as u8, + buffer_sizes, + }) + } + let rows_in_chunk = num_rows - row_offset; + if rows_in_chunk > 0 { + let buffer_sizes = Self::extract_fsl_chunk( + &data, + &layers, + row_offset as usize, + rows_in_chunk as usize, + &mut validity_buffers, + ); + chunks.push(MiniBlockChunk { + log_num_values: 0, + buffer_sizes, + }); + } + + let encoding = Self::make_fsl_encoding(&layers, data.bits_per_value); + // Finally, add the data buffer + let buffers = validity_buffers + .into_iter() + .map(LanceBuffer::Owned) + .chain(std::iter::once(data.data)) + .collect::>(); + + ( + MiniBlockCompressed { + chunks, + data: buffers, + num_values: num_rows, + }, + encoding, + ) + } + + fn miniblock_fsl(data: DataBlock) -> (MiniBlockCompressed, ArrayEncoding) { + let num_rows = data.num_values(); + let fsl = data.as_fixed_size_list().unwrap(); + let mut layers = Vec::new(); + let mut child = *fsl.child; + let mut cur_layer = MiniblockFslLayer { + validity: None, + dimension: fsl.dimension, + }; + loop { + if let DataBlock::Nullable(nullable) = child { + cur_layer.validity = Some(nullable.nulls); + child = *nullable.data; + } + match child { + DataBlock::FixedSizeList(inner) => { + layers.push(cur_layer); + cur_layer = MiniblockFslLayer { + validity: None, + dimension: inner.dimension, + }; + child = *inner.child; + } + DataBlock::FixedWidth(inner) => { + layers.push(cur_layer); + return Self::chunk_fsl(inner, layers, num_rows); + } + _ => unreachable!("Unexpected data block type in value encoder's miniblock_fsl"), + } + } + } +} + +struct PerValueFslValidityIter { + buffer: LanceBuffer, + bits_per_row: usize, + offset: usize, +} + +/// In this section we deal with per-value encoding of FSL>>> data. +/// +/// It's easier than mini-block. All we need to do is flatten the FSL layers into a single buffer. +/// This includes any validity buffers we encounter on the way. +impl ValueEncoder { + fn fsl_to_encoding(fsl: &FixedSizeListBlock) -> ArrayEncoding { + let mut inner = fsl.child.as_ref(); + let mut has_validity = false; + inner = match inner { + DataBlock::Nullable(nullable) => { + has_validity = true; + nullable.data.as_ref() + } + _ => inner, + }; + let inner_encoding = match inner { + DataBlock::FixedWidth(fixed_width) => { + ProtobufUtils::flat_encoding(fixed_width.bits_per_value, 0, None) + } + DataBlock::FixedSizeList(inner) => Self::fsl_to_encoding(inner), + _ => unreachable!("Unexpected data block type in value encoder's fsl_to_encoding"), + }; + ProtobufUtils::fsl_encoding(fsl.dimension, inner_encoding, has_validity) + } + + fn simple_per_value_fsl(fsl: FixedSizeListBlock) -> (PerValueDataBlock, ArrayEncoding) { + // The simple case is zero-copy, we just return the flattened inner buffer + let encoding = Self::fsl_to_encoding(&fsl); + let num_values = fsl.num_values(); + let mut child = *fsl.child; + let mut cum_dim = 1; + loop { + cum_dim *= fsl.dimension; + match child { + DataBlock::Nullable(nullable) => { + child = *nullable.data; + } + DataBlock::FixedSizeList(inner) => { + child = *inner.child; + } + DataBlock::FixedWidth(inner) => { + let data = FixedWidthDataBlock { + bits_per_value: inner.bits_per_value * cum_dim, + num_values, + data: inner.data, + block_info: BlockInfo::new(), + }; + return (PerValueDataBlock::Fixed(data), encoding); + } + _ => unreachable!( + "Unexpected data block type in value encoder's simple_per_value_fsl" + ), + } + } + } + + fn nullable_per_value_fsl(fsl: FixedSizeListBlock) -> (PerValueDataBlock, ArrayEncoding) { + // If there are nullable inner values then we need to zip the validity with the values + let encoding = Self::fsl_to_encoding(&fsl); + let num_values = fsl.num_values(); + let mut bytes_per_row = 0; + let mut cum_dim = 1; + let mut current = fsl; + let mut validity_iters: Vec = Vec::new(); + let data_bytes_per_row: usize; + let data_buffer: LanceBuffer; + loop { + cum_dim *= current.dimension; + let mut child = *current.child; + if let DataBlock::Nullable(nullable) = child { + // Each item will need this many bytes of validity prepended to it + bytes_per_row += cum_dim.div_ceil(8) as usize; + validity_iters.push(PerValueFslValidityIter { + buffer: nullable.nulls, + bits_per_row: cum_dim as usize, + offset: 0, + }); + child = *nullable.data; + }; + match child { + DataBlock::FixedSizeList(inner) => { + current = inner; + } + DataBlock::FixedWidth(fixed_width) => { + data_bytes_per_row = + (fixed_width.bits_per_value.div_ceil(8) * cum_dim) as usize; + bytes_per_row += data_bytes_per_row; + data_buffer = fixed_width.data; + break; + } + _ => unreachable!( + "Unexpected data block type in value encoder's nullable_per_value_fsl: {:?}", + child + ), + } + } + + let bytes_needed = bytes_per_row * num_values as usize; + let mut zipped = Vec::with_capacity(bytes_needed); + let data_slice = &data_buffer; + // Hopefully values are pretty large so we don't iterate this loop _too_ many times + for i in 0..num_values as usize { + for validity in validity_iters.iter_mut() { + let validity_slice = validity + .buffer + .bit_slice_le_with_length(validity.offset, validity.bits_per_row); + zipped.extend_from_slice(&validity_slice); + validity.offset += validity.bits_per_row; + } + let start = i * data_bytes_per_row; + let end = start + data_bytes_per_row; + zipped.extend_from_slice(&data_slice[start..end]); + } + + let zipped = LanceBuffer::Owned(zipped); + let data = PerValueDataBlock::Fixed(FixedWidthDataBlock { + bits_per_value: bytes_per_row as u64 * 8, + num_values, + data: zipped, + block_info: BlockInfo::new(), + }); + (data, encoding) + } + + fn per_value_fsl(fsl: FixedSizeListBlock) -> (PerValueDataBlock, ArrayEncoding) { + if !fsl.child.is_nullable() { + Self::simple_per_value_fsl(fsl) + } else { + Self::nullable_per_value_fsl(fsl) + } + } +} + impl BlockCompressor for ValueEncoder { fn compress(&self, data: DataBlock) -> Result { let data = match data { @@ -344,6 +690,7 @@ impl MiniBlockCompressor for ValueEncoder { let encoding = ProtobufUtils::flat_encoding(fixed_width.bits_per_value, 0, None); Ok((Self::chunk_data(fixed_width), encoding)) } + DataBlock::FixedSizeList(_) => Ok(Self::miniblock_fsl(chunk)), _ => Err(Error::InvalidInput { source: format!( "Cannot compress a data block of type {} with ValueEncoder", @@ -360,76 +707,269 @@ impl MiniBlockCompressor for ValueEncoder { #[derive(Debug)] pub struct ConstantDecompressor { scalar: LanceBuffer, - num_values: u64, } impl ConstantDecompressor { - pub fn new(scalar: LanceBuffer, num_values: u64) -> Self { + pub fn new(scalar: LanceBuffer) -> Self { Self { scalar: scalar.into_borrowed(), - num_values, } } } impl BlockDecompressor for ConstantDecompressor { - fn decompress(&self, _data: LanceBuffer) -> Result { + fn decompress(&self, _data: LanceBuffer, num_values: u64) -> Result { Ok(DataBlock::Constant(ConstantDataBlock { data: self.scalar.try_clone().unwrap(), - num_values: self.num_values, + num_values, })) } } +#[derive(Debug)] +struct ValueFslDesc { + dimension: u64, + has_validity: bool, +} + /// A decompressor for fixed-width data that has /// been written, as-is, to disk in single contiguous array #[derive(Debug)] pub struct ValueDecompressor { - bytes_per_value: u64, + /// How many bits are in each inner-most item (e.g. FSL would be 32) + bits_per_item: u64, + /// How many bits are in each value (e.g. FSL would be 3200) + /// + /// This number is a little trickier to compute because we also have to include bytes + /// of any inner validity + bits_per_value: u64, + /// How many items are in each value (e.g. FSL would be 100) + items_per_value: u64, + layers: Vec, } impl ValueDecompressor { - pub fn new(description: &pb::Flat) -> Self { - assert!(description.bits_per_value % 8 == 0); + pub fn from_flat(description: &pb::Flat) -> Self { Self { - bytes_per_value: description.bits_per_value / 8, + bits_per_item: description.bits_per_value, + bits_per_value: description.bits_per_value, + items_per_value: 1, + layers: Vec::default(), } } -} -impl BlockDecompressor for ValueDecompressor { - fn decompress(&self, data: LanceBuffer) -> Result { - let num_values = data.len() as u64 / self.bytes_per_value; - assert_eq!(data.len() as u64 % self.bytes_per_value, 0); - Ok(DataBlock::FixedWidth(FixedWidthDataBlock { - bits_per_value: self.bytes_per_value * 8, - data, + pub fn from_fsl(mut description: &pb::FixedSizeList) -> Self { + let mut layers = Vec::new(); + let mut cum_dim = 1; + let mut bytes_per_value = 0; + loop { + layers.push(ValueFslDesc { + has_validity: description.has_validity, + dimension: description.dimension as u64, + }); + cum_dim *= description.dimension as u64; + if description.has_validity { + bytes_per_value += cum_dim.div_ceil(8); + } + match description + .items + .as_ref() + .unwrap() + .array_encoding + .as_ref() + .unwrap() + { + pb::array_encoding::ArrayEncoding::FixedSizeList(inner) => { + description = inner; + } + pb::array_encoding::ArrayEncoding::Flat(flat) => { + let mut bits_per_value = bytes_per_value * 8; + bits_per_value += flat.bits_per_value * cum_dim; + return Self { + bits_per_item: flat.bits_per_value, + bits_per_value, + items_per_value: cum_dim, + layers, + }; + } + _ => unreachable!(), + } + } + } + + fn buffer_to_block(&self, data: LanceBuffer, num_values: u64) -> DataBlock { + DataBlock::FixedWidth(FixedWidthDataBlock { + bits_per_value: self.bits_per_item, num_values, + data, block_info: BlockInfo::new(), - })) + }) } } -impl MiniBlockDecompressor for ValueDecompressor { +impl BlockDecompressor for ValueDecompressor { fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result { - debug_assert!(data.len() as u64 >= num_values * self.bytes_per_value); + let block = self.buffer_to_block(data, num_values); + assert_eq!(block.num_values(), num_values); + Ok(block) + } +} - Ok(DataBlock::FixedWidth(FixedWidthDataBlock { - data, - bits_per_value: self.bytes_per_value * 8, - num_values, +impl MiniBlockDecompressor for ValueDecompressor { + fn decompress(&self, data: Vec, num_values: u64) -> Result { + let num_items = num_values * self.items_per_value; + let mut buffer_iter = data.into_iter().rev(); + + // Always at least 1 buffer + let data_buf = buffer_iter.next().unwrap(); + let items = self.buffer_to_block(data_buf, num_items); + let mut lists = items; + + for layer in self.layers.iter().rev() { + if layer.has_validity { + let validity_buf = buffer_iter.next().unwrap(); + lists = DataBlock::Nullable(NullableDataBlock { + data: Box::new(lists), + nulls: validity_buf, + block_info: BlockInfo::default(), + }); + } + lists = DataBlock::FixedSizeList(FixedSizeListBlock { + child: Box::new(lists), + dimension: layer.dimension, + }) + } + + assert_eq!(lists.num_values(), num_values); + Ok(lists) + } +} + +struct FslDecompressorValidityBuilder { + buffer: BooleanBufferBuilder, + bits_per_row: usize, + bytes_per_row: usize, +} + +// Helper methods for per-value decompression +impl ValueDecompressor { + fn has_validity(&self) -> bool { + self.layers.iter().any(|layer| layer.has_validity) + } + + // If there is no validity then decompression is zero-copy, we just need to restore any FSL layers + fn simple_decompress(&self, data: FixedWidthDataBlock, num_rows: u64) -> DataBlock { + let mut cum_dim = 1; + for layer in &self.layers { + cum_dim *= layer.dimension; + } + debug_assert_eq!(self.bits_per_item, data.bits_per_value / cum_dim); + let mut block = DataBlock::FixedWidth(FixedWidthDataBlock { + bits_per_value: self.bits_per_item, + num_values: num_rows * cum_dim, + data: data.data, block_info: BlockInfo::new(), - })) + }); + for layer in self.layers.iter().rev() { + block = DataBlock::FixedSizeList(FixedSizeListBlock { + child: Box::new(block), + dimension: layer.dimension, + }); + } + debug_assert_eq!(num_rows, block.num_values()); + block + } + + // If there is validity then it has been zipped in with the values and we must unzip it + fn unzip_decompress(&self, data: FixedWidthDataBlock, num_rows: usize) -> DataBlock { + // No support for full-zip on per-value encodings + assert_eq!(self.bits_per_item % 8, 0); + let bytes_per_item = self.bits_per_item / 8; + let mut buffer_builders = Vec::with_capacity(self.layers.len()); + let mut cum_dim = 1; + let mut total_size_bytes = 0; + // First, go through the layers, setup our builders, allocate space + for layer in &self.layers { + cum_dim *= layer.dimension as usize; + if layer.has_validity { + let validity_size_bits = cum_dim; + let validity_size_bytes = validity_size_bits.div_ceil(8); + total_size_bytes += num_rows * validity_size_bytes; + buffer_builders.push(FslDecompressorValidityBuilder { + buffer: BooleanBufferBuilder::new(validity_size_bits * num_rows), + bits_per_row: cum_dim, + bytes_per_row: validity_size_bytes, + }) + } + } + let num_items = num_rows * cum_dim; + let data_size = num_items * bytes_per_item as usize; + total_size_bytes += data_size; + let mut data_buffer = Vec::with_capacity(data_size); + + assert_eq!(data.data.len(), total_size_bytes); + + let bytes_per_value = bytes_per_item as usize; + let data_bytes_per_row = bytes_per_value * cum_dim; + + // Next, unzip + let mut data_offset = 0; + while data_offset < total_size_bytes { + for builder in buffer_builders.iter_mut() { + let start = data_offset * 8; + let end = start + builder.bits_per_row; + builder.buffer.append_packed_range(start..end, &data.data); + data_offset += builder.bytes_per_row; + } + let end = data_offset + data_bytes_per_row; + data_buffer.extend_from_slice(&data.data[data_offset..end]); + data_offset += data_bytes_per_row; + } + + // Finally, restore the structure + let mut block = DataBlock::FixedWidth(FixedWidthDataBlock { + bits_per_value: self.bits_per_value, + num_values: num_items as u64, + data: LanceBuffer::Owned(data_buffer), + block_info: BlockInfo::new(), + }); + + let mut validity_bufs = buffer_builders + .into_iter() + .rev() + .map(|mut b| LanceBuffer::Borrowed(b.buffer.finish().into_inner())); + for layer in self.layers.iter().rev() { + if layer.has_validity { + let nullable = NullableDataBlock { + data: Box::new(block), + nulls: validity_bufs.next().unwrap(), + block_info: BlockInfo::new(), + }; + block = DataBlock::Nullable(nullable); + } + block = DataBlock::FixedSizeList(FixedSizeListBlock { + child: Box::new(block), + dimension: layer.dimension, + }); + } + + assert_eq!(num_rows, block.num_values() as usize); + + block } } -impl PerValueDecompressor for ValueDecompressor { - fn decompress(&self, data: LanceBuffer, num_values: u64) -> Result { - MiniBlockDecompressor::decompress(self, data, num_values) +impl FixedPerValueDecompressor for ValueDecompressor { + fn decompress(&self, data: FixedWidthDataBlock, num_rows: u64) -> Result { + if self.has_validity() { + Ok(self.unzip_decompress(data, num_rows as usize)) + } else { + Ok(self.simple_decompress(data, num_rows)) + } } fn bits_per_value(&self) -> u64 { - self.bytes_per_value * 8 + self.bits_per_value } } @@ -440,6 +980,7 @@ impl PerValueCompressor for ValueEncoder { let encoding = ProtobufUtils::flat_encoding(fixed_width.bits_per_value, 0, None); (PerValueDataBlock::Fixed(fixed_width), encoding) } + DataBlock::FixedSizeList(fixed_size_list) => Self::per_value_fsl(fixed_size_list), _ => unimplemented!( "Cannot compress block of type {} with ValueEncoder", data.name() @@ -454,15 +995,26 @@ impl PerValueCompressor for ValueEncoder { pub(crate) mod tests { use std::{collections::HashMap, sync::Arc}; - use arrow_array::{Array, ArrayRef, Decimal128Array, Int32Array}; + use arrow_array::{ + make_array, Array, ArrayRef, Decimal128Array, FixedSizeListArray, Int32Array, + }; + use arrow_buffer::{BooleanBuffer, NullBuffer}; use arrow_schema::{DataType, Field, TimeUnit}; + use lance_datagen::{array, gen, ArrayGeneratorExt, Dimension, RowCount}; use rstest::rstest; use crate::{ + data::DataBlock, + decoder::{FixedPerValueDecompressor, MiniBlockDecompressor}, + encoder::{MiniBlockCompressor, PerValueCompressor, PerValueDataBlock}, + encodings::physical::value::ValueDecompressor, + format::pb, testing::{check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases}, version::LanceFileVersion, }; + use super::ValueEncoder; + const PRIMITIVE_TYPES: &[DataType] = &[ DataType::Null, DataType::FixedSizeBinary(2), @@ -490,6 +1042,29 @@ pub(crate) mod tests { // DataType::Interval(IntervalUnit::DayTime), ]; + #[test_log::test(tokio::test)] + async fn test_simple_value() { + let items = Arc::new(Int32Array::from(vec![ + Some(0), + None, + Some(2), + Some(3), + Some(4), + Some(5), + ])); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_range(0..2) + .with_range(1..3) + .with_indices(vec![0, 1, 2]) + .with_indices(vec![1]) + .with_indices(vec![2]) + .with_file_version(LanceFileVersion::V2_1); + + check_round_trip_encoding_of_data(vec![items], &test_cases, HashMap::default()).await; + } + #[rstest] #[test_log::test(tokio::test)] async fn test_value_primitive( @@ -583,4 +1158,189 @@ pub(crate) mod tests { } } } + + fn create_simple_fsl() -> FixedSizeListArray { + // [[0, 1], NULL], [NULL, NULL], [[8, 9], [NULL, 11]] + let items = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + None, + None, + None, + Some(8), + Some(9), + None, + Some(11), + ])); + let items_field = Arc::new(Field::new("item", DataType::Int32, true)); + let inner_list_nulls = BooleanBuffer::from(vec![true, false, false, false, true, true]); + let inner_list = Arc::new(FixedSizeListArray::new( + items_field.clone(), + 2, + items, + Some(NullBuffer::new(inner_list_nulls)), + )); + let inner_list_field = Arc::new(Field::new( + "item", + DataType::FixedSizeList(items_field, 2), + true, + )); + FixedSizeListArray::new(inner_list_field, 2, inner_list, None) + } + + #[test] + fn test_fsl_value_compression_miniblock() { + let sample_list = create_simple_fsl(); + + let starting_data = DataBlock::from_array(sample_list.clone()); + + let encoder = ValueEncoder::default(); + let (data, compression) = MiniBlockCompressor::compress(&encoder, starting_data).unwrap(); + + assert_eq!(data.num_values, 3); + assert_eq!(data.data.len(), 3); + assert_eq!(data.chunks.len(), 1); + assert_eq!(data.chunks[0].buffer_sizes, vec![1, 2, 48]); + assert_eq!(data.chunks[0].log_num_values, 0); + + let pb::array_encoding::ArrayEncoding::FixedSizeList(fsl) = + compression.array_encoding.unwrap() + else { + panic!() + }; + + let decompressor = ValueDecompressor::from_fsl(fsl.as_ref()); + + let decompressed = + MiniBlockDecompressor::decompress(&decompressor, data.data, data.num_values).unwrap(); + + let decompressed = make_array( + decompressed + .into_arrow(sample_list.data_type().clone(), true) + .unwrap(), + ); + + assert_eq!(decompressed.as_ref(), &sample_list); + } + + #[test] + fn test_fsl_value_compression_per_value() { + let sample_list = create_simple_fsl(); + + let starting_data = DataBlock::from_array(sample_list.clone()); + + let encoder = ValueEncoder::default(); + let (data, compression) = PerValueCompressor::compress(&encoder, starting_data).unwrap(); + + let PerValueDataBlock::Fixed(data) = data else { + panic!() + }; + + assert_eq!(data.bits_per_value, 144); + assert_eq!(data.num_values, 3); + assert_eq!(data.data.len(), 18 * 3); + + let pb::array_encoding::ArrayEncoding::FixedSizeList(fsl) = + compression.array_encoding.unwrap() + else { + panic!() + }; + + let decompressor = ValueDecompressor::from_fsl(fsl.as_ref()); + + let num_values = data.num_values; + let decompressed = + FixedPerValueDecompressor::decompress(&decompressor, data, num_values).unwrap(); + + let decompressed = make_array( + decompressed + .into_arrow(sample_list.data_type().clone(), true) + .unwrap(), + ); + + assert_eq!(decompressed.as_ref(), &sample_list); + } + + fn create_random_fsl() -> Arc { + // Several levels of def and multiple pages + let inner = array::rand_type(&DataType::Int32).with_random_nulls(0.1); + let list_one = array::cycle_vec(inner, Dimension::from(4)).with_random_nulls(0.1); + let list_two = array::cycle_vec(list_one, Dimension::from(4)).with_random_nulls(0.1); + let list_three = array::cycle_vec(list_two, Dimension::from(2)); + + // Should be 256Ki rows ~ 1MiB of data + let batch = gen() + .anon_col(list_three) + .into_batch_rows(RowCount::from(8 * 1024)) + .unwrap(); + batch.column(0).clone() + } + + #[test] + fn fsl_value_miniblock_stress() { + let sample_array = create_random_fsl(); + + let starting_data = + DataBlock::from_arrays(&[sample_array.clone()], sample_array.len() as u64); + + let encoder = ValueEncoder::default(); + let (data, compression) = MiniBlockCompressor::compress(&encoder, starting_data).unwrap(); + + let pb::array_encoding::ArrayEncoding::FixedSizeList(fsl) = + compression.array_encoding.unwrap() + else { + panic!() + }; + + let decompressor = ValueDecompressor::from_fsl(fsl.as_ref()); + + let decompressed = + MiniBlockDecompressor::decompress(&decompressor, data.data, data.num_values).unwrap(); + + let decompressed = make_array( + decompressed + .into_arrow(sample_array.data_type().clone(), true) + .unwrap(), + ); + + assert_eq!(decompressed.as_ref(), sample_array.as_ref()); + } + + #[test] + fn fsl_value_per_value_stress() { + let sample_array = create_random_fsl(); + + let starting_data = + DataBlock::from_arrays(&[sample_array.clone()], sample_array.len() as u64); + + let encoder = ValueEncoder::default(); + let (data, compression) = PerValueCompressor::compress(&encoder, starting_data).unwrap(); + + let pb::array_encoding::ArrayEncoding::FixedSizeList(fsl) = + compression.array_encoding.unwrap() + else { + panic!() + }; + + let decompressor = ValueDecompressor::from_fsl(fsl.as_ref()); + + let PerValueDataBlock::Fixed(data) = data else { + panic!() + }; + + let num_values = data.num_values; + let decompressed = + FixedPerValueDecompressor::decompress(&decompressor, data, num_values).unwrap(); + + let decompressed = make_array( + decompressed + .into_arrow(sample_array.data_type().clone(), true) + .unwrap(), + ); + + assert_eq!(decompressed.as_ref(), sample_array.as_ref()); + } } diff --git a/rust/lance-encoding/src/format.rs b/rust/lance-encoding/src/format.rs index cf90c49ab2f..b33be64dde2 100644 --- a/rust/lance-encoding/src/format.rs +++ b/rust/lance-encoding/src/format.rs @@ -17,14 +17,18 @@ pub mod pb { use pb::{ array_encoding::ArrayEncoding as ArrayEncodingEnum, buffer::BufferType, + full_zip_layout, nullable::{AllNull, NoNull, Nullability, SomeNull}, page_layout::Layout, - AllNullLayout, ArrayEncoding, Binary, BinaryBlock, BinaryMiniBlock, Bitpack2, Bitpacked, - BitpackedForNonNeg, Dictionary, FixedSizeBinary, FixedSizeList, Flat, Fsst, FsstMiniBlock, - MiniBlockLayout, Nullable, PackedStruct, PageLayout, + AllNullLayout, ArrayEncoding, Binary, Bitpacked, BitpackedForNonNeg, Block, Dictionary, + FixedSizeBinary, FixedSizeList, Flat, Fsst, InlineBitpacking, MiniBlockLayout, Nullable, + OutOfLineBitpacking, PackedStruct, PackedStructFixedWidthMiniBlock, PageLayout, RepDefLayer, + Variable, }; -use crate::encodings::physical::block_compress::CompressionConfig; +use crate::{ + encodings::physical::block_compress::CompressionConfig, repdef::DefinitionInterpretation, +}; use self::pb::Constant; @@ -32,9 +36,11 @@ use self::pb::Constant; pub struct ProtobufUtils {} impl ProtobufUtils { - pub fn constant(value: Vec, num_values: u64) -> ArrayEncoding { + pub fn constant(value: Vec) -> ArrayEncoding { ArrayEncoding { - array_encoding: Some(ArrayEncodingEnum::Constant(Constant { value, num_values })), + array_encoding: Some(ArrayEncodingEnum::Constant(Constant { + value: value.into(), + })), } } @@ -70,6 +76,14 @@ impl ProtobufUtils { } } + pub fn block(scheme: &str) -> ArrayEncoding { + ArrayEncoding { + array_encoding: Some(ArrayEncodingEnum::Block(Block { + scheme: scheme.to_string(), + })), + } + } + pub fn flat_encoding( bits_per_value: u64, buffer_index: u32, @@ -90,6 +104,16 @@ impl ProtobufUtils { } } + pub fn fsl_encoding(dimension: u64, items: ArrayEncoding, has_validity: bool) -> ArrayEncoding { + ArrayEncoding { + array_encoding: Some(ArrayEncodingEnum::FixedSizeList(Box::new(FixedSizeList { + dimension: dimension.try_into().unwrap(), + items: Some(Box::new(items)), + has_validity, + }))), + } + } + pub fn bitpacked_encoding( compressed_bits_per_value: u64, uncompressed_bits_per_value: u64, @@ -125,34 +149,43 @@ impl ProtobufUtils { })), } } - pub fn bitpack2(uncompressed_bits_per_value: u64) -> ArrayEncoding { + pub fn inline_bitpacking(uncompressed_bits_per_value: u64) -> ArrayEncoding { ArrayEncoding { - array_encoding: Some(ArrayEncodingEnum::Bitpack2(Bitpack2 { + array_encoding: Some(ArrayEncodingEnum::InlineBitpacking(InlineBitpacking { uncompressed_bits_per_value, })), } } - - pub fn binary_miniblock() -> ArrayEncoding { + pub fn out_of_line_bitpacking( + uncompressed_bits_per_value: u64, + compressed_bits_per_value: u64, + ) -> ArrayEncoding { ArrayEncoding { - array_encoding: Some(ArrayEncodingEnum::BinaryMiniBlock(BinaryMiniBlock {})), + array_encoding: Some(ArrayEncodingEnum::OutOfLineBitpacking( + OutOfLineBitpacking { + uncompressed_bits_per_value, + compressed_bits_per_value, + }, + )), } } - pub fn binary_block() -> ArrayEncoding { + pub fn variable(bits_per_offset: u8) -> ArrayEncoding { ArrayEncoding { - array_encoding: Some(ArrayEncodingEnum::BinaryBlock(BinaryBlock {})), + array_encoding: Some(ArrayEncodingEnum::Variable(Variable { + bits_per_offset: bits_per_offset as u32, + })), } } // Construct a `FsstMiniBlock` ArrayEncoding, the inner `binary_mini_block` encoding is actually // not used and `FsstMiniBlockDecompressor` constructs a `binary_mini_block` in a `hard-coded` fashion. // This can be an optimization later. - pub fn fsst_mini_block(data: ArrayEncoding, symbol_table: Vec) -> ArrayEncoding { + pub fn fsst(data: ArrayEncoding, symbol_table: Vec) -> ArrayEncoding { ArrayEncoding { - array_encoding: Some(ArrayEncodingEnum::FsstMiniBlock(Box::new(FsstMiniBlock { - binary_mini_block: Some(Box::new(data)), - symbol_table, + array_encoding: Some(ArrayEncodingEnum::Fsst(Box::new(Fsst { + binary: Some(Box::new(data)), + symbol_table: symbol_table.into(), }))), } } @@ -172,6 +205,20 @@ impl ProtobufUtils { } } + pub fn packed_struct_fixed_width_mini_block( + data: ArrayEncoding, + bits_per_values: Vec, + ) -> ArrayEncoding { + ArrayEncoding { + array_encoding: Some(ArrayEncodingEnum::PackedStructFixedWidthMiniBlock( + Box::new(PackedStructFixedWidthMiniBlock { + flat: Some(Box::new(data)), + bits_per_values, + }), + )), + } + } + pub fn binary( indices_encoding: ArrayEncoding, bytes_encoding: ArrayEncoding, @@ -211,57 +258,144 @@ impl ProtobufUtils { } } - pub fn fixed_size_list(data: ArrayEncoding, dimension: u64) -> ArrayEncoding { - ArrayEncoding { - array_encoding: Some(ArrayEncodingEnum::FixedSizeList(Box::new(FixedSizeList { - dimension: dimension.try_into().unwrap(), - items: Some(Box::new(data)), - }))), + fn def_inter_to_repdef_layer(def: DefinitionInterpretation) -> i32 { + match def { + DefinitionInterpretation::AllValidItem => RepDefLayer::RepdefAllValidItem as i32, + DefinitionInterpretation::AllValidList => RepDefLayer::RepdefAllValidList as i32, + DefinitionInterpretation::NullableItem => RepDefLayer::RepdefNullableItem as i32, + DefinitionInterpretation::NullableList => RepDefLayer::RepdefNullableList as i32, + DefinitionInterpretation::EmptyableList => RepDefLayer::RepdefEmptyableList as i32, + DefinitionInterpretation::NullableAndEmptyableList => { + RepDefLayer::RepdefNullAndEmptyList as i32 + } } } - pub fn fsst(data: ArrayEncoding, symbol_table: Vec) -> ArrayEncoding { - ArrayEncoding { - array_encoding: Some(ArrayEncodingEnum::Fsst(Box::new(Fsst { - binary: Some(Box::new(data)), - symbol_table, - }))), + pub fn repdef_layer_to_def_interp(layer: i32) -> DefinitionInterpretation { + let layer = RepDefLayer::try_from(layer).unwrap(); + match layer { + RepDefLayer::RepdefAllValidItem => DefinitionInterpretation::AllValidItem, + RepDefLayer::RepdefAllValidList => DefinitionInterpretation::AllValidList, + RepDefLayer::RepdefNullableItem => DefinitionInterpretation::NullableItem, + RepDefLayer::RepdefNullableList => DefinitionInterpretation::NullableList, + RepDefLayer::RepdefEmptyableList => DefinitionInterpretation::EmptyableList, + RepDefLayer::RepdefNullAndEmptyList => { + DefinitionInterpretation::NullableAndEmptyableList + } + RepDefLayer::RepdefUnspecified => panic!("Unspecified repdef layer"), } } + #[allow(clippy::too_many_arguments)] pub fn miniblock_layout( - rep_encoding: ArrayEncoding, - def_encoding: ArrayEncoding, + rep_encoding: Option, + def_encoding: Option, value_encoding: ArrayEncoding, - dictionary_encoding: Option, + repetition_index_depth: u32, + num_buffers: u64, + dictionary_encoding: Option<(ArrayEncoding, u64)>, + def_meaning: &[DefinitionInterpretation], + num_items: u64, ) -> PageLayout { + assert!(!def_meaning.is_empty()); + let (dictionary, num_dictionary_items) = dictionary_encoding + .map(|(d, i)| (Some(d), i)) + .unwrap_or((None, 0)); PageLayout { layout: Some(Layout::MiniBlockLayout(MiniBlockLayout { - def_compression: Some(def_encoding), - rep_compression: Some(rep_encoding), + def_compression: def_encoding, + rep_compression: rep_encoding, value_compression: Some(value_encoding), - dictionary: dictionary_encoding, + repetition_index_depth, + num_buffers, + dictionary, + num_dictionary_items, + layers: def_meaning + .iter() + .map(|&def| Self::def_inter_to_repdef_layer(def)) + .collect(), + num_items, })), } } - pub fn full_zip_layout( + fn full_zip_layout( bits_rep: u8, bits_def: u8, + details: full_zip_layout::Details, value_encoding: ArrayEncoding, + def_meaning: &[DefinitionInterpretation], + num_items: u32, + num_visible_items: u32, ) -> PageLayout { PageLayout { layout: Some(Layout::FullZipLayout(pb::FullZipLayout { bits_rep: bits_rep as u32, bits_def: bits_def as u32, + details: Some(details), value_compression: Some(value_encoding), + num_items, + num_visible_items, + layers: def_meaning + .iter() + .map(|&def| Self::def_inter_to_repdef_layer(def)) + .collect(), })), } } - pub fn simple_all_null_layout() -> PageLayout { + pub fn fixed_full_zip_layout( + bits_rep: u8, + bits_def: u8, + bits_per_value: u32, + value_encoding: ArrayEncoding, + def_meaning: &[DefinitionInterpretation], + num_items: u32, + num_visible_items: u32, + ) -> PageLayout { + Self::full_zip_layout( + bits_rep, + bits_def, + full_zip_layout::Details::BitsPerValue(bits_per_value), + value_encoding, + def_meaning, + num_items, + num_visible_items, + ) + } + + pub fn variable_full_zip_layout( + bits_rep: u8, + bits_def: u8, + bits_per_offset: u32, + value_encoding: ArrayEncoding, + def_meaning: &[DefinitionInterpretation], + num_items: u32, + num_visible_items: u32, + ) -> PageLayout { + Self::full_zip_layout( + bits_rep, + bits_def, + full_zip_layout::Details::BitsPerOffset(bits_per_offset), + value_encoding, + def_meaning, + num_items, + num_visible_items, + ) + } + + pub fn all_null_layout(def_meaning: &[DefinitionInterpretation]) -> PageLayout { PageLayout { - layout: Some(Layout::AllNullLayout(AllNullLayout {})), + layout: Some(Layout::AllNullLayout(AllNullLayout { + layers: def_meaning + .iter() + .map(|&def| Self::def_inter_to_repdef_layer(def)) + .collect(), + })), } } + + pub fn simple_all_null_layout() -> PageLayout { + Self::all_null_layout(&[DefinitionInterpretation::NullableItem]) + } } diff --git a/rust/lance-encoding/src/lib.rs b/rust/lance-encoding/src/lib.rs index 19bcd721ed3..d6fc7bd627c 100644 --- a/rust/lance-encoding/src/lib.rs +++ b/rust/lance-encoding/src/lib.rs @@ -19,6 +19,7 @@ pub mod repdef; pub mod statistics; #[cfg(test)] pub mod testing; +pub mod utils; pub mod version; // We can definitely add support for big-endian machines someday. However, it's not a priority and diff --git a/rust/lance-encoding/src/repdef.rs b/rust/lance-encoding/src/repdef.rs index 82bbe680257..5d1922bc1a4 100644 --- a/rust/lance-encoding/src/repdef.rs +++ b/rust/lance-encoding/src/repdef.rs @@ -80,58 +80,424 @@ //! However, in Lance we don't always take advantage of that compression because we want to be able //! to zip rep-def levels together with our values. This gives us fewer IOPS when accessing row values. -// TODO: Right now, if a layer has no nulls, but other layers do, then we still use -// up a repetition layer for the no-null spot. For example, if we have four -// levels of rep: [has nulls, has nulls, no nulls, has nulls] then we will say: -// 0 = valid -// 1 = layer 4 null -// 2 = layer 3 null -// 3 = layer 2 null (useless) -// 4 = layer 1 null -// -// This means we end up with 3 bits per level instead of 2. We could instead record -// the layers that are all null somewhere else and not require wider rep levels. - -use std::{iter::Zip, sync::Arc}; +use std::{ + iter::{Copied, Zip}, + sync::Arc, +}; use arrow_array::OffsetSizeTrait; use arrow_buffer::{ ArrowNativeType, BooleanBuffer, BooleanBufferBuilder, NullBuffer, OffsetBuffer, ScalarBuffer, }; use lance_core::{utils::bit::log_2_ceil, Error, Result}; -use snafu::{location, Location}; +use snafu::location; + +use crate::buffer::LanceBuffer; // We assume 16 bits is good enough for rep-def levels. This gives us // 65536 levels of struct nesting and list nesting. pub type LevelBuffer = Vec; +/// Represents information that we extract from a list array as we are +/// encoding +#[derive(Clone, Debug)] +struct OffsetDesc { + offsets: Arc<[i64]>, + specials: Arc<[SpecialOffset]>, + validity: Option, + has_empty_lists: bool, + num_values: usize, +} + +/// Represents validity information that we extract from non-list arrays (that +/// have nulls) as we are encoding +#[derive(Clone, Debug)] +struct ValidityDesc { + validity: Option, + num_values: usize, +} + +/// Represents validity information that we extract from FSL arrays. This is +/// just validity (no offsets) but we also record the dimension of the FSL array +/// as that will impact the next layer +#[derive(Clone, Debug)] +struct FslDesc { + validity: Option, + dimension: usize, + num_values: usize, +} + // As we build up rep/def from arrow arrays we record a -// series of RawRepDef objects +// series of RawRepDef objects. Each one corresponds to layer +// in the array structure #[derive(Clone, Debug)] enum RawRepDef { - Offsets(Arc<[i64]>), - Validity(BooleanBuffer), - NoNull(usize), + Offsets(OffsetDesc), + Validity(ValidityDesc), + Fsl(FslDesc), +} + +impl RawRepDef { + // Are there any nulls in this layer + fn has_nulls(&self) -> bool { + match self { + Self::Offsets(OffsetDesc { validity, .. }) => validity.is_some(), + Self::Validity(ValidityDesc { validity, .. }) => validity.is_some(), + Self::Fsl(FslDesc { validity, .. }) => validity.is_some(), + } + } + + // How many values are in this layer + fn num_values(&self) -> usize { + match self { + Self::Offsets(OffsetDesc { num_values, .. }) => *num_values, + Self::Validity(ValidityDesc { num_values, .. }) => *num_values, + Self::Fsl(FslDesc { num_values, .. }) => *num_values, + } + } } /// Represents repetition and definition levels that have been /// serialized into a pair of (optional) level buffers #[derive(Debug)] pub struct SerializedRepDefs { - // If None, there are no lists - pub repetition_levels: Option, - // If None, there are no nulls - pub definition_levels: Option, + /// The repetition levels, one per item + /// + /// If None, there are no lists + pub repetition_levels: Option>, + /// The definition levels, one per item + /// + /// If None, there are no nulls + pub definition_levels: Option>, + /// Special records indicate empty / null lists + /// + /// These do not have any mapping to items. There may be empty or there may + /// be more special records than items or anywhere in between. + pub special_records: Vec, + /// The meaning of each definition level + pub def_meaning: Vec, + /// The maximum level that is "visible" from the lowest level + /// + /// This is the last level before we encounter a list level of some kind. Once we've + /// hit a list level then nulls in any level beyond do not map to actual items. + /// + /// This is None if there are no lists + pub max_visible_level: Option, } impl SerializedRepDefs { + pub fn new( + repetition_levels: Option, + definition_levels: Option, + special_records: Vec, + def_meaning: Vec, + ) -> Self { + let first_list = def_meaning.iter().position(|level| level.is_list()); + let max_visible_level = first_list.map(|first_list| { + def_meaning + .iter() + .map(|level| level.num_def_levels()) + .take(first_list) + .sum::() + }); + Self { + repetition_levels: repetition_levels.map(Arc::from), + definition_levels: definition_levels.map(Arc::from), + special_records, + def_meaning, + max_visible_level, + } + } + /// Creates an empty SerializedRepDefs (no repetition, all valid) - pub fn empty() -> Self { + pub fn empty(def_meaning: Vec) -> Self { Self { repetition_levels: None, definition_levels: None, + special_records: Vec::new(), + def_meaning, + max_visible_level: None, + } + } + + pub fn rep_slicer(&self) -> Option { + self.repetition_levels + .as_ref() + .map(|rep| RepDefSlicer::new(self, rep.clone())) + } + + pub fn def_slicer(&self) -> Option { + self.definition_levels + .as_ref() + .map(|def| RepDefSlicer::new(self, def.clone())) + } + + /// Creates a version of the SerializedRepDefs with the specials collapsed into + /// the repetition and definition levels + pub fn collapse_specials(self) -> Self { + if self.special_records.is_empty() { + return self; + } + + // If we have specials then we must have repetition + let rep = self.repetition_levels.unwrap(); + + let new_len = rep.len() + self.special_records.len(); + + let mut new_rep = Vec::with_capacity(new_len); + let mut new_def = Vec::with_capacity(new_len); + + // Now we just merge the rep/def levels and the specials into one list. There is just + // one tricky part. If a non-special is added after a special item then it swaps its + // repetition level with the special item. + if let Some(def) = self.definition_levels { + let mut def_itr = def.iter(); + let mut rep_itr = rep.iter(); + let mut special_itr = self.special_records.into_iter().peekable(); + let mut last_special = None; + + for idx in 0..new_len { + if let Some(special) = special_itr.peek() { + if special.pos == idx { + new_rep.push(special.rep_level); + new_def.push(special.def_level); + special_itr.next(); + last_special = Some(new_rep.last_mut().unwrap()); + } else { + let rep = if let Some(last_special) = last_special { + let rep = *last_special; + *last_special = *rep_itr.next().unwrap(); + rep + } else { + *rep_itr.next().unwrap() + }; + new_rep.push(rep); + new_def.push(*def_itr.next().unwrap()); + last_special = None; + } + } else { + let rep = if let Some(last_special) = last_special { + let rep = *last_special; + *last_special = *rep_itr.next().unwrap(); + rep + } else { + *rep_itr.next().unwrap() + }; + new_rep.push(rep); + new_def.push(*def_itr.next().unwrap()); + last_special = None; + } + } + } else { + let mut rep_itr = rep.iter(); + let mut special_itr = self.special_records.into_iter().peekable(); + let mut last_special = None; + + for idx in 0..new_len { + if let Some(special) = special_itr.peek() { + if special.pos == idx { + new_rep.push(special.rep_level); + new_def.push(special.def_level); + special_itr.next(); + last_special = Some(new_rep.last_mut().unwrap()); + } else { + let rep = if let Some(last_special) = last_special { + let rep = *last_special; + *last_special = *rep_itr.next().unwrap(); + rep + } else { + *rep_itr.next().unwrap() + }; + new_rep.push(rep); + new_def.push(0); + last_special = None; + } + } else { + let rep = if let Some(last_special) = last_special { + let rep = *last_special; + *last_special = *rep_itr.next().unwrap(); + rep + } else { + *rep_itr.next().unwrap() + }; + new_rep.push(rep); + new_def.push(0); + last_special = None; + } + } + } + + Self { + repetition_levels: Some(new_rep.into()), + definition_levels: Some(new_def.into()), + special_records: Vec::new(), + def_meaning: self.def_meaning, + max_visible_level: self.max_visible_level, + } + } +} + +/// Slices a level buffer into pieces +/// +/// This is needed to handle the fact that a level buffer may have more +/// levels than values due to special (empty/null) lists. +/// +/// As a result, a call to `slice_next(10)` may return 10 levels or it may +/// return more than 10 levels if any special values are encountered. +#[derive(Debug)] +pub struct RepDefSlicer<'a> { + repdef: &'a SerializedRepDefs, + to_slice: LanceBuffer, + current: usize, +} + +// TODO: All of this logic will need some changing when we compress rep/def levels. +impl<'a> RepDefSlicer<'a> { + fn new(repdef: &'a SerializedRepDefs, levels: Arc<[u16]>) -> Self { + Self { + repdef, + to_slice: LanceBuffer::reinterpret_slice(levels), + current: 0, + } + } + + pub fn num_levels(&self) -> usize { + self.to_slice.len() / 2 + } + + pub fn num_levels_remaining(&self) -> usize { + self.num_levels() - self.current + } + + pub fn all_levels(&self) -> &LanceBuffer { + &self.to_slice + } + + /// Returns the rest of the levels not yet sliced + /// + /// This must be called instead of `slice_next` on the final iteration. + /// This is because anytime we slice there may be empty/null lists on the + /// boundary that are "free" and the current behavior in `slice_next` is to + /// leave them for the next call. + /// + /// `slice_rest` will slice all remaining levels and return them. + pub fn slice_rest(&mut self) -> LanceBuffer { + let start = self.current; + let remaining = self.num_levels_remaining(); + self.current = self.num_levels(); + self.to_slice.slice_with_length(start * 2, remaining * 2) + } + + /// Returns enough levels to satisfy the next `num_values` values + pub fn slice_next(&mut self, num_values: usize) -> LanceBuffer { + let start = self.current; + let Some(max_visible_level) = self.repdef.max_visible_level else { + // No lists, should be 1:1 mapping from levels to values + self.current = start + num_values; + return self.to_slice.slice_with_length(start * 2, num_values * 2); + }; + if let Some(def) = self.repdef.definition_levels.as_ref() { + // There are lists and there are def levels. That means there may be + // more rep/def levels than values. We need to scan the def levels to figure + // out which items are "invisible" and skip over them + let mut def_itr = def[start..].iter(); + let mut num_taken = 0; + let mut num_passed = 0; + while num_taken < num_values { + let def_level = *def_itr.next().unwrap(); + if def_level <= max_visible_level { + num_taken += 1; + } + num_passed += 1; + } + self.current = start + num_passed; + self.to_slice.slice_with_length(start * 2, num_passed * 2) + } else { + // No def levels, should be 1:1 mapping from levels to values + self.current = start + num_values; + self.to_slice.slice_with_length(start * 2, num_values * 2) + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct SpecialRecord { + /// The position of the special record in the items array + /// + /// Note that this is the position in the "expanded" items array (including the specials) + /// + /// For example, if we have five items [I0, I1, ..., I4] and two specials [S0(pos=3), S1(pos=6)] then + /// the combined array is [I0, I1, I2, S0, I3, I4, S1]. + /// + /// Another tricky fact is that a special "swaps" the repetition level of the matching item when it is + /// being inserted into the combined list. So, if items are [I0(rep=2), I1(rep=1), I2(rep=2), I3(rep=0)] + /// and a special is S0(pos=2, rep=1) then the combined list is + /// [I0(rep=2), I1(rep=1), S0(rep=2), I2(rep=1), I3(rep=0)]. + /// + /// Or, to put it in practice we start with [[I0], [I1]], [[I2, I3]] and after inserting our special + /// we have [[I0], [I1]], [S0, [I2, I3]] + pos: usize, + /// The definition level of the special record. This is never 0 and is used to distinguish between an + /// empty list and a null list. + def_level: u16, + /// The repetition level of the special record. This is never 0 and is used to indicate which level of + /// nesting the special record is at. + rep_level: u16, +} + +/// This tells us how an array handles definition. Given a stack of +/// these and a nested array and a set of definition levels we can calculate +/// how we should interpret the definition levels. +/// +/// For example, if the interpretation is [AllValidItem, NullableItem] then +/// a 0 means "valid item" and a 1 means "null struct". If the interpretation +/// is [NullableItem, NullableItem] then a 0 means "valid item" and a 1 means +/// "null item" and a 2 means "null struct". +/// +/// Lists are tricky because we might use up to two definition levels for a +/// single layer of list nesting because we need one value to indicate "empty list" +/// and another value to indicate "null list". +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum DefinitionInterpretation { + AllValidItem, + AllValidList, + NullableItem, + NullableList, + EmptyableList, + NullableAndEmptyableList, +} + +impl DefinitionInterpretation { + /// How many definition levels do we need for this layer + pub fn num_def_levels(&self) -> u16 { + match self { + Self::AllValidItem => 0, + Self::AllValidList => 0, + Self::NullableItem => 1, + Self::NullableList => 1, + Self::EmptyableList => 1, + Self::NullableAndEmptyableList => 2, } } + + /// Does this layer have nulls? + pub fn is_all_valid(&self) -> bool { + matches!( + self, + Self::AllValidItem | Self::AllValidList | Self::EmptyableList + ) + } + + /// Does this layer represent a list? + pub fn is_list(&self) -> bool { + matches!( + self, + Self::AllValidList + | Self::NullableList + | Self::EmptyableList + | Self::NullableAndEmptyableList + ) + } } /// The RepDefBuilder is used to collect offsets & validity buffers @@ -139,7 +505,7 @@ impl SerializedRepDefs { /// to build the actual repetition and definition levels by walking through /// the arrow constructs in reverse order. /// -/// The algorithm for definition levels is pretty simple +/// The algorithm for definition levels is as follows: /// /// Given: /// - a validity buffer of [T, F, F, T, T] @@ -167,97 +533,269 @@ impl SerializedRepDefs { /// we would have 5 definition levels. We can use our current offsets /// ([0, 3, 5]) to expand [T, F] into [T, T, T, F, F]. struct SerializerContext { - last_offsets: Option>, + last_offsets: Option>, + last_offsets_full: Option>, + specials: Vec, + def_meaning: Vec, rep_levels: LevelBuffer, def_levels: LevelBuffer, current_rep: u16, current_def: u16, + // FSL layers multiply the preceding def / rep levels by the dimension + current_multiplier: usize, has_nulls: bool, } impl SerializerContext { - fn new(len: usize, has_nulls: bool) -> Self { + fn new(len: usize, has_nulls: bool, has_offsets: bool, num_layers: usize) -> Self { + let def_meaning = Vec::with_capacity(num_layers); Self { last_offsets: None, - rep_levels: LevelBuffer::with_capacity(len), + last_offsets_full: None, + rep_levels: if has_offsets { + vec![0; len] + } else { + LevelBuffer::default() + }, def_levels: if has_nulls { - LevelBuffer::with_capacity(len) + vec![0; len] } else { LevelBuffer::default() }, + def_meaning, current_rep: 1, current_def: 1, + current_multiplier: 1, has_nulls: false, + specials: Vec::default(), } } - fn record_all_valid(&mut self, len: usize) { - self.current_def += 1; - if self.def_levels.is_empty() { - self.def_levels.resize(len, 0); - } + fn checkout_def(&mut self, meaning: DefinitionInterpretation) -> u16 { + let def = self.current_def; + self.current_def += meaning.num_def_levels(); + self.def_meaning.push(meaning); + def } - fn record_offsets(&mut self, offsets: &Arc<[i64]>) { + fn record_offsets(&mut self, offset_desc: &OffsetDesc) { + if self.current_multiplier != 1 { + // If we need this it isn't too terrible. We just need to multiply all of the offsets in offset_desc by + // the current multiplier before we do anything with them. Not adding at the moment simply to avoid the + // burden of testing + todo!("List<...FSL<...>> not yet supported"); + } let rep_level = self.current_rep; + let (null_list_level, empty_list_level) = + match (offset_desc.validity.is_some(), offset_desc.has_empty_lists) { + (true, true) => { + let level = + self.checkout_def(DefinitionInterpretation::NullableAndEmptyableList); + (level, level + 1) + } + (true, false) => (self.checkout_def(DefinitionInterpretation::NullableList), 0), + (false, true) => ( + 0, + self.checkout_def(DefinitionInterpretation::EmptyableList), + ), + (false, false) => { + self.checkout_def(DefinitionInterpretation::AllValidList); + (0, 0) + } + }; self.current_rep += 1; if let Some(last_offsets) = &self.last_offsets { - let mut new_last_off = Vec::with_capacity(offsets.len()); - for off in offsets[..offsets.len() - 1].iter() { - let offset_ctx = last_offsets[*off as usize]; + let last_offsets_full = self.last_offsets_full.as_ref().unwrap(); + let mut new_last_off = Vec::with_capacity(offset_desc.offsets.len()); + let mut new_last_off_full = Vec::with_capacity(offset_desc.offsets.len()); + let mut empties_seen = 0; + for off in offset_desc.offsets.windows(2) { + let offset_ctx = last_offsets[off[0] as usize]; + let offset_ctx_end = last_offsets[off[1] as usize]; new_last_off.push(offset_ctx); - self.rep_levels[offset_ctx as usize] = rep_level; + new_last_off_full.push(last_offsets_full[off[0] as usize] + empties_seen); + if off[0] == off[1] { + // This list has an empty/null + empties_seen += 1; + } else if offset_ctx == offset_ctx_end { + // Inner list is empty/null + // We previously added a special record but now we need to upgrade its repetition + // level to the current level + let matching_special_idx = self + .specials + .binary_search_by_key(&offset_ctx, |spec| spec.pos) + .unwrap(); + self.specials[matching_special_idx].rep_level = rep_level; + } else { + self.rep_levels[offset_ctx] = rep_level; + } } - self.last_offsets = Some(new_last_off.into()); + new_last_off.push(last_offsets[*offset_desc.offsets.last().unwrap() as usize]); + new_last_off_full.push( + last_offsets_full[*offset_desc.offsets.last().unwrap() as usize] + empties_seen, + ); + self.last_offsets = Some(new_last_off); + self.last_offsets_full = Some(new_last_off_full); } else { - self.rep_levels.resize(*offsets.last().unwrap() as usize, 0); - for off in offsets[..offsets.len() - 1].iter() { - self.rep_levels[*off as usize] = rep_level; + let mut new_last_off = Vec::with_capacity(offset_desc.offsets.len()); + let mut new_last_off_full = Vec::with_capacity(offset_desc.offsets.len()); + let mut empties_seen = 0; + for off in offset_desc.offsets.windows(2) { + new_last_off.push(off[0] as usize); + new_last_off_full.push(off[0] as usize + empties_seen); + if off[0] == off[1] { + empties_seen += 1; + } else { + self.rep_levels[off[0] as usize] = rep_level; + } } - self.last_offsets = Some(offsets.clone()); + new_last_off.push(*offset_desc.offsets.last().unwrap() as usize); + new_last_off_full.push(*offset_desc.offsets.last().unwrap() as usize + empties_seen); + self.last_offsets = Some(new_last_off); + self.last_offsets_full = Some(new_last_off_full); } + + // Must update specials _after_ setting last_offsets_full + let last_offsets_full = self.last_offsets_full.as_ref().unwrap(); + let num_combined_specials = self.specials.len() + offset_desc.specials.len(); + let mut new_specials = Vec::with_capacity(num_combined_specials); + let mut new_inserted = 0; + let mut old_specials_itr = self.specials.iter().peekable(); + let mut specials_itr = offset_desc.specials.iter().peekable(); + for _ in 0..num_combined_specials { + if let Some(old_special) = old_specials_itr.peek() { + let old_special_pos = old_special.pos + new_inserted; + if let Some(new_special) = specials_itr.peek() { + let new_special_pos = last_offsets_full[new_special.pos()]; + if old_special_pos < new_special_pos { + let mut old_special = *old_specials_itr.next().unwrap(); + old_special.pos = old_special_pos; + new_specials.push(old_special); + } else { + let new_special = specials_itr.next().unwrap(); + new_specials.push(SpecialRecord { + pos: new_special_pos, + def_level: if matches!(new_special, SpecialOffset::EmptyList(_)) { + empty_list_level + } else { + null_list_level + }, + rep_level, + }); + new_inserted += 1; + } + } else { + let mut old_special = *old_specials_itr.next().unwrap(); + old_special.pos = old_special_pos; + new_specials.push(old_special); + } + } else { + let new_special = specials_itr.next().unwrap(); + new_specials.push(SpecialRecord { + pos: last_offsets_full[new_special.pos()], + def_level: if matches!(new_special, SpecialOffset::EmptyList(_)) { + empty_list_level + } else { + null_list_level + }, + rep_level, + }); + new_inserted += 1; + } + } + self.specials = new_specials; } - fn record_validity(&mut self, validity: &BooleanBuffer) { + fn do_record_validity(&mut self, validity: &BooleanBuffer, null_level: u16) { self.has_nulls = true; - let def_level = self.current_def; - self.current_def += 1; - if self.def_levels.is_empty() { - self.def_levels.resize(validity.len(), 0); - } + assert!(!self.def_levels.is_empty()); if let Some(last_offsets) = &self.last_offsets { last_offsets .windows(2) .zip(validity.iter()) .for_each(|(w, valid)| { + let start = w[0] * self.current_multiplier; + let end = w[1] * self.current_multiplier; if !valid { - self.def_levels[w[0] as usize..w[1] as usize].fill(def_level); + self.def_levels[start..end].fill(null_level); } }); - } else { + } else if self.current_multiplier == 1 { self.def_levels .iter_mut() .zip(validity.iter()) .for_each(|(def, valid)| { if !valid { - *def = def_level; + *def = null_level; } }); + } else { + self.def_levels + .iter_mut() + .zip( + validity + .iter() + .flat_map(|v| std::iter::repeat_n(v, self.current_multiplier)), + ) + .for_each(|(def, valid)| { + if !valid { + *def = null_level; + } + }); + } + } + + fn record_validity_buf(&mut self, validity: &Option) { + if let Some(validity) = validity { + let def_level = self.checkout_def(DefinitionInterpretation::NullableItem); + self.do_record_validity(validity, def_level); + } else { + self.checkout_def(DefinitionInterpretation::AllValidItem); } } + fn record_validity(&mut self, validity_desc: &ValidityDesc) { + self.record_validity_buf(&validity_desc.validity) + } + + fn record_fsl(&mut self, fsl_desc: &FslDesc) { + self.current_multiplier *= fsl_desc.dimension; + self.record_validity_buf(&fsl_desc.validity); + } + fn build(self) -> SerializedRepDefs { - SerializedRepDefs { - definition_levels: if self.has_nulls { - Some(self.def_levels) - } else { - None - }, - repetition_levels: if self.current_rep > 1 { - Some(self.rep_levels) - } else { - None - }, + let definition_levels = if self.has_nulls { + Some(self.def_levels) + } else { + None + }; + let repetition_levels = if self.current_rep > 1 { + Some(self.rep_levels) + } else { + None + }; + SerializedRepDefs::new( + repetition_levels, + definition_levels, + self.specials, + self.def_meaning, + ) + } +} + +/// As we are encoding we record information about "specials" which are +/// empty lists or null lists. +#[derive(Debug, Copy, Clone)] +enum SpecialOffset { + NullList(usize), + EmptyList(usize), +} + +impl SpecialOffset { + fn pos(&self) -> usize { + match self { + Self::NullList(pos) => *pos, + Self::EmptyList(pos) => *pos, } } } @@ -268,7 +806,7 @@ impl SerializerContext { /// As we are encoding the structural encoders are given this struct and /// will record the arrow information into it. Once we hit a leaf node we /// serialize the data into rep/def levels and write these into the page. -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub struct RepDefBuilder { // The rep/def info we have collected so far repdefs: Vec, @@ -280,21 +818,23 @@ pub struct RepDefBuilder { } impl RepDefBuilder { - fn check_validity_len(&mut self, validity: &NullBuffer) { + fn check_validity_len(&mut self, incoming_len: usize) { if let Some(len) = self.len { - assert!(validity.len() == len); + assert_eq!(incoming_len, len); } - self.len = Some(validity.len()); + self.len = Some(incoming_len); } fn num_layers(&self) -> usize { self.repdefs.len() } + /// The builder is "empty" if there is no repetition and no nulls. In this case we don't need + /// to store anything to disk (except the description) fn is_empty(&self) -> bool { self.repdefs .iter() - .all(|r| matches!(r, RawRepDef::NoNull(_))) + .all(|r| matches!(r, RawRepDef::Validity(ValidityDesc { validity: None, .. }))) } /// Returns true if there is only a single layer of definition @@ -307,21 +847,55 @@ impl RepDefBuilder { /// Return False if all layers are non-null (the def levels can /// be skipped in this case) pub fn has_nulls(&self) -> bool { + self.repdefs.iter().any(|rd| { + matches!( + rd, + RawRepDef::Validity(ValidityDesc { + validity: Some(_), + .. + }) | RawRepDef::Fsl(FslDesc { + validity: Some(_), + .. + }) + ) + }) + } + + pub fn has_offsets(&self) -> bool { self.repdefs .iter() - .any(|rd| matches!(rd, RawRepDef::Validity(_))) + .any(|rd| matches!(rd, RawRepDef::Offsets(OffsetDesc { .. }))) } /// Registers a nullable validity bitmap pub fn add_validity_bitmap(&mut self, validity: NullBuffer) { - self.check_validity_len(&validity); - self.repdefs - .push(RawRepDef::Validity(validity.into_inner())); + self.check_validity_len(validity.len()); + self.repdefs.push(RawRepDef::Validity(ValidityDesc { + num_values: validity.len(), + validity: Some(validity.into_inner()), + })); } /// Registers an all-valid validity layer pub fn add_no_null(&mut self, len: usize) { - self.repdefs.push(RawRepDef::NoNull(len)); + self.check_validity_len(len); + self.repdefs.push(RawRepDef::Validity(ValidityDesc { + validity: None, + num_values: len, + })); + } + + pub fn add_fsl(&mut self, validity: Option, dimension: usize, num_values: usize) { + if let Some(len) = self.len { + assert_eq!(num_values, len); + } + self.len = Some(num_values * dimension); + debug_assert!(validity.is_none() || validity.as_ref().unwrap().len() == num_values); + self.repdefs.push(RawRepDef::Fsl(FslDesc { + num_values, + validity: validity.map(|v| v.into_inner()), + dimension, + })) } fn check_offset_len(&mut self, offsets: &[i64]) { @@ -333,93 +907,331 @@ impl RepDefBuilder { /// Adds a layer of offsets /// - /// Note: a List/LargeList/etc. array has both offsets and validity. The - /// caller should register the validity before registering the offsets - pub fn add_offsets(&mut self, repetition: OffsetBuffer) { - // We should be able to zero-copy + /// Offsets are casted to a common type (i64) and also normalized. Null lists are + /// always represented by a zero-length (identical) pair of offsets and so the caller + /// should filter out any garbage items before encoding them. To assist with this the + /// method will return true if any non-empty null lists were found. + pub fn add_offsets( + &mut self, + offsets: OffsetBuffer, + validity: Option, + ) -> bool { + let mut has_garbage_values = false; if O::IS_LARGE { - let inner = repetition.into_inner(); + let inner = offsets.into_inner(); let len = inner.len(); - let i64_buff = ScalarBuffer::new(inner.into_inner(), 0, len); - let offsets = Vec::from(i64_buff); - self.check_offset_len(&offsets); - self.repdefs.push(RawRepDef::Offsets(offsets.into())); + let i64_buff = ScalarBuffer::::new(inner.into_inner(), 0, len); + let mut normalized = Vec::with_capacity(len); + normalized.push(0_i64); + let mut specials = Vec::new(); + let mut has_empty_lists = false; + let mut last_off = 0; + if let Some(validity) = validity.as_ref() { + for (idx, (off, valid)) in i64_buff.windows(2).zip(validity.iter()).enumerate() { + let len: i64 = off[1] - off[0]; + match (valid, len == 0) { + (false, is_empty) => { + specials.push(SpecialOffset::NullList(idx)); + has_garbage_values |= !is_empty; + } + (true, true) => { + has_empty_lists = true; + specials.push(SpecialOffset::EmptyList(idx)); + } + _ => { + last_off += len; + } + } + normalized.push(last_off); + } + } else { + for (idx, off) in i64_buff.windows(2).enumerate() { + let len: i64 = off[1] - off[0]; + if len == 0 { + has_empty_lists = true; + specials.push(SpecialOffset::EmptyList(idx)); + } + last_off += len; + normalized.push(last_off); + } + }; + self.check_offset_len(&normalized); + self.repdefs.push(RawRepDef::Offsets(OffsetDesc { + num_values: normalized.len() - 1, + offsets: normalized.into(), + validity: validity.map(|v| v.into_inner()), + has_empty_lists, + specials: specials.into(), + })); + has_garbage_values } else { - let inner = repetition.into_inner(); + let inner = offsets.into_inner(); let len = inner.len(); - let casted = ScalarBuffer::::new(inner.into_inner(), 0, len) - .iter() - .copied() - .map(|o| o as i64) - .collect::>(); + let scalar_off = ScalarBuffer::::new(inner.into_inner(), 0, len); + let mut casted = Vec::with_capacity(len); + casted.push(0); + let mut has_empty_lists = false; + let mut specials = Vec::new(); + let mut last_off: i64 = 0; + if let Some(validity) = validity.as_ref() { + for (idx, (off, valid)) in scalar_off.windows(2).zip(validity.iter()).enumerate() { + let len = (off[1] - off[0]) as i64; + match (valid, len == 0) { + (false, is_empty) => { + specials.push(SpecialOffset::NullList(idx)); + has_garbage_values |= !is_empty; + } + (true, true) => { + has_empty_lists = true; + specials.push(SpecialOffset::EmptyList(idx)); + } + _ => { + last_off += len; + } + } + casted.push(last_off); + } + } else { + for (idx, off) in scalar_off.windows(2).enumerate() { + let len = (off[1] - off[0]) as i64; + if len == 0 { + has_empty_lists = true; + specials.push(SpecialOffset::EmptyList(idx)); + } + last_off += len; + casted.push(last_off); + } + }; self.check_offset_len(&casted); - self.repdefs.push(RawRepDef::Offsets(casted.into())); + self.repdefs.push(RawRepDef::Offsets(OffsetDesc { + num_values: casted.len() - 1, + offsets: casted.into(), + validity: validity.map(|v| v.into_inner()), + has_empty_lists, + specials: specials.into(), + })); + has_garbage_values } } - // TODO: This is lazy. We shouldn't need this concatenation pass. We should be able - // to concatenate as we build up the rep/def levels but I'm saving that for a - // future optimization. - fn concat_layers<'a>(mut layers: impl Iterator, len: usize) -> RawRepDef { - let first = layers.next().unwrap(); - match &first { - RawRepDef::NoNull(_) | RawRepDef::Validity(_) => { - // Also lazy, building up a validity buffer just to throw it away - // if there are no nulls - let mut has_nulls = false; - let mut builder = BooleanBufferBuilder::new(len); - for layer in std::iter::once(first).chain(layers) { - match layer { - RawRepDef::NoNull(num_valid) => { - builder.append_n(*num_valid, true); - } - RawRepDef::Validity(validity) => { - has_nulls = true; - builder.append_buffer(validity); - } - _ => unreachable!(), - } + // When we are encoding data it arrives in batches. For each batch we create a RepDefBuilder and collect the + // various validity buffers and offset buffers from that batch. Once we have enough batches to write a page we + // need to take this collection of RepDefBuilders and concatenate them and then serialize them into rep/def levels. + // + // TODO: In the future, we may concatenate and serialize at the same time? + // + // This method takes care of the concatenation part. First we collect all of layer 0 from each builder, then we + // call this method. Then we collect all of layer 1 from each builder and call this method. And so on. + // + // That means this method should get a collection of `RawRepDef` where each item is the same kind (all validity or + // all offsets) though the nullability / lengths may be different in each layer. + fn concat_layers<'a>( + layers: impl Iterator, + num_layers: usize, + ) -> RawRepDef { + enum LayerKind { + Validity, + Fsl, + Offsets, + } + + // We make two passes through the layers. The first determines if we need to pay the cost of allocating + // buffers. The second pass actually adds the values. + let mut collected = Vec::with_capacity(num_layers); + let mut has_nulls = false; + let mut layer_kind = LayerKind::Validity; + let mut num_specials = 0; + let mut all_dimension = 0; + let mut all_has_empty_lists = false; + let mut all_num_values = 0; + for layer in layers { + has_nulls |= layer.has_nulls(); + match layer { + RawRepDef::Validity(_) => { + layer_kind = LayerKind::Validity; } - if has_nulls { - RawRepDef::Validity(builder.finish()) - } else { - RawRepDef::NoNull(builder.len()) + RawRepDef::Offsets(OffsetDesc { + specials, + has_empty_lists, + .. + }) => { + all_has_empty_lists |= *has_empty_lists; + layer_kind = LayerKind::Offsets; + num_specials += specials.len(); + } + RawRepDef::Fsl(FslDesc { dimension, .. }) => { + layer_kind = LayerKind::Fsl; + all_dimension = *dimension; + } + } + collected.push(layer); + all_num_values += layer.num_values(); + } + + // Shortcut if there are no nulls + if !has_nulls { + match layer_kind { + LayerKind::Validity => { + return RawRepDef::Validity(ValidityDesc { + validity: None, + num_values: all_num_values, + }); } + LayerKind::Fsl => { + return RawRepDef::Fsl(FslDesc { + validity: None, + num_values: all_num_values, + dimension: all_dimension, + }) + } + LayerKind::Offsets => {} } - RawRepDef::Offsets(offsets) => { - let mut all_offsets = Vec::with_capacity(len); - all_offsets.extend(offsets.iter().copied()); - for layer in layers { + } + + // Only allocate if needed + let mut validity_builder = if has_nulls { + BooleanBufferBuilder::new(all_num_values) + } else { + BooleanBufferBuilder::new(0) + }; + let mut all_offsets = if matches!(layer_kind, LayerKind::Offsets) { + let mut all_offsets = Vec::with_capacity(all_num_values); + all_offsets.push(0); + all_offsets + } else { + Vec::new() + }; + let mut all_specials = Vec::with_capacity(num_specials); + + for layer in collected { + match layer { + RawRepDef::Validity(ValidityDesc { + validity: Some(validity), + .. + }) => { + validity_builder.append_buffer(validity); + } + RawRepDef::Validity(ValidityDesc { + validity: None, + num_values, + }) => { + validity_builder.append_n(*num_values, true); + } + RawRepDef::Fsl(FslDesc { + validity, + num_values, + .. + }) => { + if let Some(validity) = validity { + validity_builder.append_buffer(validity); + } else { + validity_builder.append_n(*num_values, true); + } + } + RawRepDef::Offsets(OffsetDesc { + offsets, + validity: Some(validity), + has_empty_lists, + specials, + .. + }) => { + all_has_empty_lists |= has_empty_lists; + validity_builder.append_buffer(validity); + let existing_lists = all_offsets.len() - 1; let last = *all_offsets.last().unwrap(); - let RawRepDef::Offsets(offsets) = layer else { - unreachable!() - }; all_offsets.extend(offsets.iter().skip(1).map(|off| *off + last)); + all_specials.extend(specials.iter().map(|s| match s { + SpecialOffset::NullList(pos) => { + SpecialOffset::NullList(*pos + existing_lists) + } + SpecialOffset::EmptyList(pos) => { + SpecialOffset::EmptyList(*pos + existing_lists) + } + })); + } + RawRepDef::Offsets(OffsetDesc { + offsets, + validity: None, + has_empty_lists, + num_values, + specials, + }) => { + all_has_empty_lists |= has_empty_lists; + if has_nulls { + validity_builder.append_n(*num_values, true); + } + let last = *all_offsets.last().unwrap(); + let existing_lists = all_offsets.len() - 1; + all_offsets.extend(offsets.iter().skip(1).map(|off| *off + last)); + all_specials.extend(specials.iter().map(|s| match s { + SpecialOffset::NullList(pos) => { + SpecialOffset::NullList(*pos + existing_lists) + } + SpecialOffset::EmptyList(pos) => { + SpecialOffset::EmptyList(*pos + existing_lists) + } + })); } - RawRepDef::Offsets(all_offsets.into()) } } + let validity = if has_nulls { + Some(validity_builder.finish()) + } else { + None + }; + match layer_kind { + LayerKind::Fsl => RawRepDef::Fsl(FslDesc { + validity, + num_values: all_num_values, + dimension: all_dimension, + }), + LayerKind::Validity => RawRepDef::Validity(ValidityDesc { + validity, + num_values: all_num_values, + }), + LayerKind::Offsets => RawRepDef::Offsets(OffsetDesc { + offsets: all_offsets.into(), + validity, + has_empty_lists: all_has_empty_lists, + num_values: all_num_values, + specials: all_specials.into(), + }), + } } /// Converts the validity / offsets buffers that have been gathered so far /// into repetition and definition levels pub fn serialize(builders: Vec) -> SerializedRepDefs { - if builders.is_empty() { - return SerializedRepDefs::empty(); - } + assert!(!builders.is_empty()); if builders.iter().all(|b| b.is_empty()) { // No repetition, all-valid - return SerializedRepDefs::empty(); + return SerializedRepDefs::empty( + builders + .first() + .unwrap() + .repdefs + .iter() + .map(|_| DefinitionInterpretation::AllValidItem) + .collect::>(), + ); } let has_nulls = builders.iter().any(|b| b.has_nulls()); + let has_offsets = builders.iter().any(|b| b.has_offsets()); let total_len = builders.iter().map(|b| b.len.unwrap()).sum(); - let mut context = SerializerContext::new(total_len, has_nulls); + let num_layers = builders[0].num_layers(); + let mut context = SerializerContext::new(total_len, has_nulls, has_offsets, num_layers); + let combined_layers = (0..num_layers) + .map(|layer_index| { + Self::concat_layers( + builders.iter().map(|b| &b.repdefs[layer_index]), + builders.len(), + ) + }) + .collect::>(); debug_assert!(builders .iter() .all(|b| b.num_layers() == builders[0].num_layers())); - for layer_index in (0..builders[0].num_layers()).rev() { - let layer = - Self::concat_layers(builders.iter().map(|b| &b.repdefs[layer_index]), total_len); + for layer in combined_layers.into_iter().rev() { match layer { RawRepDef::Validity(def) => { context.record_validity(&def); @@ -427,12 +1239,12 @@ impl RepDefBuilder { RawRepDef::Offsets(rep) => { context.record_offsets(&rep); } - RawRepDef::NoNull(len) => { - context.record_all_valid(len); + RawRepDef::Fsl(fsl) => { + context.record_fsl(&fsl); } } } - context.build() + context.build().collapse_specials() } } @@ -444,37 +1256,142 @@ impl RepDefBuilder { pub struct RepDefUnraveler { rep_levels: Option, def_levels: Option, + // Maps from definition level to the rep level at which that definition level is visible + levels_to_rep: Vec, + def_meaning: Arc<[DefinitionInterpretation]>, // Current definition level to compare to. current_def_cmp: u16, + // Current rep level, determines which specials we can see + current_rep_cmp: u16, + // Current layer index, 0 means inner-most layer and it counts up from there. Used to index + // into special_defs + current_layer: usize, } impl RepDefUnraveler { /// Creates a new unraveler from serialized repetition and definition information - pub fn new(rep_levels: Option, def_levels: Option) -> Self { + pub fn new( + rep_levels: Option, + def_levels: Option, + def_meaning: Arc<[DefinitionInterpretation]>, + ) -> Self { + let mut levels_to_rep = Vec::with_capacity(def_meaning.len()); + let mut rep_counter = 0; + // Level=0 is always visible and means valid item + levels_to_rep.push(0); + for meaning in def_meaning.as_ref() { + match meaning { + DefinitionInterpretation::AllValidItem | DefinitionInterpretation::AllValidList => { + // There is no corresponding level, so nothing to put in levels_to_rep + } + DefinitionInterpretation::NullableItem => { + // Some null structs are not visible at inner rep levels in cases like LIST>> + levels_to_rep.push(rep_counter); + } + DefinitionInterpretation::NullableList => { + rep_counter += 1; + levels_to_rep.push(rep_counter); + } + DefinitionInterpretation::EmptyableList => { + rep_counter += 1; + levels_to_rep.push(rep_counter); + } + DefinitionInterpretation::NullableAndEmptyableList => { + rep_counter += 1; + levels_to_rep.push(rep_counter); + levels_to_rep.push(rep_counter); + } + } + } Self { rep_levels, def_levels, current_def_cmp: 0, + current_rep_cmp: 0, + levels_to_rep, + current_layer: 0, + def_meaning, } } + pub fn is_all_valid(&self) -> bool { + self.def_meaning[self.current_layer].is_all_valid() + } + + /// If the current level is a repetition layer then this returns the number of lists + /// at this level. + /// + /// This is not valid to call when the current level is a struct/primitive layer because + /// in some cases there may be no rep or def information to know this. + pub fn max_lists(&self) -> usize { + debug_assert!( + self.def_meaning[self.current_layer] != DefinitionInterpretation::NullableItem + ); + self.rep_levels + .as_ref() + // Worst case every rep item is max_rep and a new list + .map(|levels| levels.len()) + .unwrap_or(0) + } + /// Unravels a layer of offsets from the unraveler into the given offset width /// /// When decoding a list the caller should first unravel the offsets and then /// unravel the validity (this is the opposite order used during encoding) - pub fn unravel_offsets(&mut self) -> Result> { + pub fn unravel_offsets( + &mut self, + offsets: &mut Vec, + validity: Option<&mut BooleanBufferBuilder>, + ) -> Result<()> { let rep_levels = self .rep_levels .as_mut() .expect("Expected repetition level but data didn't contain repetition"); - let mut offsets: Vec = Vec::with_capacity(rep_levels.len() + 1); - let mut curlen: usize = 0; + let valid_level = self.current_def_cmp; + let (null_level, empty_level) = match self.def_meaning[self.current_layer] { + DefinitionInterpretation::NullableList => { + self.current_def_cmp += 1; + (valid_level + 1, 0) + } + DefinitionInterpretation::EmptyableList => { + self.current_def_cmp += 1; + (0, valid_level + 1) + } + DefinitionInterpretation::NullableAndEmptyableList => { + self.current_def_cmp += 2; + (valid_level + 1, valid_level + 2) + } + DefinitionInterpretation::AllValidList => (0, 0), + _ => unreachable!(), + }; + let max_level = null_level.max(empty_level); + self.current_layer += 1; + + let mut curlen: usize = offsets.last().map(|o| o.as_usize()).unwrap_or(0); + + // If offsets is empty this is a no-op. If offsets is not empty that means we already + // added a set of offsets. For example, we might have added [0, 3, 5] (2 lists). Now + // say we want to add [0, 1, 4] (2 lists). We should get [0, 3, 5, 6, 9] (4 lists). If + // we don't pop here we get [0, 3, 5, 5, 6, 9] which is wrong. + // + // Or, to think about it another way, if every unraveler adds the starting 0 and the trailing + // length then we have N + unravelers.len() values instead of N + 1. + offsets.pop(); + let to_offset = |val: usize| { T::from_usize(val) .ok_or_else(|| Error::invalid_input("A single batch had more than i32::MAX values and so a large container type is required", location!())) }; + self.current_rep_cmp += 1; if let Some(def_levels) = &mut self.def_levels { assert!(rep_levels.len() == def_levels.len()); + // It's possible validity is None even if we have def levels. For example, we might have + // empty lists (which require def levels) but no nulls. + let mut push_validity: Box = if let Some(validity) = validity { + Box::new(|is_valid| validity.append(is_valid)) + } else { + Box::new(|_| {}) + }; // This is a strange access pattern. We are iterating over the rep/def levels and // at the same time writing the rep/def levels. This means we need both a mutable // and immutable reference to the rep/def levels. @@ -486,25 +1403,48 @@ impl RepDefUnraveler { unsafe { let rep_val = *rep_levels.get_unchecked(read_idx); if rep_val != 0 { - // Finish the current list - offsets.push(to_offset(curlen)?); + let def_val = *def_levels.get_unchecked(read_idx); + // Copy over *rep_levels.get_unchecked_mut(write_idx) = rep_val - 1; - *def_levels.get_unchecked_mut(write_idx) = - *def_levels.get_unchecked(read_idx); + *def_levels.get_unchecked_mut(write_idx) = def_val; write_idx += 1; + + if def_val == 0 { + // This is a valid list + offsets.push(to_offset(curlen)?); + curlen += 1; + push_validity(true); + } else if def_val > max_level { + // This is not visible at this rep level, do not add to offsets, but keep in repdef + } else if def_val == null_level { + // This is a null list + offsets.push(to_offset(curlen)?); + push_validity(false); + } else if def_val == empty_level { + // This is an empty list + offsets.push(to_offset(curlen)?); + push_validity(true); + } else { + // New valid list starting with null item + offsets.push(to_offset(curlen)?); + curlen += 1; + push_validity(true); + } + } else { + curlen += 1; } - curlen += 1; read_idx += 1; } } offsets.push(to_offset(curlen)?); - rep_levels.truncate(offsets.len() - 1); - def_levels.truncate(offsets.len() - 1); - Ok(OffsetBuffer::new(ScalarBuffer::from(offsets))) + rep_levels.truncate(write_idx); + def_levels.truncate(write_idx); + Ok(()) } else { // SAFETY: See above loop let mut read_idx = 0; let mut write_idx = 0; + let old_offsets_len = offsets.len(); while read_idx < rep_levels.len() { // SAFETY: read_idx / write_idx cannot go past rep_levels.len() unsafe { @@ -519,25 +1459,165 @@ impl RepDefUnraveler { read_idx += 1; } } + let num_new_lists = offsets.len() - old_offsets_len; offsets.push(to_offset(curlen)?); rep_levels.truncate(offsets.len() - 1); - Ok(OffsetBuffer::new(ScalarBuffer::from(offsets))) + if let Some(validity) = validity { + // Even though we don't have validity it is possible another unraveler did and so we need + // to push all valids + validity.append_n(num_new_lists, true); + } + Ok(()) } } + pub fn skip_validity(&mut self) { + debug_assert!( + self.def_meaning[self.current_layer] == DefinitionInterpretation::AllValidItem + ); + self.current_layer += 1; + } + /// Unravels a layer of validity from the definition levels - pub fn unravel_validity(&mut self) -> Option { - let Some(def_levels) = &self.def_levels else { - return None; - }; + pub fn unravel_validity(&mut self, validity: &mut BooleanBufferBuilder) { + debug_assert!( + self.def_meaning[self.current_layer] != DefinitionInterpretation::AllValidItem + ); + self.current_layer += 1; + + let def_levels = &self.def_levels.as_ref().unwrap(); + let current_def_cmp = self.current_def_cmp; self.current_def_cmp += 1; - let validity = BooleanBuffer::from_iter(def_levels.iter().map(|&r| r <= current_def_cmp)); - if validity.count_set_bits() == validity.len() { + + for is_valid in def_levels.iter().filter_map(|&level| { + if self.levels_to_rep[level as usize] <= self.current_rep_cmp { + Some(level <= current_def_cmp) + } else { + None + } + }) { + validity.append(is_valid); + } + } + + pub fn decimate(&mut self, dimension: usize) { + if self.rep_levels.is_some() { + // If we need to support this then I think we need to walk through the rep def levels to find + // the spots at which we keep. E.g. if we have: + // rep: 1 0 0 1 0 1 0 0 0 1 0 0 + // def: 1 1 1 0 1 0 1 1 0 1 1 0 + // dimension: 2 + // + // The output should be: + // rep: 1 0 0 1 0 0 0 + // def: 1 1 1 0 1 1 0 + // + // Maybe there's some special logic for empty/null lists? I'll save the headache for future me. + todo!("Not yet supported FSL<...List<...>>"); + } + let Some(def_levels) = self.def_levels.as_mut() else { + return; + }; + let mut read_idx = 0; + let mut write_idx = 0; + while read_idx < def_levels.len() { + unsafe { + *def_levels.get_unchecked_mut(write_idx) = *def_levels.get_unchecked(read_idx); + } + write_idx += 1; + read_idx += dimension; + } + def_levels.truncate(write_idx); + } +} + +/// As we decode we may extract rep/def information from multiple pages (or multiple +/// chunks within a page). +/// +/// For each chunk we create an unraveler. Each unraveler can have a completely different +/// interpretation (e.g. one page might contain null items but no null structs and the next +/// page might have null structs but no null items). +/// +/// Concatenating these unravelers would be tricky and expensive so instead we have a +/// composite unraveler which unravels across multiple unravelers. +/// +/// Note: this class should be used even if there is only one page / unraveler. This is +/// because the `RepDefUnraveler`'s API is more complex (it's meant to be called by this +/// class) +#[derive(Debug)] +pub struct CompositeRepDefUnraveler { + unravelers: Vec, +} + +impl CompositeRepDefUnraveler { + pub fn new(unravelers: Vec) -> Self { + Self { unravelers } + } + + /// Unravels a layer of validity + /// + /// Returns None if there are no null items in this layer + pub fn unravel_validity(&mut self, num_values: usize) -> Option { + let is_all_valid = self + .unravelers + .iter() + .all(|unraveler| unraveler.is_all_valid()); + + if is_all_valid { + for unraveler in self.unravelers.iter_mut() { + unraveler.skip_validity(); + } + None + } else { + let mut validity = BooleanBufferBuilder::new(num_values); + for unraveler in self.unravelers.iter_mut() { + unraveler.unravel_validity(&mut validity); + } + Some(NullBuffer::new(validity.finish())) + } + } + + pub fn unravel_fsl_validity( + &mut self, + num_values: usize, + dimension: usize, + ) -> Option { + for unraveler in self.unravelers.iter_mut() { + unraveler.decimate(dimension); + } + self.unravel_validity(num_values) + } + + /// Unravels a layer of offsets (and the validity for that layer) + pub fn unravel_offsets( + &mut self, + ) -> Result<(OffsetBuffer, Option)> { + let mut is_all_valid = true; + let mut max_num_lists = 0; + for unraveler in self.unravelers.iter() { + is_all_valid &= unraveler.is_all_valid(); + max_num_lists += unraveler.max_lists(); + } + + let mut validity = if is_all_valid { None } else { - Some(NullBuffer::new(validity)) + // Note: This is probably an over-estimate and potentially even an under-estimate. We only know + // right now how many items we have and not how many rows. (TODO: Shouldn't we know the # of rows?) + Some(BooleanBufferBuilder::new(max_num_lists)) + }; + + let mut offsets = Vec::with_capacity(max_num_lists + 1); + + for unraveler in self.unravelers.iter_mut() { + unraveler.unravel_offsets(&mut offsets, validity.as_mut())?; } + + Ok(( + OffsetBuffer::new(ScalarBuffer::from(offsets)), + validity.map(|mut v| NullBuffer::new(v.finish())), + )) } } @@ -550,6 +1630,8 @@ impl RepDefUnraveler { pub struct BinaryControlWordIterator, W> { repdef: I, def_width: usize, + max_rep: u16, + max_visible_def: u16, rep_mask: u16, def_mask: u16, bits_rep: u8, @@ -558,28 +1640,44 @@ pub struct BinaryControlWordIterator, W> { } impl> BinaryControlWordIterator { - fn append_next(&mut self, buf: &mut Vec) { - let next = self.repdef.next().unwrap(); + fn append_next(&mut self, buf: &mut Vec) -> Option { + let next = self.repdef.next()?; let control_word: u8 = (((next.0 & self.rep_mask) as u8) << self.def_width) + ((next.1 & self.def_mask) as u8); buf.push(control_word); + let is_new_row = next.0 == self.max_rep; + let is_visible = next.1 <= self.max_visible_def; + let is_valid_item = next.1 == 0; + Some(ControlWordDesc { + is_new_row, + is_visible, + is_valid_item, + }) } } impl> BinaryControlWordIterator { - fn append_next(&mut self, buf: &mut Vec) { - let next = self.repdef.next().unwrap(); + fn append_next(&mut self, buf: &mut Vec) -> Option { + let next = self.repdef.next()?; let control_word: u16 = ((next.0 & self.rep_mask) << self.def_width) + (next.1 & self.def_mask); let control_word = control_word.to_le_bytes(); buf.push(control_word[0]); buf.push(control_word[1]); + let is_new_row = next.0 == self.max_rep; + let is_visible = next.1 <= self.max_visible_def; + let is_valid_item = next.1 == 0; + Some(ControlWordDesc { + is_new_row, + is_visible, + is_valid_item, + }) } } impl> BinaryControlWordIterator { - fn append_next(&mut self, buf: &mut Vec) { - let next = self.repdef.next().unwrap(); + fn append_next(&mut self, buf: &mut Vec) -> Option { + let next = self.repdef.next()?; let control_word: u32 = (((next.0 & self.rep_mask) as u32) << self.def_width) + ((next.1 & self.def_mask) as u32); let control_word = control_word.to_le_bytes(); @@ -587,6 +1685,14 @@ impl> BinaryControlWordIterator { buf.push(control_word[1]); buf.push(control_word[2]); buf.push(control_word[3]); + let is_new_row = next.0 == self.max_rep; + let is_visible = next.1 <= self.max_visible_def; + let is_valid_item = next.1 == 0; + Some(ControlWordDesc { + is_new_row, + is_visible, + is_valid_item, + }) } } @@ -597,45 +1703,95 @@ pub struct UnaryControlWordIterator, W> { level_mask: u16, bits_rep: u8, bits_def: u8, + max_rep: u16, phantom: std::marker::PhantomData, } impl> UnaryControlWordIterator { - fn append_next(&mut self, buf: &mut Vec) { - let next = self.repdef.next().unwrap(); + fn append_next(&mut self, buf: &mut Vec) -> Option { + let next = self.repdef.next()?; buf.push((next & self.level_mask) as u8); + let is_new_row = self.max_rep == 0 || next == self.max_rep; + let is_valid_item = next == 0 || self.bits_def == 0; + Some(ControlWordDesc { + is_new_row, + // Either there is no rep, in which case there are no invisible items + // or there is no def, in which case there are no invisible items + is_visible: true, + is_valid_item, + }) } } impl> UnaryControlWordIterator { - fn append_next(&mut self, buf: &mut Vec) { + fn append_next(&mut self, buf: &mut Vec) -> Option { let next = self.repdef.next().unwrap() & self.level_mask; let control_word = next.to_le_bytes(); buf.push(control_word[0]); buf.push(control_word[1]); + let is_new_row = self.max_rep == 0 || next == self.max_rep; + let is_valid_item = next == 0 || self.bits_def == 0; + Some(ControlWordDesc { + is_new_row, + is_visible: true, + is_valid_item, + }) } } impl> UnaryControlWordIterator { - fn append_next(&mut self, buf: &mut Vec) { - let next = (self.repdef.next().unwrap() & self.level_mask) as u32; + fn append_next(&mut self, buf: &mut Vec) -> Option { + let next = self.repdef.next()?; + let next = (next & self.level_mask) as u32; let control_word = next.to_le_bytes(); buf.push(control_word[0]); buf.push(control_word[1]); buf.push(control_word[2]); buf.push(control_word[3]); + let is_new_row = self.max_rep == 0 || next as u16 == self.max_rep; + let is_valid_item = next == 0 || self.bits_def == 0; + Some(ControlWordDesc { + is_new_row, + is_visible: true, + is_valid_item, + }) } } /// A [`ControlWordIterator`] when there are no repetition or definition levels #[derive(Debug)] -pub struct NilaryControlWordIterator; +pub struct NilaryControlWordIterator { + len: usize, + idx: usize, +} + +impl NilaryControlWordIterator { + fn append_next(&mut self) -> Option { + if self.idx == self.len { + None + } else { + self.idx += 1; + Some(ControlWordDesc { + is_new_row: true, + is_visible: true, + is_valid_item: true, + }) + } + } +} /// Helper function to get a bit mask of the given width fn get_mask(width: u16) -> u16 { (1 << width) - 1 } +// We're really going out of our way to avoid boxing here but this will be called on a per-value basis +// so it is in the critical path. +type SpecificBinaryControlWordIterator<'a, T> = BinaryControlWordIterator< + Zip>, Copied>>, + T, +>; + /// An iterator that generates control words from repetition and definition levels /// /// "Control word" is just a fancy term for a single u8/u16/u32 that contains both @@ -646,19 +1802,29 @@ fn get_mask(width: u16) -> u16 { /// need two bytes. In the worst case we need 4 bytes though this suggests hundreds of /// levels of nesting which seems unlikely to encounter in practice. #[derive(Debug)] -pub enum ControlWordIterator { - Binary8(BinaryControlWordIterator, std::vec::IntoIter>, u8>), - Binary16(BinaryControlWordIterator, std::vec::IntoIter>, u16>), - Binary32(BinaryControlWordIterator, std::vec::IntoIter>, u32>), - Unary8(UnaryControlWordIterator, u8>), - Unary16(UnaryControlWordIterator, u16>), - Unary32(UnaryControlWordIterator, u32>), +pub enum ControlWordIterator<'a> { + Binary8(SpecificBinaryControlWordIterator<'a, u8>), + Binary16(SpecificBinaryControlWordIterator<'a, u16>), + Binary32(SpecificBinaryControlWordIterator<'a, u32>), + Unary8(UnaryControlWordIterator>, u8>), + Unary16(UnaryControlWordIterator>, u16>), + Unary32(UnaryControlWordIterator>, u32>), Nilary(NilaryControlWordIterator), } -impl ControlWordIterator { +/// Describes the properties of a control word +#[derive(Debug)] +pub struct ControlWordDesc { + pub is_new_row: bool, + pub is_visible: bool, + pub is_valid_item: bool, +} + +impl ControlWordIterator<'_> { /// Appends the next control word to the buffer - pub fn append_next(&mut self, buf: &mut Vec) { + /// + /// Returns true if this is the start of a new item (i.e. the repetition level is maxed out) + pub fn append_next(&mut self, buf: &mut Vec) -> Option { match self { Self::Binary8(iter) => iter.append_next(buf), Self::Binary16(iter) => iter.append_next(buf), @@ -666,7 +1832,18 @@ impl ControlWordIterator { Self::Unary8(iter) => iter.append_next(buf), Self::Unary16(iter) => iter.append_next(buf), Self::Unary32(iter) => iter.append_next(buf), - Self::Nilary(_) => {} + Self::Nilary(iter) => iter.append_next(), + } + } + + /// Return true if the control word iterator has repetition levels + pub fn has_repetition(&self) -> bool { + match self { + Self::Binary8(_) | Self::Binary16(_) | Self::Binary32(_) => true, + Self::Unary8(iter) => iter.bits_rep > 0, + Self::Unary16(iter) => iter.bits_rep > 0, + Self::Unary32(iter) => iter.bits_rep > 0, + Self::Nilary(_) => false, } } @@ -713,12 +1890,14 @@ impl ControlWordIterator { /// Builds a [`ControlWordIterator`] from repetition and definition levels /// by first calculating the width needed and then creating the iterator /// with the appropriate width -pub fn build_control_word_iterator( - rep: Option>, +pub fn build_control_word_iterator<'a>( + rep: Option<&'a [u16]>, max_rep: u16, - def: Option>, + def: Option<&'a [u16]>, max_def: u16, -) -> ControlWordIterator { + max_visible_def: u16, + len: usize, +) -> ControlWordIterator<'a> { let rep_width = if max_rep == 0 { 0 } else { @@ -734,7 +1913,7 @@ pub fn build_control_word_iterator( let total_width = rep_width + def_width; match (rep, def) { (Some(rep), Some(def)) => { - let iter = rep.into_iter().zip(def); + let iter = rep.iter().copied().zip(def.iter().copied()); let def_width = def_width as usize; if total_width <= 8 { ControlWordIterator::Binary8(BinaryControlWordIterator { @@ -742,6 +1921,8 @@ pub fn build_control_word_iterator( rep_mask, def_mask, def_width, + max_rep, + max_visible_def, bits_rep: rep_width as u8, bits_def: def_width as u8, phantom: std::marker::PhantomData, @@ -752,6 +1933,8 @@ pub fn build_control_word_iterator( rep_mask, def_mask, def_width, + max_rep, + max_visible_def, bits_rep: rep_width as u8, bits_def: def_width as u8, phantom: std::marker::PhantomData, @@ -762,6 +1945,8 @@ pub fn build_control_word_iterator( rep_mask, def_mask, def_width, + max_rep, + max_visible_def, bits_rep: rep_width as u8, bits_def: def_width as u8, phantom: std::marker::PhantomData, @@ -769,13 +1954,14 @@ pub fn build_control_word_iterator( } } (Some(lev), None) => { - let iter = lev.into_iter(); + let iter = lev.iter().copied(); if total_width <= 8 { ControlWordIterator::Unary8(UnaryControlWordIterator { repdef: iter, level_mask: rep_mask, bits_rep: total_width as u8, bits_def: 0, + max_rep, phantom: std::marker::PhantomData, }) } else if total_width <= 16 { @@ -784,6 +1970,7 @@ pub fn build_control_word_iterator( level_mask: rep_mask, bits_rep: total_width as u8, bits_def: 0, + max_rep, phantom: std::marker::PhantomData, }) } else { @@ -792,18 +1979,20 @@ pub fn build_control_word_iterator( level_mask: rep_mask, bits_rep: total_width as u8, bits_def: 0, + max_rep, phantom: std::marker::PhantomData, }) } } (None, Some(lev)) => { - let iter = lev.into_iter(); + let iter = lev.iter().copied(); if total_width <= 8 { ControlWordIterator::Unary8(UnaryControlWordIterator { repdef: iter, level_mask: def_mask, bits_rep: 0, bits_def: total_width as u8, + max_rep: 0, phantom: std::marker::PhantomData, }) } else if total_width <= 16 { @@ -812,6 +2001,7 @@ pub fn build_control_word_iterator( level_mask: def_mask, bits_rep: 0, bits_def: total_width as u8, + max_rep: 0, phantom: std::marker::PhantomData, }) } else { @@ -820,11 +2010,12 @@ pub fn build_control_word_iterator( level_mask: def_mask, bits_rep: 0, bits_def: total_width as u8, + max_rep: 0, phantom: std::marker::PhantomData, }) } } - (None, None) => ControlWordIterator::Nilary(NilaryControlWordIterator {}), + (None, None) => ControlWordIterator::Nilary(NilaryControlWordIterator { len, idx: 0 }), } } @@ -881,6 +2072,57 @@ impl ControlWordParser { } } + fn parse_desc_both( + src: &[u8], + bits_to_shift: u8, + mask_to_apply: u32, + max_rep: u16, + max_visible_def: u16, + ) -> ControlWordDesc { + match WORD_SIZE { + 1 => { + let word = src[0]; + let rep = word >> bits_to_shift; + let def = word & (mask_to_apply as u8); + let is_visible = def as u16 <= max_visible_def; + let is_new_row = rep as u16 == max_rep; + let is_valid_item = def == 0; + ControlWordDesc { + is_visible, + is_new_row, + is_valid_item, + } + } + 2 => { + let word = u16::from_le_bytes([src[0], src[1]]); + let rep = word >> bits_to_shift; + let def = word & mask_to_apply as u16; + let is_visible = def <= max_visible_def; + let is_new_row = rep == max_rep; + let is_valid_item = def == 0; + ControlWordDesc { + is_visible, + is_new_row, + is_valid_item, + } + } + 4 => { + let word = u32::from_le_bytes([src[0], src[1], src[2], src[3]]); + let rep = word >> bits_to_shift; + let def = word & mask_to_apply; + let is_visible = def as u16 <= max_visible_def; + let is_new_row = rep as u16 == max_rep; + let is_valid_item = def == 0; + ControlWordDesc { + is_visible, + is_new_row, + is_valid_item, + } + } + _ => unreachable!(), + } + } + fn parse_one(src: &[u8], dst: &mut Vec) { match WORD_SIZE { 1 => { @@ -899,6 +2141,48 @@ impl ControlWordParser { } } + fn parse_rep_desc_one(src: &[u8], max_rep: u16) -> ControlWordDesc { + match WORD_SIZE { + 1 => ControlWordDesc { + is_new_row: src[0] as u16 == max_rep, + is_visible: true, + is_valid_item: true, + }, + 2 => ControlWordDesc { + is_new_row: u16::from_le_bytes([src[0], src[1]]) == max_rep, + is_visible: true, + is_valid_item: true, + }, + 4 => ControlWordDesc { + is_new_row: u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as u16 == max_rep, + is_visible: true, + is_valid_item: true, + }, + _ => unreachable!(), + } + } + + fn parse_def_desc_one(src: &[u8]) -> ControlWordDesc { + match WORD_SIZE { + 1 => ControlWordDesc { + is_new_row: true, + is_visible: true, + is_valid_item: src[0] == 0, + }, + 2 => ControlWordDesc { + is_new_row: true, + is_visible: true, + is_valid_item: u16::from_le_bytes([src[0], src[1]]) == 0, + }, + 4 => ControlWordDesc { + is_new_row: true, + is_visible: true, + is_valid_item: u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as u16 == 0, + }, + _ => unreachable!(), + } + } + /// Returns the number of bytes per control word pub fn bytes_per_word(&self) -> usize { match self { @@ -942,6 +2226,57 @@ impl ControlWordParser { } } + /// Return true if the control words contain repetition information + pub fn has_rep(&self) -> bool { + match self { + Self::BOTH8(..) + | Self::BOTH16(..) + | Self::BOTH32(..) + | Self::REP8 + | Self::REP16 + | Self::REP32 => true, + Self::DEF8 | Self::DEF16 | Self::DEF32 | Self::NIL => false, + } + } + + /// Temporarily parses the control word to inspect its properties but does not append to any buffers + pub fn parse_desc(&self, src: &[u8], max_rep: u16, max_visible_def: u16) -> ControlWordDesc { + match self { + Self::BOTH8(bits_to_shift, mask_to_apply) => Self::parse_desc_both::<1>( + src, + *bits_to_shift, + *mask_to_apply, + max_rep, + max_visible_def, + ), + Self::BOTH16(bits_to_shift, mask_to_apply) => Self::parse_desc_both::<2>( + src, + *bits_to_shift, + *mask_to_apply, + max_rep, + max_visible_def, + ), + Self::BOTH32(bits_to_shift, mask_to_apply) => Self::parse_desc_both::<4>( + src, + *bits_to_shift, + *mask_to_apply, + max_rep, + max_visible_def, + ), + Self::REP8 => Self::parse_rep_desc_one::<1>(src, max_rep), + Self::REP16 => Self::parse_rep_desc_one::<2>(src, max_rep), + Self::REP32 => Self::parse_rep_desc_one::<4>(src, max_rep), + Self::DEF8 => Self::parse_def_desc_one::<1>(src), + Self::DEF16 => Self::parse_def_desc_one::<2>(src), + Self::DEF32 => Self::parse_def_desc_one::<4>(src), + Self::NIL => ControlWordDesc { + is_new_row: true, + is_valid_item: true, + is_visible: true, + }, + } + } + /// Creates a new parser from the number of bits used for the repetition and definition levels pub fn new(bits_rep: u8, bits_def: u8) -> Self { let total_bits = bits_rep + bits_def; @@ -981,7 +2316,9 @@ impl ControlWordParser { mod tests { use arrow_buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; - use crate::repdef::RepDefUnraveler; + use crate::repdef::{ + CompositeRepDefUnraveler, DefinitionInterpretation, RepDefUnraveler, SerializedRepDefs, + }; use super::RepDefBuilder; @@ -998,13 +2335,17 @@ mod tests { } #[test] - fn test_repdef() { + fn test_repdef_basic() { // Basic case, rep & def let mut builder = RepDefBuilder::default(); - builder.add_validity_bitmap(validity(&[true, false, true])); - builder.add_offsets(offsets_64(&[0, 2, 3, 5])); - builder.add_validity_bitmap(validity(&[true, true, true, false, true])); - builder.add_offsets(offsets_64(&[0, 1, 3, 5, 7, 9])); + builder.add_offsets( + offsets_64(&[0, 2, 2, 5]), + Some(validity(&[true, false, true])), + ); + builder.add_offsets( + offsets_64(&[0, 1, 3, 5, 5, 9]), + Some(validity(&[true, true, true, false, true])), + ); builder.add_validity_bitmap(validity(&[ true, true, true, false, false, false, true, true, false, ])); @@ -1013,71 +2354,317 @@ mod tests { let rep = repdefs.repetition_levels.unwrap(); let def = repdefs.definition_levels.unwrap(); - assert_eq!(vec![0, 0, 0, 3, 3, 2, 2, 0, 1], def); - assert_eq!(vec![2, 1, 0, 2, 0, 2, 0, 1, 0], rep); + assert_eq!(vec![0, 0, 0, 3, 1, 1, 2, 1, 0, 0, 1], *def); + assert_eq!(vec![2, 1, 0, 2, 2, 0, 1, 1, 0, 0, 0], *rep); - let mut unraveler = RepDefUnraveler::new(Some(rep), Some(def)); + let mut unraveler = CompositeRepDefUnraveler::new(vec![RepDefUnraveler::new( + Some(rep.as_ref().to_vec()), + Some(def.as_ref().to_vec()), + repdefs.def_meaning.into(), + )]); // Note: validity doesn't exactly round-trip because repdef normalizes some of the // redundant validity values assert_eq!( - unraveler.unravel_validity(), + unraveler.unravel_validity(9), Some(validity(&[ - true, true, true, false, false, false, false, true, false + true, true, true, false, false, false, true, true, false ])) ); + let (off, val) = unraveler.unravel_offsets::().unwrap(); + assert_eq!(off.inner(), offsets_32(&[0, 1, 3, 5, 5, 9]).inner()); + assert_eq!(val, Some(validity(&[true, true, true, false, true]))); + let (off, val) = unraveler.unravel_offsets::().unwrap(); + assert_eq!(off.inner(), offsets_32(&[0, 2, 2, 5]).inner()); + assert_eq!(val, Some(validity(&[true, false, true]))); + } + + #[test] + fn test_repdef_simple_null_empty_list() { + let check = |repdefs: SerializedRepDefs, last_def: DefinitionInterpretation| { + let rep = repdefs.repetition_levels.unwrap(); + let def = repdefs.definition_levels.unwrap(); + + assert_eq!([1, 0, 1, 1, 0, 0], *rep); + assert_eq!([0, 0, 2, 0, 1, 0], *def); + assert!(repdefs.special_records.is_empty()); + assert_eq!( + vec![DefinitionInterpretation::NullableItem, last_def,], + repdefs.def_meaning + ); + }; + + // Null list and empty list should be serialized mostly the same + + // Null case + let mut builder = RepDefBuilder::default(); + builder.add_offsets( + offsets_32(&[0, 2, 2, 5]), + Some(validity(&[true, false, true])), + ); + builder.add_validity_bitmap(validity(&[true, true, true, false, true])); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + check(repdefs, DefinitionInterpretation::NullableList); + + // Empty case + let mut builder = RepDefBuilder::default(); + builder.add_offsets(offsets_32(&[0, 2, 2, 5]), None); + builder.add_validity_bitmap(validity(&[true, true, true, false, true])); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + check(repdefs, DefinitionInterpretation::EmptyableList); + } + + #[test] + fn test_repdef_empty_list_at_end() { + // Regresses a failure we encountered when the last item was an empty list + let mut builder = RepDefBuilder::default(); + builder.add_offsets(offsets_32(&[0, 2, 5, 5]), None); + builder.add_validity_bitmap(validity(&[true, true, true, false, true])); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + let rep = repdefs.repetition_levels.unwrap(); + let def = repdefs.definition_levels.unwrap(); + + assert_eq!([1, 0, 1, 0, 0, 1], *rep); + assert_eq!([0, 0, 0, 1, 0, 2], *def); + assert!(repdefs.special_records.is_empty()); + assert_eq!( + vec![ + DefinitionInterpretation::NullableItem, + DefinitionInterpretation::EmptyableList, + ], + repdefs.def_meaning + ); + } + + #[test] + fn test_repdef_abnormal_nulls() { + // List nulls are allowed to have non-empty offsets and garbage values + // and the add_offsets call should normalize this + let mut builder = RepDefBuilder::default(); + builder.add_offsets( + offsets_32(&[0, 2, 5, 8]), + Some(validity(&[true, false, true])), + ); + // Note: we pass 5 here and not 8. If add_offsets tells us there is garbage nulls they + // should be removed before continuing + builder.add_no_null(5); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + let rep = repdefs.repetition_levels.unwrap(); + let def = repdefs.definition_levels.unwrap(); + + assert_eq!([1, 0, 1, 1, 0, 0], *rep); + assert_eq!([0, 0, 1, 0, 0, 0], *def); + + assert_eq!( + vec![ + DefinitionInterpretation::AllValidItem, + DefinitionInterpretation::NullableList, + ], + repdefs.def_meaning + ); + } + + #[test] + fn test_repdef_fsl() { + let mut builder = RepDefBuilder::default(); + builder.add_fsl(Some(validity(&[true, false])), 2, 2); + builder.add_fsl(None, 2, 4); + builder.add_validity_bitmap(validity(&[ + true, false, true, false, true, false, true, false, + ])); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + assert_eq!( + vec![ + DefinitionInterpretation::NullableItem, + DefinitionInterpretation::AllValidItem, + DefinitionInterpretation::NullableItem + ], + repdefs.def_meaning + ); + + assert!(repdefs.repetition_levels.is_none()); + + let def = repdefs.definition_levels.unwrap(); + + assert_eq!([0, 1, 0, 1, 2, 2, 2, 2], *def); + + let mut unraveler = CompositeRepDefUnraveler::new(vec![RepDefUnraveler::new( + None, + Some(def.as_ref().to_vec()), + repdefs.def_meaning.into(), + )]); + assert_eq!( - unraveler.unravel_offsets::().unwrap().inner(), - offsets_32(&[0, 1, 3, 5, 7, 9]).inner() + unraveler.unravel_validity(8), + Some(validity(&[ + true, false, true, false, false, false, false, false + ])) ); + assert_eq!(unraveler.unravel_fsl_validity(4, 2), None); assert_eq!( - unraveler.unravel_validity(), - Some(validity(&[true, true, false, false, true])) + unraveler.unravel_fsl_validity(2, 2), + Some(validity(&[true, false])) ); + } + + #[test] + fn test_repdef_fsl_allvalid_item() { + let mut builder = RepDefBuilder::default(); + builder.add_fsl(Some(validity(&[true, false])), 2, 2); + builder.add_fsl(None, 2, 4); + builder.add_no_null(8); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + assert_eq!( - unraveler.unravel_offsets::().unwrap().inner(), - offsets_32(&[0, 2, 3, 5]).inner() + vec![ + DefinitionInterpretation::AllValidItem, + DefinitionInterpretation::AllValidItem, + DefinitionInterpretation::NullableItem + ], + repdefs.def_meaning ); + + assert!(repdefs.repetition_levels.is_none()); + + let def = repdefs.definition_levels.unwrap(); + + assert_eq!([0, 0, 0, 0, 1, 1, 1, 1], *def); + + let mut unraveler = CompositeRepDefUnraveler::new(vec![RepDefUnraveler::new( + None, + Some(def.as_ref().to_vec()), + repdefs.def_meaning.into(), + )]); + + assert_eq!(unraveler.unravel_validity(8), None); + assert_eq!(unraveler.unravel_fsl_validity(4, 2), None); assert_eq!( - unraveler.unravel_validity(), - Some(validity(&[true, false, true])) + unraveler.unravel_fsl_validity(2, 2), + Some(validity(&[true, false])) ); } #[test] - fn test_repdef_all_valid() { + fn test_repdef_sliced_offsets() { + // Sliced lists may have offsets that don't start with zero. The + // add_offsets call needs to normalize these to operate correctly. let mut builder = RepDefBuilder::default(); - builder.add_no_null(3); - builder.add_offsets(offsets_64(&[0, 2, 3, 5])); + builder.add_offsets( + offsets_32(&[5, 7, 7, 10]), + Some(validity(&[true, false, true])), + ); builder.add_no_null(5); - builder.add_offsets(offsets_64(&[0, 1, 3, 5, 7, 9])); - builder.add_no_null(9); let repdefs = RepDefBuilder::serialize(vec![builder]); - let rep = repdefs.repetition_levels.unwrap(); - assert!(repdefs.definition_levels.is_none()); - assert_eq!(vec![2, 1, 0, 2, 0, 2, 0, 1, 0], rep); + let rep = repdefs.repetition_levels.unwrap(); + let def = repdefs.definition_levels.unwrap(); - let mut unraveler = RepDefUnraveler::new(Some(rep), None); + assert_eq!([1, 0, 1, 1, 0, 0], *rep); + assert_eq!([0, 0, 1, 0, 0, 0], *def); - assert_eq!(unraveler.unravel_validity(), None); assert_eq!( - unraveler.unravel_offsets::().unwrap().inner(), - offsets_32(&[0, 1, 3, 5, 7, 9]).inner() + vec![ + DefinitionInterpretation::AllValidItem, + DefinitionInterpretation::NullableList, + ], + repdefs.def_meaning ); - assert_eq!(unraveler.unravel_validity(), None); - assert_eq!( - unraveler.unravel_offsets::().unwrap().inner(), - offsets_32(&[0, 2, 3, 5]).inner() + } + + #[test] + fn test_repdef_complex_null_empty() { + let mut builder = RepDefBuilder::default(); + builder.add_offsets( + offsets_32(&[0, 4, 4, 4, 6]), + Some(validity(&[true, false, true, true])), ); - assert_eq!(unraveler.unravel_validity(), None); + builder.add_offsets( + offsets_32(&[0, 1, 1, 2, 2, 2, 3]), + Some(validity(&[true, false, true, false, true, true])), + ); + builder.add_no_null(3); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + let rep = repdefs.repetition_levels.unwrap(); + let def = repdefs.definition_levels.unwrap(); + + assert_eq!([2, 1, 1, 1, 2, 2, 2, 1], *rep); + assert_eq!([0, 1, 0, 1, 3, 4, 2, 0], *def); + } + + #[test] + fn test_repdef_empty_list_no_null() { + // Tests when we have some empty lists but no null lists. This case + // caused some bugs because we have definition but no nulls + let mut builder = RepDefBuilder::default(); + builder.add_offsets(offsets_32(&[0, 4, 4, 4, 6]), None); + builder.add_no_null(6); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + let rep = repdefs.repetition_levels.unwrap(); + let def = repdefs.definition_levels.unwrap(); + + assert_eq!([1, 0, 0, 0, 1, 1, 1, 0], *rep); + assert_eq!([0, 0, 0, 0, 1, 1, 0, 0], *def); + + let mut unraveler = CompositeRepDefUnraveler::new(vec![RepDefUnraveler::new( + Some(rep.as_ref().to_vec()), + Some(def.as_ref().to_vec()), + repdefs.def_meaning.into(), + )]); + + assert_eq!(unraveler.unravel_validity(6), None); + let (off, val) = unraveler.unravel_offsets::().unwrap(); + assert_eq!(off.inner(), offsets_32(&[0, 4, 4, 4, 6]).inner()); + assert_eq!(val, None); + } + + #[test] + fn test_repdef_all_valid() { + let mut builder = RepDefBuilder::default(); + builder.add_offsets(offsets_64(&[0, 2, 3, 5]), None); + builder.add_offsets(offsets_64(&[0, 1, 3, 5, 7, 9]), None); + builder.add_no_null(9); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + let rep = repdefs.repetition_levels.unwrap(); + assert!(repdefs.definition_levels.is_none()); + + assert_eq!([2, 1, 0, 2, 0, 2, 0, 1, 0], *rep); + + let mut unraveler = CompositeRepDefUnraveler::new(vec![RepDefUnraveler::new( + Some(rep.as_ref().to_vec()), + None, + repdefs.def_meaning.into(), + )]); + + assert_eq!(unraveler.unravel_validity(9), None); + let (off, val) = unraveler.unravel_offsets::().unwrap(); + assert_eq!(off.inner(), offsets_32(&[0, 1, 3, 5, 7, 9]).inner()); + assert_eq!(val, None); + let (off, val) = unraveler.unravel_offsets::().unwrap(); + assert_eq!(off.inner(), offsets_32(&[0, 2, 3, 5]).inner()); + assert_eq!(val, None); } #[test] fn test_repdef_no_rep() { let mut builder = RepDefBuilder::default(); - builder.add_no_null(3); + builder.add_no_null(5); builder.add_validity_bitmap(validity(&[false, false, true, true, true])); builder.add_validity_bitmap(validity(&[false, true, true, true, false])); @@ -1085,52 +2672,123 @@ mod tests { assert!(repdefs.repetition_levels.is_none()); let def = repdefs.definition_levels.unwrap(); - assert_eq!(vec![2, 2, 0, 0, 1], def); + assert_eq!([2, 2, 0, 0, 1], *def); - let mut unraveler = RepDefUnraveler::new(None, Some(def)); + let mut unraveler = CompositeRepDefUnraveler::new(vec![RepDefUnraveler::new( + None, + Some(def.as_ref().to_vec()), + repdefs.def_meaning.into(), + )]); assert_eq!( - unraveler.unravel_validity(), + unraveler.unravel_validity(5), Some(validity(&[false, false, true, true, false])) ); assert_eq!( - unraveler.unravel_validity(), + unraveler.unravel_validity(5), Some(validity(&[false, false, true, true, true])) ); - assert_eq!(unraveler.unravel_validity(), None); + assert_eq!(unraveler.unravel_validity(5), None); + } + + #[test] + fn test_composite_unravel() { + let mut builder = RepDefBuilder::default(); + builder.add_offsets( + offsets_64(&[0, 2, 2, 5]), + Some(validity(&[true, false, true])), + ); + let repdef1 = RepDefBuilder::serialize(vec![builder]); + + let mut builder = RepDefBuilder::default(); + builder.add_offsets(offsets_64(&[0, 1, 3, 5, 7, 9]), None); + let repdef2 = RepDefBuilder::serialize(vec![builder]); + + let unravel1 = RepDefUnraveler::new( + repdef1.repetition_levels.map(|l| l.to_vec()), + repdef1.definition_levels.map(|l| l.to_vec()), + repdef1.def_meaning.into(), + ); + let unravel2 = RepDefUnraveler::new( + repdef2.repetition_levels.map(|l| l.to_vec()), + repdef2.definition_levels.map(|l| l.to_vec()), + repdef2.def_meaning.into(), + ); + + let mut unraveler = CompositeRepDefUnraveler::new(vec![unravel1, unravel2]); + + let (off, val) = unraveler.unravel_offsets::().unwrap(); + assert_eq!( + off.inner(), + offsets_32(&[0, 2, 2, 5, 6, 8, 10, 12, 14]).inner() + ); + assert_eq!( + val, + Some(validity(&[true, false, true, true, true, true, true, true])) + ); } #[test] fn test_repdef_multiple_builders() { // Basic case, rep & def let mut builder1 = RepDefBuilder::default(); - builder1.add_validity_bitmap(validity(&[true])); - builder1.add_offsets(offsets_64(&[0, 2])); - builder1.add_validity_bitmap(validity(&[true, true])); - builder1.add_offsets(offsets_64(&[0, 1, 3])); + builder1.add_offsets(offsets_64(&[0, 2]), None); + builder1.add_offsets(offsets_64(&[0, 1, 3]), None); builder1.add_validity_bitmap(validity(&[true, true, true])); let mut builder2 = RepDefBuilder::default(); - builder2.add_validity_bitmap(validity(&[false, true])); - builder2.add_offsets(offsets_64(&[0, 1, 3])); - builder2.add_validity_bitmap(validity(&[true, false, true])); - builder2.add_offsets(offsets_64(&[0, 2, 4, 6])); + builder2.add_offsets(offsets_64(&[0, 0, 3]), Some(validity(&[false, true]))); + builder2.add_offsets( + offsets_64(&[0, 2, 2, 6]), + Some(validity(&[true, false, true])), + ); builder2.add_validity_bitmap(validity(&[false, false, false, true, true, false])); let repdefs = RepDefBuilder::serialize(vec![builder1, builder2]); + let rep = repdefs.repetition_levels.unwrap(); let def = repdefs.definition_levels.unwrap(); - assert_eq!(vec![2, 1, 0, 2, 0, 2, 0, 1, 0], rep); - assert_eq!(vec![0, 0, 0, 3, 3, 2, 2, 0, 1], def); + assert_eq!([2, 1, 0, 2, 2, 0, 1, 1, 0, 0, 0], *rep); + assert_eq!([0, 0, 0, 3, 1, 1, 2, 1, 0, 0, 1], *def); + } + + #[test] + fn test_slicer() { + let mut builder = RepDefBuilder::default(); + builder.add_offsets( + offsets_64(&[0, 2, 2, 30, 30]), + Some(validity(&[true, false, true, true])), + ); + builder.add_no_null(30); + + let repdefs = RepDefBuilder::serialize(vec![builder]); + + let mut rep_slicer = repdefs.rep_slicer().unwrap(); + + // First 5 items include a null list so we get 6 levels (12 bytes) + assert_eq!(rep_slicer.slice_next(5).len(), 12); + // Next 20 are all plain + assert_eq!(rep_slicer.slice_next(20).len(), 40); + // Last 5 include an empty list so we get 6 levels (12 bytes) + assert_eq!(rep_slicer.slice_rest().len(), 12); + + let mut def_slicer = repdefs.rep_slicer().unwrap(); + + // First 5 items include a null list so we get 6 levels (12 bytes) + assert_eq!(def_slicer.slice_next(5).len(), 12); + // Next 20 are all plain + assert_eq!(def_slicer.slice_next(20).len(), 40); + // Last 5 include an empty list so we get 6 levels (12 bytes) + assert_eq!(def_slicer.slice_rest().len(), 12); } #[test] fn test_control_words() { // Convert to control words, verify expected, convert back, verify same as original fn check( - rep: Vec, - def: Vec, + rep: &[u16], + def: &[u16], expected_values: Vec, expected_bytes_per_word: usize, expected_bits_rep: u8, @@ -1140,18 +2798,17 @@ mod tests { let max_rep = rep.iter().max().copied().unwrap_or(0); let max_def = def.iter().max().copied().unwrap_or(0); - let in_rep = if rep.is_empty() { - None - } else { - Some(rep.clone()) - }; - let in_def = if def.is_empty() { - None - } else { - Some(def.clone()) - }; - - let mut iter = super::build_control_word_iterator(in_rep, max_rep, in_def, max_def); + let in_rep = if rep.is_empty() { None } else { Some(rep) }; + let in_def = if def.is_empty() { None } else { Some(def) }; + + let mut iter = super::build_control_word_iterator( + in_rep, + max_rep, + in_def, + max_def, + max_def + 1, + expected_values.len(), + ); assert_eq!(iter.bytes_per_word(), expected_bytes_per_word); assert_eq!(iter.bits_rep(), expected_bits_rep); assert_eq!(iter.bits_def(), expected_bits_def); @@ -1160,6 +2817,7 @@ mod tests { for _ in 0..num_vals { iter.append_next(&mut cw_vec); } + assert!(iter.append_next(&mut cw_vec).is_none()); assert_eq!(expected_values, cw_vec); @@ -1174,13 +2832,13 @@ mod tests { } } - assert_eq!(rep, rep_out); - assert_eq!(def, def_out); + assert_eq!(rep, rep_out.as_slice()); + assert_eq!(def, def_out.as_slice()); } // Each will need 4 bits and so we should get 1-byte control words - let rep = vec![0_u16, 7, 3, 2, 9, 8, 12, 5]; - let def = vec![5_u16, 3, 1, 2, 12, 15, 0, 2]; + let rep = &[0_u16, 7, 3, 2, 9, 8, 12, 5]; + let def = &[5_u16, 3, 1, 2, 12, 15, 0, 2]; let expected = vec![ 0b00000101, // 0, 5 0b01110011, // 7, 3 @@ -1194,8 +2852,8 @@ mod tests { check(rep, def, expected, 1, 4, 4); // Now we need 5 bits for def so we get 2-byte control words - let rep = vec![0_u16, 7, 3, 2, 9, 8, 12, 5]; - let def = vec![5_u16, 3, 1, 2, 12, 22, 0, 2]; + let rep = &[0_u16, 7, 3, 2, 9, 8, 12, 5]; + let def = &[5_u16, 3, 1, 2, 12, 22, 0, 2]; let expected = vec![ 0b00000101, 0b00000000, // 0, 5 0b11100011, 0b00000000, // 7, 3 @@ -1209,7 +2867,7 @@ mod tests { check(rep, def, expected, 2, 4, 5); // Just rep, 4 bits so 1 byte each - let levels = vec![0_u16, 7, 3, 2, 9, 8, 12, 5]; + let levels = &[0_u16, 7, 3, 2, 9, 8, 12, 5]; let expected = vec![ 0b00000000, // 0 0b00000111, // 7 @@ -1220,12 +2878,92 @@ mod tests { 0b00001100, // 12 0b00000101, // 5 ]; - check(levels.clone(), Vec::default(), expected.clone(), 1, 4, 0); + check(levels, &[], expected.clone(), 1, 4, 0); // Just def - check(Vec::default(), levels, expected, 1, 0, 4); + check(&[], levels, expected, 1, 0, 4); // No rep, no def, no bytes - check(Vec::default(), Vec::default(), Vec::default(), 0, 0, 0); + check(&[], &[], Vec::default(), 0, 0, 0); + } + + #[test] + fn test_control_words_rep_index() { + fn check( + rep: &[u16], + def: &[u16], + expected_new_rows: Vec, + expected_is_visible: Vec, + ) { + let num_vals = rep.len().max(def.len()); + let max_rep = rep.iter().max().copied().unwrap_or(0); + let max_def = def.iter().max().copied().unwrap_or(0); + + let in_rep = if rep.is_empty() { None } else { Some(rep) }; + let in_def = if def.is_empty() { None } else { Some(def) }; + + let mut iter = super::build_control_word_iterator( + in_rep, + max_rep, + in_def, + max_def, + /*max_visible_def=*/ 2, + expected_new_rows.len(), + ); + + let mut cw_vec = Vec::with_capacity(num_vals * iter.bytes_per_word()); + let mut expected_new_rows = expected_new_rows.iter().copied(); + let mut expected_is_visible = expected_is_visible.iter().copied(); + for _ in 0..expected_new_rows.len() { + let word_desc = iter.append_next(&mut cw_vec).unwrap(); + assert_eq!(word_desc.is_new_row, expected_new_rows.next().unwrap()); + assert_eq!(word_desc.is_visible, expected_is_visible.next().unwrap()); + } + assert!(iter.append_next(&mut cw_vec).is_none()); + } + + // 2 means new list + let rep = &[2_u16, 1, 0, 2, 2, 0, 1, 1, 0, 2, 0]; + // These values don't matter for this test + let def = &[0_u16, 0, 0, 3, 1, 1, 2, 1, 0, 0, 1]; + + // Rep & def + check( + rep, + def, + vec![ + true, false, false, true, true, false, false, false, false, true, false, + ], + vec![ + true, true, true, false, true, true, true, true, true, true, true, + ], + ); + // Rep only + check( + rep, + &[], + vec![ + true, false, false, true, true, false, false, false, false, true, false, + ], + vec![true; 11], + ); + // No repetition + check( + &[], + def, + vec![ + true, true, true, true, true, true, true, true, true, true, true, + ], + vec![true; 11], + ); + // No repetition, no definition + check( + &[], + &[], + vec![ + true, true, true, true, true, true, true, true, true, true, true, + ], + vec![true; 11], + ); } } diff --git a/rust/lance-encoding/src/statistics.rs b/rust/lance-encoding/src/statistics.rs index 21116a52db4..675a3cbfd95 100644 --- a/rust/lance-encoding/src/statistics.rs +++ b/rust/lance-encoding/src/statistics.rs @@ -2,18 +2,19 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::{ - fmt, + fmt::{self}, hash::{Hash, RandomState}, sync::Arc, }; -use arrow_array::{Array, UInt64Array}; +use arrow::{array::AsArray, datatypes::UInt64Type}; +use arrow_array::{Array, ArrowPrimitiveType, UInt64Array}; use hyperloglogplus::{HyperLogLog, HyperLogLogPlus}; use num_traits::PrimInt; use crate::data::{ - AllNullDataBlock, DataBlock, DictionaryDataBlock, FixedWidthDataBlock, OpaqueBlock, - StructDataBlock, VariableWidthBlock, + AllNullDataBlock, DataBlock, DictionaryDataBlock, FixedSizeListBlock, FixedWidthDataBlock, + NullableDataBlock, OpaqueBlock, StructDataBlock, VariableWidthBlock, }; #[derive(Clone, Copy, PartialEq, Eq, Hash)] @@ -57,10 +58,10 @@ impl ComputeStat for DataBlock { Self::AllNull(_) => {} Self::Nullable(data_block) => data_block.data.compute_stat(), Self::FixedWidth(data_block) => data_block.compute_stat(), - Self::FixedSizeList(_) => {} + Self::FixedSizeList(data_block) => data_block.compute_stat(), Self::VariableWidth(data_block) => data_block.compute_stat(), Self::Opaque(data_block) => data_block.compute_stat(), - Self::Struct(_) => {} + Self::Struct(data_block) => data_block.compute_stat(), Self::Dictionary(_) => {} } } @@ -111,8 +112,19 @@ impl ComputeStat for FixedWidthDataBlock { if let Some(cardinality_array) = cardidinality_array { info.insert(Stat::Cardinality, cardinality_array); } + } +} - // TODO(broccoliSpicy): We also need to consider FixedSizeList here +impl ComputeStat for FixedSizeListBlock { + fn compute_stat(&mut self) { + // We leave the child stats unchanged. This may seem odd (e.g. should bit width be the + // bit width of the child * dimension?) but it's because we use these stats to determine + // compression and we are currently just compressing the child data. + // + // There is a potential opportunity here to do better. For example, if we have a FSL of + // 4 32-bit integers then we should probably treat them as a single 128-bit integer or maybe + // even 4 columns of 32-bit integers. This might yield better compression. + self.child.compute_stat(); } } @@ -126,8 +138,25 @@ impl ComputeStat for OpaqueBlock { } } -pub trait GetStat { +pub trait GetStat: fmt::Debug { fn get_stat(&self, stat: Stat) -> Option>; + + fn expect_stat(&self, stat: Stat) -> Arc { + self.get_stat(stat) + .unwrap_or_else(|| panic!("{:?} DataBlock does not have `{}` statistics.", self, stat)) + } + + fn expect_single_stat(&self, stat: Stat) -> T::Native { + let stat_value = self.expect_stat(stat); + let stat_value = stat_value.as_primitive::(); + if stat_value.len() != 1 { + panic!( + "{:?} DataBlock does not have exactly one value for `{} statistics.", + self, stat + ); + } + stat_value.value(0) + } } impl GetStat for DataBlock { @@ -136,9 +165,9 @@ impl GetStat for DataBlock { Self::Empty() => None, Self::Constant(_) => None, Self::AllNull(data_block) => data_block.get_stat(stat), - Self::Nullable(data_block) => data_block.data.get_stat(stat), + Self::Nullable(data_block) => data_block.get_stat(stat), Self::FixedWidth(data_block) => data_block.get_stat(stat), - Self::FixedSizeList(_) => None, + Self::FixedSizeList(data_block) => data_block.get_stat(stat), Self::VariableWidth(data_block) => data_block.get_stat(stat), Self::Opaque(data_block) => data_block.get_stat(stat), Self::Struct(data_block) => data_block.get_stat(stat), @@ -147,17 +176,37 @@ impl GetStat for DataBlock { } } +// NullableDataBlock will be deprecated in Lance 2.1. +impl GetStat for NullableDataBlock { + // This function simply returns the statistics of the inner `DataBlock` of `NullableDataBlock`, + // this is not accurate but `NullableDataBlock` is going to be deprecated in Lance 2.1 anyway. + fn get_stat(&self, stat: Stat) -> Option> { + self.data.get_stat(stat) + } +} + impl GetStat for VariableWidthBlock { fn get_stat(&self, stat: Stat) -> Option> { + let block_info = self.block_info.0.read().unwrap(); + + if block_info.is_empty() { + panic!("get_stat should be called after statistics are computed."); + } + block_info.get(&stat).cloned() + } +} + +impl GetStat for FixedSizeListBlock { + fn get_stat(&self, stat: Stat) -> Option> { + let child_stat = self.child.get_stat(stat); match stat { - Stat::BitWidth => None, - Stat::NullCount => None, - _ => { - if self.block_info.0.read().unwrap().is_empty() { - panic!("get_stat should be called after statistics are computed"); - } - self.block_info.0.read().unwrap().get(&stat).cloned() - } + Stat::MaxLength => child_stat.map(|max_length| { + // this is conservative when working with variable length data as we shouldn't assume + // that we have a list of all max-length elements but it's cheap and easy to calculate + let max_length = max_length.as_primitive::().value(0); + Arc::new(UInt64Array::from(vec![max_length * self.dimension])) as Arc + }), + _ => child_stat, } } } @@ -248,15 +297,12 @@ impl GetStat for AllNullDataBlock { impl GetStat for FixedWidthDataBlock { fn get_stat(&self, stat: Stat) -> Option> { - match stat { - Stat::NullCount => None, - _ => { - if self.block_info.0.read().unwrap().is_empty() { - panic!("get_stat should be called after statistics are computed"); - } - self.block_info.0.read().unwrap().get(&stat).cloned() - } + let block_info = self.block_info.0.read().unwrap(); + + if block_info.is_empty() { + panic!("get_stat should be called after statistics are computed."); } + block_info.get(&stat).cloned() } } @@ -335,10 +381,12 @@ impl FixedWidthDataBlock { impl GetStat for OpaqueBlock { fn get_stat(&self, stat: Stat) -> Option> { - match stat { - Stat::DataSize => self.block_info.0.read().unwrap().get(&stat).cloned(), - _ => None, + let block_info = self.block_info.0.read().unwrap(); + + if block_info.is_empty() { + panic!("get_stat should be called after statistics are computed."); } + block_info.get(&stat).cloned() } } @@ -349,8 +397,30 @@ impl GetStat for DictionaryDataBlock { } impl GetStat for StructDataBlock { - fn get_stat(&self, _stat: Stat) -> Option> { - None + fn get_stat(&self, stat: Stat) -> Option> { + let block_info = self.block_info.0.read().unwrap(); + if block_info.is_empty() { + panic!("get_stat should be called after statistics are computed.") + } + block_info.get(&stat).cloned() + } +} + +impl ComputeStat for StructDataBlock { + fn compute_stat(&mut self) { + let data_size = self.data_size(); + let data_size_array = Arc::new(UInt64Array::from(vec![data_size])); + + let max_len = self + .children + .iter() + .map(|child| child.expect_single_stat::(Stat::MaxLength)) + .sum::(); + let max_len_array = Arc::new(UInt64Array::from(vec![max_len])); + + let mut info = self.block_info.0.write().unwrap(); + info.insert(Stat::DataSize, data_size_array); + info.insert(Stat::MaxLength, max_len_array); } } @@ -371,7 +441,11 @@ mod tests { use super::DataBlock; - use arrow::{compute::concat, datatypes::Int32Type}; + use arrow::{ + array::AsArray, + compute::concat, + datatypes::{Int32Type, UInt64Type}, + }; use arrow_array::Array; #[test] fn test_data_size_stat() { @@ -389,18 +463,7 @@ mod tests { ]) .unwrap(); - let data_size_array = block.get_stat(Stat::DataSize).unwrap_or_else(|| { - panic!( - "A data block of type: {} should have valid {} statistics", - block.name(), - Stat::DataSize - ) - }); - let data_size = data_size_array - .as_any() - .downcast_ref::() - .unwrap() - .value(0); + let data_size = block.expect_single_stat::(Stat::DataSize); let total_buffer_size: usize = concatenated_array .to_data() @@ -414,19 +477,8 @@ mod tests { let mut gen = lance_datagen::array::rand_type(&DataType::Binary); let arr = gen.generate(RowCount::from(3), &mut rng).unwrap(); let block = DataBlock::from_array(arr.clone()); - let data_size_array = block.get_stat(Stat::DataSize).unwrap_or_else(|| { - panic!( - "A data block of type: {} should have valid {} statistics", - block.name(), - Stat::DataSize - ) - }); - - let data_size = data_size_array - .as_any() - .downcast_ref::() - .unwrap() - .value(0); + let data_size = block.expect_single_stat::(Stat::DataSize); + let total_buffer_size: usize = arr .to_data() .buffers() @@ -439,22 +491,25 @@ mod tests { let fields = vec![ Arc::new(Field::new("int_field", DataType::Int32, false)), Arc::new(Field::new("float_field", DataType::Float32, false)), - Arc::new(Field::new( - "fsl_field", - DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 5), - false, - )), ] .into(); let mut gen = lance_datagen::array::rand_type(&DataType::Struct(fields)); let arr = gen.generate(RowCount::from(3), &mut rng).unwrap(); let block = DataBlock::from_array(arr.clone()); - assert!( - block.get_stat(Stat::DataSize).is_none(), - "Expected Stat::DataSize to be None for data block of type: {}", - block.name() - ); + let (_, arr_parts, _) = arr.as_struct().clone().into_parts(); + let total_buffer_size: usize = arr_parts + .iter() + .map(|arr| { + arr.to_data() + .buffers() + .iter() + .map(|buffer| buffer.len()) + .sum::() + }) + .sum(); + let data_size = block.expect_single_stat::(Stat::DataSize); + assert!(data_size == total_buffer_size as u64); // test DataType::Dictionary let mut gen = array::rand_type(&DataType::Dictionary( @@ -463,635 +518,344 @@ mod tests { )); let arr = gen.generate(RowCount::from(3), &mut rng).unwrap(); let block = DataBlock::from_array(arr.clone()); - assert!( - block.get_stat(Stat::DataSize).is_none(), - "Expected Stat::DataSize to be None for data block of type: {}", - block.name() - ); + assert!(block.get_stat(Stat::DataSize).is_none()); let mut gen = array::rand::().with_nulls(&[false, true, false]); let arr = gen.generate(RowCount::from(3), &mut rng).unwrap(); let block = DataBlock::from_array(arr.clone()); - let data_size_array = block.get_stat(Stat::DataSize).unwrap_or_else(|| { - panic!( - "A data block of type: {} should have valid {} statistics", - block.name(), - Stat::DataSize - ) - }); - let data_size = data_size_array - .as_any() - .downcast_ref::() - .unwrap() - .value(0); + let data_size = block.expect_single_stat::(Stat::DataSize); let total_buffer_size: usize = arr .to_data() .buffers() .iter() .map(|buffer| buffer.len()) .sum(); + assert!(data_size == total_buffer_size as u64); } #[test] fn test_bit_width_stat_for_integers() { let int8_array = Int8Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int8_array.clone()); + let array_ref: ArrayRef = Arc::new(int8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); + let actual_bit_width = block.expect_stat(Stat::BitWidth); - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int8_array - ); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref(),); let int8_array = Int8Array::from(vec![0x1, 0x2, 0x3, 0x7F]); - let array_ref: ArrayRef = Arc::new(int8_array.clone()); + let array_ref: ArrayRef = Arc::new(int8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![7])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int8_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref(),); let int8_array = Int8Array::from(vec![0x1, 0x2, 0x3, 0xF, 0x1F]); - let array_ref: ArrayRef = Arc::new(int8_array.clone()); + let array_ref: ArrayRef = Arc::new(int8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![5])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int8_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref(),); let int8_array = Int8Array::from(vec![-1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int8_array.clone()); + let array_ref: ArrayRef = Arc::new(int8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![8])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int8_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int16_array = Int16Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int16_array.clone()); + let array_ref: ArrayRef = Arc::new(int16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int16_array = Int16Array::from(vec![0x1, 0x2, 0x3, 0x7F]); - let array_ref: ArrayRef = Arc::new(int16_array.clone()); + let array_ref: ArrayRef = Arc::new(int16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![7])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int16_array = Int16Array::from(vec![0x1, 0x2, 0x3, 0xFF]); - let array_ref: ArrayRef = Arc::new(int16_array.clone()); + let array_ref: ArrayRef = Arc::new(int16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![8])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int16_array = Int16Array::from(vec![0x1, 0x2, 0x3, 0x1FF]); - let array_ref: ArrayRef = Arc::new(int16_array.clone()); + let array_ref: ArrayRef = Arc::new(int16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![9])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); + let int16_array = Int16Array::from(vec![0x1, 0x2, 0x3, 0xF, 0x1F]); - let array_ref: ArrayRef = Arc::new(int16_array.clone()); + let array_ref: ArrayRef = Arc::new(int16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![5])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int16_array = Int16Array::from(vec![-1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int16_array.clone()); + let array_ref: ArrayRef = Arc::new(int16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![16])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int32_array = Int32Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int32_array.clone()); + let array_ref: ArrayRef = Arc::new(int32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int32_array = Int32Array::from(vec![0x1, 0x2, 0x3, 0xFF]); - let array_ref: ArrayRef = Arc::new(int32_array.clone()); + let array_ref: ArrayRef = Arc::new(int32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![8])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int32_array = Int32Array::from(vec![0x1, 0x2, 0x3, 0xFF, 0x1FF]); - let array_ref: ArrayRef = Arc::new(int32_array.clone()); + let array_ref: ArrayRef = Arc::new(int32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![9])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int32_array = Int32Array::from(vec![-1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int32_array.clone()); + let array_ref: ArrayRef = Arc::new(int32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![32])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int32_array = Int32Array::from(vec![-1, 2, 3, -88]); - let array_ref: ArrayRef = Arc::new(int32_array.clone()); + let array_ref: ArrayRef = Arc::new(int32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![32])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int64_array = Int64Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int64_array.clone()); + let array_ref: ArrayRef = Arc::new(int64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int64_array = Int64Array::from(vec![0x1, 0x2, 0x3, 0xFF]); - let array_ref: ArrayRef = Arc::new(int64_array.clone()); + let array_ref: ArrayRef = Arc::new(int64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![8])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int64_array = Int64Array::from(vec![0x1, 0x2, 0x3, 0xFF, 0x1FF]); - let array_ref: ArrayRef = Arc::new(int64_array.clone()); + let array_ref: ArrayRef = Arc::new(int64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![9])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int64_array = Int64Array::from(vec![-1, 2, 3]); - let array_ref: ArrayRef = Arc::new(int64_array.clone()); + let array_ref: ArrayRef = Arc::new(int64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![64])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let int64_array = Int64Array::from(vec![-1, 2, 3, -88]); - let array_ref: ArrayRef = Arc::new(int64_array.clone()); + let array_ref: ArrayRef = Arc::new(int64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![64])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - int64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint8_array = UInt8Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(uint8_array.clone()); + let array_ref: ArrayRef = Arc::new(uint8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint8_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint8_array = UInt8Array::from(vec![0x1, 0x2, 0x3, 0x7F]); - let array_ref: ArrayRef = Arc::new(uint8_array.clone()); + let array_ref: ArrayRef = Arc::new(uint8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![7])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint8_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint8_array = UInt8Array::from(vec![0x1, 0x2, 0x3, 0xF, 0x1F]); - let array_ref: ArrayRef = Arc::new(uint8_array.clone()); + let array_ref: ArrayRef = Arc::new(uint8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![5])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint8_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint8_array = UInt8Array::from(vec![1, 2, 3, 0xF]); - let array_ref: ArrayRef = Arc::new(uint8_array.clone()); + let array_ref: ArrayRef = Arc::new(uint8_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![4])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint8_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint16_array = UInt16Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(uint16_array.clone()); + let array_ref: ArrayRef = Arc::new(uint16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint16_array = UInt16Array::from(vec![0x1, 0x2, 0x3, 0x7F]); - let array_ref: ArrayRef = Arc::new(uint16_array.clone()); + let array_ref: ArrayRef = Arc::new(uint16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![7])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint16_array = UInt16Array::from(vec![0x1, 0x2, 0x3, 0xFF]); - let array_ref: ArrayRef = Arc::new(uint16_array.clone()); + let array_ref: ArrayRef = Arc::new(uint16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![8])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint16_array = UInt16Array::from(vec![0x1, 0x2, 0x3, 0x1FF]); - let array_ref: ArrayRef = Arc::new(uint16_array.clone()); + let array_ref: ArrayRef = Arc::new(uint16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![9])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); + let uint16_array = UInt16Array::from(vec![0x1, 0x2, 0x3, 0xF, 0x1F]); - let array_ref: ArrayRef = Arc::new(uint16_array.clone()); + let array_ref: ArrayRef = Arc::new(uint16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![5])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint16_array = UInt16Array::from(vec![1, 2, 3, 0xFFFF]); - let array_ref: ArrayRef = Arc::new(uint16_array.clone()); + let array_ref: ArrayRef = Arc::new(uint16_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![16])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint16_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint32_array = UInt32Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(uint32_array.clone()); + let array_ref: ArrayRef = Arc::new(uint32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint32_array = UInt32Array::from(vec![0x1, 0x2, 0x3, 0xFF]); - let array_ref: ArrayRef = Arc::new(uint32_array.clone()); + let array_ref: ArrayRef = Arc::new(uint32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![8])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref(),); let uint32_array = UInt32Array::from(vec![0x1, 0x2, 0x3, 0xFF, 0x1FF]); - let array_ref: ArrayRef = Arc::new(uint32_array.clone()); + let array_ref: ArrayRef = Arc::new(uint32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![9])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint32_array = UInt32Array::from(vec![1, 2, 3, 0xF]); - let array_ref: ArrayRef = Arc::new(uint32_array.clone()); + let array_ref: ArrayRef = Arc::new(uint32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![4])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint32_array = UInt32Array::from(vec![1, 2, 3, 0x77]); - let array_ref: ArrayRef = Arc::new(uint32_array.clone()); + let array_ref: ArrayRef = Arc::new(uint32_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![7])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint32_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint64_array = UInt64Array::from(vec![1, 2, 3]); - let array_ref: ArrayRef = Arc::new(uint64_array.clone()); + let array_ref: ArrayRef = Arc::new(uint64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint64_array = UInt64Array::from(vec![0x1, 0x2, 0x3, 0xFF]); - let array_ref: ArrayRef = Arc::new(uint64_array.clone()); + let array_ref: ArrayRef = Arc::new(uint64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![8])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint64_array = UInt64Array::from(vec![0x1, 0x2, 0x3, 0xFF, 0x1FF]); - let array_ref: ArrayRef = Arc::new(uint64_array.clone()); + let array_ref: ArrayRef = Arc::new(uint64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![9])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint64_array = UInt64Array::from(vec![0, 2, 3, 0xFFFF]); - let array_ref: ArrayRef = Arc::new(uint64_array.clone()); + let array_ref: ArrayRef = Arc::new(uint64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![16])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); let uint64_array = UInt64Array::from(vec![1, 2, 3, 0xFFFF_FFFF_FFFF_FFFF]); - let array_ref: ArrayRef = Arc::new(uint64_array.clone()); + let array_ref: ArrayRef = Arc::new(uint64_array); let block = DataBlock::from_array(array_ref); let expected_bit_width = Arc::new(UInt64Array::from(vec![64])) as ArrayRef; - let actual_bit_width = block.get_stat(Stat::BitWidth); - - assert_eq!( - actual_bit_width, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - uint64_array - ); + let actual_bit_width = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_width.as_ref(), expected_bit_width.as_ref()); } #[test] @@ -1119,14 +883,8 @@ mod tests { 4, (data_type.byte_width() * 8) as u64, ])) as ArrayRef; - let actual_bit_widths = block.get_stat(Stat::BitWidth); - assert_eq!( - actual_bit_widths, - Some(expected_bit_width.clone()), - "Expected Stat::BitWidth to be {:?} for data block generated from array: {:?}", - expected_bit_width, - concatenated - ); + let actual_bit_widths = block.expect_stat(Stat::BitWidth); + assert_eq!(actual_bit_widths.as_ref(), expected_bit_width.as_ref(),); } } @@ -1136,121 +894,72 @@ mod tests { let mut gen = lance_datagen::array::rand_type(&DataType::Binary); let arr = gen.generate(RowCount::from(3), &mut rng).unwrap(); let block = DataBlock::from_array(arr.clone()); - assert_eq!( - block.get_stat(Stat::BitWidth), - None, - "Expected Stat::BitWidth to be None for data block: {:?}", - block.name() - ); + assert!(block.get_stat(Stat::BitWidth).is_none(),); } #[test] fn test_cardinality_variable_width_datablock() { let string_array = StringArray::from(vec![Some("hello"), Some("world")]); - let block = DataBlock::from_array(string_array.clone()); - let expected_cardinality = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_cardinality = block.get_stat(Stat::Cardinality); - - assert_eq!( - actual_cardinality, - Some(expected_cardinality.clone()), - "Expected Stat::Cardinality to be {:?} for data block generated from array: {:?}", - expected_cardinality, - string_array, - ); + let block = DataBlock::from_array(string_array); + let expected_cardinality = 2; + let actual_cardinality = block.expect_single_stat::(Stat::Cardinality); + assert_eq!(actual_cardinality, expected_cardinality,); let string_array = StringArray::from(vec![ Some("to be named by variables"), Some("to be passed as arguments to procedures"), Some("to be returned as values of procedures"), ]); - let block = DataBlock::from_array(string_array.clone()); - let expected_cardinality = Arc::new(UInt64Array::from(vec![3])) as ArrayRef; - let actual_cardinality = block.get_stat(Stat::Cardinality); + let block = DataBlock::from_array(string_array); + let expected_cardinality = 3; + let actual_cardinality = block.expect_single_stat::(Stat::Cardinality); - assert_eq!( - actual_cardinality, - Some(expected_cardinality.clone()), - "Expected Stat::Cardinality to be {:?} for data block generated from array: {:?}", - expected_cardinality, - string_array, - ); + assert_eq!(actual_cardinality, expected_cardinality,); let string_array = StringArray::from(vec![ Some("Samuel Eilenberg"), Some("Saunders Mac Lane"), Some("Samuel Eilenberg"), ]); - let block = DataBlock::from_array(string_array.clone()); - let expected_cardinality = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_cardinality = block.get_stat(Stat::Cardinality); - - assert_eq!( - actual_cardinality, - Some(expected_cardinality.clone()), - "Expected Stat::Cardinality to be {:?} for data block generated from array: {:?}", - expected_cardinality, - string_array, - ); + let block = DataBlock::from_array(string_array); + let expected_cardinality = 2; + let actual_cardinality = block.expect_single_stat::(Stat::Cardinality); + assert_eq!(actual_cardinality, expected_cardinality,); let string_array = LargeStringArray::from(vec![Some("hello"), Some("world")]); - let block = DataBlock::from_array(string_array.clone()); - let expected_cardinality = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_cardinality = block.get_stat(Stat::Cardinality); - - assert_eq!( - actual_cardinality, - Some(expected_cardinality.clone()), - "Expected Stat::Cardinality to be {:?} for data block generated from array: {:?}", - expected_cardinality, - string_array, - ); + let block = DataBlock::from_array(string_array); + let expected_cardinality = 2; + let actual_cardinality = block.expect_single_stat::(Stat::Cardinality); + assert_eq!(actual_cardinality, expected_cardinality,); let string_array = LargeStringArray::from(vec![ Some("to be named by variables"), Some("to be passed as arguments to procedures"), Some("to be returned as values of procedures"), ]); - let block = DataBlock::from_array(string_array.clone()); - let expected_cardinality = Arc::new(UInt64Array::from(vec![3])) as ArrayRef; - let actual_cardinality = block.get_stat(Stat::Cardinality); - - assert_eq!( - actual_cardinality, - Some(expected_cardinality.clone()), - "Expected Stat::Cardinality to be {:?} for data block generated from array: {:?}", - expected_cardinality, - string_array, - ); + let block = DataBlock::from_array(string_array); + let expected_cardinality = 3; + let actual_cardinality = block.expect_single_stat::(Stat::Cardinality); + assert_eq!(actual_cardinality, expected_cardinality,); let string_array = LargeStringArray::from(vec![ Some("Samuel Eilenberg"), Some("Saunders Mac Lane"), Some("Samuel Eilenberg"), ]); - let block = DataBlock::from_array(string_array.clone()); - let expected_cardinality = Arc::new(UInt64Array::from(vec![2])) as ArrayRef; - let actual_cardinality = block.get_stat(Stat::Cardinality); - - assert_eq!( - actual_cardinality, - Some(expected_cardinality.clone()), - "Expected Stat::Cardinality to be {:?} for data block generated from array: {:?}", - expected_cardinality, - string_array, - ); + let block = DataBlock::from_array(string_array); + let expected_cardinality = 2; + let actual_cardinality = block.expect_single_stat::(Stat::Cardinality); + assert_eq!(actual_cardinality, expected_cardinality,); } #[test] fn test_max_length_variable_width_datablock() { let string_array = StringArray::from(vec![Some("hello"), Some("world")]); let block = DataBlock::from_array(string_array.clone()); - - let expected_max_length = - Arc::new(UInt64Array::from(vec![string_array.value_length(0) as u64])) as ArrayRef; - let actual_max_length = block.get_stat(Stat::MaxLength); - - assert_eq!(actual_max_length, Some(expected_max_length.clone()),); + let expected_max_length = string_array.value_length(0) as u64; + let actual_max_length = block.expect_single_stat::(Stat::MaxLength); + assert_eq!(actual_max_length, expected_max_length); let string_array = StringArray::from(vec![ Some("to be named by variables"), @@ -1258,12 +967,9 @@ mod tests { Some("to be returned as values of procedures"), ]); let block = DataBlock::from_array(string_array.clone()); - - let expected_max_length = - Arc::new(UInt64Array::from(vec![string_array.value_length(1) as u64])) as ArrayRef; - let actual_max_length = block.get_stat(Stat::MaxLength); - - assert_eq!(actual_max_length, Some(expected_max_length)); + let expected_max_length = string_array.value_length(1) as u64; + let actual_max_length = block.expect_single_stat::(Stat::MaxLength); + assert_eq!(actual_max_length, expected_max_length); let string_array = StringArray::from(vec![ Some("Samuel Eilenberg"), @@ -1271,21 +977,15 @@ mod tests { Some("Samuel Eilenberg"), ]); let block = DataBlock::from_array(string_array.clone()); - - let expected_max_length = - Arc::new(UInt64Array::from(vec![string_array.value_length(1) as u64])) as ArrayRef; - let actual_max_length = block.get_stat(Stat::MaxLength); - - assert_eq!(actual_max_length, Some(expected_max_length),); + let expected_max_length = string_array.value_length(1) as u64; + let actual_max_length = block.expect_single_stat::(Stat::MaxLength); + assert_eq!(actual_max_length, expected_max_length); let string_array = LargeStringArray::from(vec![Some("hello"), Some("world")]); let block = DataBlock::from_array(string_array.clone()); - - let expected_max_length = - Arc::new(UInt64Array::from(vec![string_array.value(0).len() as u64])) as ArrayRef; - let actual_max_length = block.get_stat(Stat::MaxLength); - - assert_eq!(actual_max_length, Some(expected_max_length),); + let expected_max_length = string_array.value_length(1) as u64; + let actual_max_length = block.expect_single_stat::(Stat::MaxLength); + assert_eq!(actual_max_length, expected_max_length); let string_array = LargeStringArray::from(vec![ Some("to be named by variables"), @@ -1293,11 +993,9 @@ mod tests { Some("to be returned as values of procedures"), ]); let block = DataBlock::from_array(string_array.clone()); + let expected_max_length = string_array.value(1).len() as u64; + let actual_max_length = block.expect_single_stat::(Stat::MaxLength); - let expected_max_length = - Arc::new(UInt64Array::from(vec![string_array.value_length(1) as u64])) as ArrayRef; - let actual_max_length = block.get_stat(Stat::MaxLength); - - assert_eq!(actual_max_length, Some(expected_max_length)); + assert_eq!(actual_max_length, expected_max_length); } } diff --git a/rust/lance-encoding/src/testing.rs b/rust/lance-encoding/src/testing.rs index 804cab93e7e..af6258edc96 100644 --- a/rust/lance-encoding/src/testing.rs +++ b/rust/lance-encoding/src/testing.rs @@ -4,7 +4,7 @@ use std::{cmp::Ordering, collections::HashMap, ops::Range, sync::Arc}; use arrow::array::make_comparator; -use arrow_array::{Array, UInt64Array}; +use arrow_array::{Array, StructArray, UInt64Array}; use arrow_schema::{DataType, Field, FieldRef, Schema, SortOptions}; use arrow_select::concat::concat; use bytes::{Bytes, BytesMut}; @@ -208,6 +208,14 @@ async fn test_decode( let expected_size = (batch_size as usize).min(expected.len() - offset); let expected = expected.slice(offset, expected_size); assert_eq!(expected.data_type(), actual.data_type()); + if expected.len() != actual.len() { + panic!( + "Mismatch in length (at offset={}) expected {} but got {}", + offset, + expected.len(), + actual.len() + ); + } if &expected != actual { if let Ok(comparator) = make_comparator(&expected, &actual, SortOptions::default()) { @@ -216,8 +224,9 @@ async fn test_decode( for i in 0..expected.len() { if !matches!(comparator(i, i), Ordering::Equal) { panic!( - "Mismatch at index {} expected {:?} but got {:?} first mismatch is expected {:?} but got {:?}", + "Mismatch at index {} (offset={}) expected {:?} but got {:?} first mismatch is expected {:?} but got {:?}", i, + offset, expected, actual, expected.slice(i, 1), @@ -446,16 +455,14 @@ impl SimulatedWriter { self.encoded_data.extend_from_slice(&buffer); let size = self.encoded_data.len() as u64 - offset; let pad_bytes = pad_bytes::(self.encoded_data.len()); - self.encoded_data - .extend(std::iter::repeat(0).take(pad_bytes)); + self.encoded_data.extend(std::iter::repeat_n(0, pad_bytes)); (offset, size) } fn write_lance_buffer(&mut self, buffer: LanceBuffer) { self.encoded_data.extend_from_slice(&buffer); let pad_bytes = pad_bytes::(self.encoded_data.len()); - self.encoded_data - .extend(std::iter::repeat(0).take(pad_bytes)); + self.encoded_data.extend(std::iter::repeat_n(0, pad_bytes)); } fn write_page(&mut self, encoded_page: EncodedPage) { @@ -464,7 +471,11 @@ impl SimulatedWriter { let page_encoding = encoded_page.description; let buffer_offsets_and_sizes = page_buffers .into_iter() - .map(|b| self.write_buffer(b)) + .map(|b| { + let (offset, size) = self.write_buffer(b); + trace!("Encoded buffer offset={} size={}", offset, size); + (offset, size) + }) .collect::>(); let page_info = PageInfo { @@ -496,8 +507,15 @@ async fn check_round_trip_encoding_inner( for arr in &data { let mut external_buffers = writer.new_external_buffers(); let repdef = RepDefBuilder::default(); + let num_rows = arr.len() as u64; let encode_tasks = encoder - .maybe_encode(arr.clone(), &mut external_buffers, repdef, row_number) + .maybe_encode( + arr.clone(), + &mut external_buffers, + repdef, + row_number, + num_rows, + ) .unwrap(); for buffer in external_buffers.take_buffers() { writer.write_lance_buffer(buffer); @@ -631,9 +649,23 @@ async fn check_round_trip_encoding_inner( } let num_rows = indices.len() as u64; let indices_arr = UInt64Array::from(indices.clone()); - let expected = concat_data - .as_ref() - .map(|concat_data| arrow_select::take::take(&concat_data, &indices_arr, None).unwrap()); + + // There is a bug in arrow_select::take::take that causes it to return empty arrays + // if the data type is an empty struct. This is a workaround for that. + let is_empty_struct = if let DataType::Struct(fields) = field.data_type() { + fields.is_empty() + } else { + false + }; + + let expected = if is_empty_struct { + Some(Arc::new(StructArray::new_empty_fields(indices_arr.len(), None)) as Arc) + } else { + concat_data.as_ref().map(|concat_data| { + arrow_select::take::take(&concat_data, &indices_arr, None).unwrap() + }) + }; + let scheduler = scheduler.clone(); let indices = indices.clone(); test_decode( diff --git a/rust/lance-encoding/src/utils.rs b/rust/lance-encoding/src/utils.rs new file mode 100644 index 00000000000..31ced9a21f5 --- /dev/null +++ b/rust/lance-encoding/src/utils.rs @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Miscellaneous utility functions that don't have a home elsewhere. + +pub mod bytepack; diff --git a/rust/lance-encoding/src/utils/bytepack.rs b/rust/lance-encoding/src/utils/bytepack.rs new file mode 100644 index 00000000000..1fbf17277c1 --- /dev/null +++ b/rust/lance-encoding/src/utils/bytepack.rs @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Utilities for byte (not bit) packing for situations where saving a few +//! bits is less important than simplicity and speed. + +pub struct U8BytePacker { + data: Vec, +} + +impl U8BytePacker { + fn with_capacity(capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity), + } + } + + fn append(&mut self, value: u64) { + self.data.push(value as u8); + } +} + +pub struct U16BytePacker { + data: Vec, +} + +impl U16BytePacker { + fn with_capacity(capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity * 2), + } + } + + fn append(&mut self, value: u64) { + self.data.extend_from_slice(&(value as u16).to_le_bytes()); + } +} + +pub struct U32BytePacker { + data: Vec, +} + +impl U32BytePacker { + fn with_capacity(capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity * 4), + } + } + + fn append(&mut self, value: u64) { + self.data.extend_from_slice(&(value as u32).to_le_bytes()); + } +} + +pub struct U64BytePacker { + data: Vec, +} + +impl U64BytePacker { + fn with_capacity(capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity * 8), + } + } + + fn append(&mut self, value: u64) { + self.data.extend_from_slice(&value.to_le_bytes()); + } +} + +/// A bytepacked integer encoder that automatically chooses the smallest +/// possible integer type to store the given values. +/// +/// This is byte packing (not bit packing). Not even that, we only fit things into +/// sizes of 1,2,4,8 bytes. It's simple, fast, and easy but doesn't provide the +/// maximum possible compression. +/// +/// Still, it's useful for things like offsets which are often small and fit into a +/// u16 or u32 but sometimes might need the full u64 range. +/// +/// In the future we can investigate replacing this with something more sophisticated. +pub enum BytepackedIntegerEncoder { + U8(U8BytePacker), + U16(U16BytePacker), + U32(U32BytePacker), + U64(U64BytePacker), + Zero, +} + +impl BytepackedIntegerEncoder { + /// Create a new encoder with the given capacity and maximum value. + pub fn with_capacity(capacity: usize, max_value: u64) -> Self { + if max_value == 0 { + Self::Zero + } else if max_value <= u8::MAX as u64 { + Self::U8(U8BytePacker::with_capacity(capacity)) + } else if max_value <= u16::MAX as u64 { + Self::U16(U16BytePacker::with_capacity(capacity)) + } else if max_value <= u32::MAX as u64 { + Self::U32(U32BytePacker::with_capacity(capacity)) + } else { + Self::U64(U64BytePacker::with_capacity(capacity)) + } + } + + /// Append a value to the encoder. + /// + /// # Safety + /// + /// This function is unsafe because it doesn't check for overflow. If the + /// value is too large to fit in the chosen integer type, it will be silently + /// truncated. + pub unsafe fn append(&mut self, value: u64) { + match self { + Self::U8(packer) => packer.append(value), + Self::U16(packer) => packer.append(value), + Self::U32(packer) => packer.append(value), + Self::U64(packer) => packer.append(value), + Self::Zero => {} + } + } + + /// Convert the encoder into a vector of bytes. + pub fn into_data(self) -> Vec { + match self { + Self::U8(packer) => packer.data, + Self::U16(packer) => packer.data, + Self::U32(packer) => packer.data, + Self::U64(packer) => packer.data, + Self::Zero => Vec::new(), + } + } +} + +/// An iterator that unpacks bytes into integers (currently only u64) +pub enum ByteUnpacker> { + U8(I), + U16(I), + U32(I), + U64(I), +} + +impl> ByteUnpacker { + #[allow(clippy::new_ret_no_self)] + pub fn new>(data: I, size: usize) -> impl Iterator { + match size { + 1 => Self::U8(data.into_iter()), + 2 => Self::U16(data.into_iter()), + 4 => Self::U32(data.into_iter()), + 8 => Self::U64(data.into_iter()), + _ => panic!("Invalid size"), + } + } +} + +impl> Iterator for ByteUnpacker { + type Item = u64; + + fn next(&mut self) -> Option { + match self { + Self::U8(iter) => iter.next().map(|v| v as u64), + Self::U16(iter) => { + let first_byte = iter.next()?; + Some(u16::from_le_bytes([first_byte, iter.next().unwrap()]) as u64) + } + Self::U32(iter) => { + let first_byte = iter.next()?; + Some(u32::from_le_bytes([ + first_byte, + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + ]) as u64) + } + Self::U64(iter) => { + let first_byte = iter.next()?; + Some(u64::from_le_bytes([ + first_byte, + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + iter.next().unwrap(), + ])) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bytepacked_integer_encoder() { + // Fits in u8 + let mut encoder = BytepackedIntegerEncoder::with_capacity(10, 100); + unsafe { + encoder.append(50); + encoder.append(20); + encoder.append(30); + } + let data = encoder.into_data(); + assert_eq!(data, vec![50, 20, 30]); + + assert_eq!( + ByteUnpacker::new(data, 1).collect::>(), + vec![50, 20, 30] + ); + + // Requires u16 + let mut encoder = BytepackedIntegerEncoder::with_capacity(10, 1000); + unsafe { + encoder.append(500); + encoder.append(200); + encoder.append(300); + } + let data = encoder.into_data(); + assert_eq!(data, vec![244, 1, 200, 0, 44, 1]); + + assert_eq!( + ByteUnpacker::new(data, 2).collect::>(), + vec![500, 200, 300] + ); + + // Requires u32 + let mut encoder = BytepackedIntegerEncoder::with_capacity(10, 1000000); + unsafe { + encoder.append(500000); + encoder.append(200000); + encoder.append(300000); + } + let data = encoder.into_data(); + assert_eq!(data, vec![32, 161, 7, 0, 64, 13, 3, 0, 224, 147, 4, 0]); + + assert_eq!( + ByteUnpacker::new(data, 4).collect::>(), + vec![500000, 200000, 300000] + ); + + // Requires u64 + let mut encoder = BytepackedIntegerEncoder::with_capacity(10, 0x10000000000); + unsafe { + encoder.append(0x5000000000); + encoder.append(0x2000000000); + encoder.append(0x3000000000); + } + let data = encoder.into_data(); + assert_eq!( + data, + vec![0, 0, 0, 0, 80, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 48, 0, 0, 0] + ); + + assert_eq!( + ByteUnpacker::new(data, 8).collect::>(), + vec![0x5000000000, 0x2000000000, 0x3000000000] + ); + } +} diff --git a/rust/lance-encoding/src/version.rs b/rust/lance-encoding/src/version.rs index a27cd4e15f5..289c80c4f8e 100644 --- a/rust/lance-encoding/src/version.rs +++ b/rust/lance-encoding/src/version.rs @@ -4,7 +4,7 @@ use std::str::FromStr; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; pub const LEGACY_FORMAT_VERSION: &str = "0.1"; pub const V2_FORMAT_2_0: &str = "2.0"; @@ -88,6 +88,7 @@ impl FromStr for LanceFileVersion { V2_FORMAT_2_1 => Ok(Self::V2_1), "stable" => Ok(Self::Stable), "legacy" => Ok(Self::Legacy), + "next" => Ok(Self::Next), // Version 0.3 is an alias of 2.0 "0.3" => Ok(Self::V2_0), _ => Err(Error::InvalidInput { diff --git a/rust/lance-file/Cargo.toml b/rust/lance-file/Cargo.toml index eabb950e08e..17fd79801d7 100644 --- a/rust/lance-file/Cargo.toml +++ b/rust/lance-file/Cargo.toml @@ -51,10 +51,18 @@ test-log.workspace = true [build-dependencies] prost-build.workspace = true +protobuf-src = { version = "2.1", optional = true } [target.'cfg(target_os = "linux")'.dev-dependencies] pprof = { workspace = true } +[features] +protoc = ["dep:protobuf-src"] + +[package.metadata.docs.rs] +# docs.rs uses an older version of Ubuntu that does not have the necessary protoc version +features = ["protoc"] + [[bench]] name = "reader" harness = false diff --git a/rust/lance-file/benches/reader.rs b/rust/lance-file/benches/reader.rs index 511bbd10711..cc773425c22 100644 --- a/rust/lance-file/benches/reader.rs +++ b/rust/lance-file/benches/reader.rs @@ -2,10 +2,11 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::sync::{Arc, Mutex}; -use arrow_array::{cast::AsArray, types::Int32Type}; +use arrow_array::{cast::AsArray, types::Int32Type, UInt32Array}; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use futures::{FutureExt, StreamExt}; +use lance_datagen::ArrayGeneratorExt; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; use lance_file::{ v2::{ @@ -19,6 +20,7 @@ use lance_io::{ object_store::ObjectStore, scheduler::{ScanScheduler, SchedulerConfig}, }; +use rand::seq::SliceRandom; fn bench_reader(c: &mut Criterion) { for version in [LanceFileVersion::V2_0, LanceFileVersion::V2_1] { @@ -31,8 +33,12 @@ fn bench_reader(c: &mut Criterion) { let tempdir = tempfile::tempdir().unwrap(); let test_path = tempdir.path(); - let (object_store, base_path) = - ObjectStore::from_path(test_path.as_os_str().to_str().unwrap()).unwrap(); + let (object_store, base_path) = rt + .block_on(ObjectStore::from_uri( + test_path.as_os_str().to_str().unwrap(), + )) + .unwrap(); + let file_path = base_path.child("foo.lance"); let object_writer = rt.block_on(object_store.create(&file_path)).unwrap(); @@ -57,7 +63,7 @@ fn bench_reader(c: &mut Criterion) { let data = &data; rt.block_on(async move { let store_scheduler = ScanScheduler::new( - Arc::new(object_store.clone()), + object_store.clone(), SchedulerConfig::default_for_testing(), ); let scheduler = store_scheduler.open_file(file_path).await.unwrap(); @@ -111,17 +117,122 @@ fn bench_reader(c: &mut Criterion) { } } +fn bench_random_access(c: &mut Criterion) { + const TAKE_SIZE: usize = 100; + for version in [LanceFileVersion::V2_0, LanceFileVersion::V2_1] { + let mut group = c.benchmark_group(format!("reader_{}", version)); + let data = lance_datagen::gen() + .anon_col(lance_datagen::array::rand_type(&DataType::Int32).with_random_nulls(0.1)) + .into_batch_rows(lance_datagen::RowCount::from(2 * 1024 * 1024)) + .unwrap(); + let rt = tokio::runtime::Runtime::new().unwrap(); + + let tempdir = tempfile::tempdir().unwrap(); + let test_path = tempdir.path(); + let (object_store, base_path) = rt + .block_on(ObjectStore::from_uri( + test_path.as_os_str().to_str().unwrap(), + )) + .unwrap(); + let file_path = base_path.child("foo.lance"); + let object_writer = rt.block_on(object_store.create(&file_path)).unwrap(); + + let mut writer = FileWriter::try_new( + object_writer, + data.schema().as_ref().try_into().unwrap(), + FileWriterOptions { + format_version: Some(version), + ..Default::default() + }, + ) + .unwrap(); + rt.block_on(writer.write_batch(&data)).unwrap(); + rt.block_on(writer.finish()).unwrap(); + + let mut indices = (0..data.num_rows() as u32).collect::>(); + indices.partial_shuffle(&mut rand::thread_rng(), TAKE_SIZE); + indices.truncate(TAKE_SIZE); + let indices: UInt32Array = indices.into(); + + let object_store = &object_store; + let file_path = &file_path; + let reader = rt.block_on(async move { + let store_scheduler = + ScanScheduler::new(object_store.clone(), SchedulerConfig::default_for_testing()); + let scheduler = store_scheduler.open_file(file_path).await.unwrap(); + Arc::new( + FileReader::try_open( + scheduler.clone(), + None, + Arc::::default(), + &test_cache(), + FileReaderOptions::default(), + ) + .await + .unwrap(), + ) + }); + + group.throughput(criterion::Throughput::Elements(TAKE_SIZE as u64)); + group.bench_function("take", |b| { + let reader = reader.clone(); + let indices = indices.clone(); + b.iter(|| { + let reader = reader.clone(); + let indices = indices.clone(); + rt.block_on(async move { + let stream = reader + .read_tasks( + lance_io::ReadBatchParams::Indices(indices), + TAKE_SIZE as u32, + None, + FilterExpression::no_filter(), + ) + .unwrap(); + let stats = Arc::new(Mutex::new((0, 0))); + let mut stream = stream + .map(|batch_task| { + let stats = stats.clone(); + async move { + let batch = batch_task.task.await.unwrap(); + let row_count = batch.num_rows(); + let sum = batch + .column(0) + .as_primitive::() + .values() + .iter() + .map(|v| *v as i64) + .sum::(); + let mut stats = stats.lock().unwrap(); + stats.0 += row_count; + stats.1 += sum; + } + .boxed() + }) + .buffer_unordered(16); + while (stream.next().await).is_some() {} + let stats = stats.lock().unwrap(); + let row_count = stats.0; + let sum = stats.1; + assert_eq!(TAKE_SIZE, row_count); + black_box(sum); + }); + }) + }); + } +} + #[cfg(target_os = "linux")] criterion_group!( name=benches; config = Criterion::default().significance_level(0.1).sample_size(10) .with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None))); - targets = bench_reader); + targets = bench_reader, bench_random_access); // Non-linux version does not support pprof. #[cfg(not(target_os = "linux"))] criterion_group!( name=benches; config = Criterion::default().significance_level(0.1).sample_size(10); - targets = bench_reader); + targets = bench_reader, bench_random_access); criterion_main!(benches); diff --git a/rust/lance-file/build.rs b/rust/lance-file/build.rs index dd004147ecd..05b791fac38 100644 --- a/rust/lance-file/build.rs +++ b/rust/lance-file/build.rs @@ -6,6 +6,10 @@ use std::io::Result; fn main() -> Result<()> { println!("cargo:rerun-if-changed=protos"); + #[cfg(feature = "protoc")] + // Use vendored protobuf compiler if requested. + std::env::set_var("PROTOC", protobuf_src::protoc()); + let mut prost_build = prost_build::Config::new(); prost_build.protoc_arg("--experimental_allow_proto3_optional"); prost_build.extern_path(".lance.encodings", "::lance_encoding::format::pb"); diff --git a/rust/lance-file/src/datatypes.rs b/rust/lance-file/src/datatypes.rs index 6560c73f7ca..b0705292d9e 100644 --- a/rust/lance-file/src/datatypes.rs +++ b/rust/lance-file/src/datatypes.rs @@ -11,7 +11,7 @@ use lance_core::datatypes::{Dictionary, Encoding, Field, LogicalType, Schema}; use lance_core::{Error, Result}; use lance_io::traits::Reader; use lance_io::utils::{read_binary_array, read_fixed_stride_array}; -use snafu::{location, Location}; +use snafu::location; use crate::format::pb; @@ -250,10 +250,7 @@ async fn load_field_dictionary<'a>(field: &mut Field, reader: &dyn Reader) -> Re /// Load dictionary value array from manifest files. // TODO: pub(crate) -pub async fn populate_schema_dictionary<'a>( - schema: &mut Schema, - reader: &dyn Reader, -) -> Result<()> { +pub async fn populate_schema_dictionary(schema: &mut Schema, reader: &dyn Reader) -> Result<()> { for field in schema.fields.as_mut_slice() { load_field_dictionary(field, reader).await?; } diff --git a/rust/lance-file/src/format/metadata.rs b/rust/lance-file/src/format/metadata.rs index d6f5a036a64..32108702392 100644 --- a/rust/lance-file/src/format/metadata.rs +++ b/rust/lance-file/src/format/metadata.rs @@ -10,7 +10,7 @@ use deepsize::DeepSizeOf; use lance_core::datatypes::Schema; use lance_core::{Error, Result}; use lance_io::traits::ProtoStruct; -use snafu::{location, Location}; +use snafu::location; /// Data File Metadata #[derive(Debug, Default, DeepSizeOf, PartialEq)] pub struct Metadata { diff --git a/rust/lance-file/src/lib.rs b/rust/lance-file/src/lib.rs index 174cbef1a18..482d78d8cf8 100644 --- a/rust/lance-file/src/lib.rs +++ b/rust/lance-file/src/lib.rs @@ -15,7 +15,7 @@ use lance_core::{Error, Result}; use lance_encoding::version::LanceFileVersion; use lance_io::object_store::ObjectStore; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; pub async fn determine_file_version( store: &ObjectStore, diff --git a/rust/lance-file/src/page_table.rs b/rust/lance-file/src/page_table.rs index 43ea2631684..65b41b62bcc 100644 --- a/rust/lance-file/src/page_table.rs +++ b/rust/lance-file/src/page_table.rs @@ -7,7 +7,7 @@ use arrow_schema::DataType; use deepsize::DeepSizeOf; use lance_io::encodings::plain::PlainDecoder; use lance_io::encodings::Decoder; -use snafu::{location, Location}; +use snafu::location; use std::collections::BTreeMap; use tokio::io::AsyncWriteExt; @@ -51,7 +51,7 @@ impl PageTable { /// Non-existent pages will be represented as (0, 0) in the page table. Pages /// can be non-existent because they are not present in the file, or because /// they are struct fields which have no data pages. - pub async fn load<'a>( + pub async fn load( reader: &dyn Reader, position: usize, min_field_id: i32, diff --git a/rust/lance-file/src/reader.rs b/rust/lance-file/src/reader.rs index 1b2e7b1a25d..aff92bd0cae 100644 --- a/rust/lance-file/src/reader.rs +++ b/rust/lance-file/src/reader.rs @@ -35,7 +35,7 @@ use lance_io::utils::{ use lance_io::{object_store::ObjectStore, ReadBatchParams}; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use crate::format::metadata::Metadata; @@ -391,7 +391,7 @@ impl FileReader { /// - **reader**: An opened file reader. /// - **projection**: The schema of the returning [RecordBatch]. /// - **predicate**: A function that takes a batch ID and returns true if the batch should be -/// returned. +/// returned. /// /// Returns: /// - A stream of [RecordBatch]s, each one corresponding to one full batch in the file. @@ -532,7 +532,7 @@ fn read_null_array( 0 } else { let idx_max = *indices.values().iter().max().unwrap() as u64; - if idx_max >= page_info.length.try_into().unwrap() { + if idx_max >= page_info.length as u64 { return Err(Error::io( format!( "NullArray Reader: request([{}]) out of range: [0..{}]", @@ -766,6 +766,7 @@ mod tests { }; use arrow_array::{BooleanArray, Int32Array}; use arrow_schema::{Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema}; + use lance_io::object_store::ObjectStoreParams; #[tokio::test] async fn test_take() { @@ -1364,8 +1365,17 @@ mod tests { #[tokio::test] async fn test_take_boolean_beyond_chunk() { - let mut store = ObjectStore::memory(); - store.set_block_size(256); + let store = ObjectStore::from_uri_and_params( + Arc::new(Default::default()), + "memory://", + &ObjectStoreParams { + block_size: Some(256), + ..Default::default() + }, + ) + .await + .unwrap() + .0; let path = Path::from("/take_bools"); let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( diff --git a/rust/lance-file/src/v2.rs b/rust/lance-file/src/v2.rs index 4223a06956d..72f93c21826 100644 --- a/rust/lance-file/src/v2.rs +++ b/rust/lance-file/src/v2.rs @@ -5,3 +5,5 @@ pub(crate) mod io; pub mod reader; pub mod testing; pub mod writer; + +pub use io::LanceEncodingsIo; diff --git a/rust/lance-file/src/v2/reader.rs b/rust/lance-file/src/v2/reader.rs index 172bc6f41cc..90964644fc2 100644 --- a/rust/lance-file/src/v2/reader.rs +++ b/rust/lance-file/src/v2/reader.rs @@ -9,6 +9,7 @@ use std::{ sync::Arc, }; +use arrow_array::RecordBatchReader; use arrow_schema::Schema as ArrowSchema; use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; use bytes::{Bytes, BytesMut}; @@ -16,16 +17,18 @@ use deepsize::{Context, DeepSizeOf}; use futures::{stream::BoxStream, Stream, StreamExt}; use lance_encoding::{ decoder::{ - schedule_and_decode, ColumnInfo, DecoderPlugins, FilterExpression, PageEncoding, PageInfo, - ReadBatchTask, RequestedRows, SchedulerDecoderConfig, + schedule_and_decode, schedule_and_decode_blocking, ColumnInfo, DecoderPlugins, + FilterExpression, PageEncoding, PageInfo, ReadBatchTask, RequestedRows, + SchedulerDecoderConfig, }, encoder::EncodedBatch, version::LanceFileVersion, EncodingsIo, }; use log::debug; +use object_store::path::Path; use prost::{Message, Name}; -use snafu::{location, Location}; +use snafu::location; use lance_core::{ cache::FileMetadataCache, @@ -57,6 +60,24 @@ pub struct BufferDescriptor { pub size: u64, } +/// Statistics summarize some of the file metadata for quick summary info +#[derive(Debug)] +pub struct FileStatistics { + /// Statistics about each of the columns in the file + pub columns: Vec, +} + +/// Summary information describing a column +#[derive(Debug)] +pub struct ColumnStatistics { + /// The number of pages in the column + pub num_pages: usize, + /// The total number of data & metadata bytes in the column + /// + /// This is the compressed on-disk size + pub size_bytes: u64, +} + // TODO: Caching #[derive(Debug)] pub struct CachedFileMetadata { @@ -138,13 +159,13 @@ pub struct ReaderProjection { /// /// - Primitive: the index of the column in the schema /// - List: the index of the list column in the schema - /// followed by the column indices of the children + /// followed by the column indices of the children /// - FixedSizeList (of primitive): the index of the column in the schema - /// (this case is not nested) + /// (this case is not nested) /// - FixedSizeList (of non-primitive): not yet implemented /// - Dictionary: same as primitive /// - Struct: the index of the struct column in the schema - /// followed by the column indices of the children + /// followed by the column indices of the children /// /// In other words, this should be a DFS listing of the desired schema. /// @@ -218,23 +239,29 @@ impl ReaderProjection { /// /// If the schema provided is not the schema of the entire file then /// the projection will be invalid and the read will fail. + /// If the field is a `struct datatype` with `packed` set to true in the field metadata, + /// the whole struct has one column index. + /// To support nested `packed-struct encoding`, this method need to be further adjusted. pub fn from_whole_schema(schema: &Schema, version: LanceFileVersion) -> Self { let schema = Arc::new(schema.clone()); let is_structural = version >= LanceFileVersion::V2_1; - let mut counter = 0; - let counter = &mut counter; - let column_indices = schema - .fields_pre_order() - .filter_map(|field| { - if field.children.is_empty() || !is_structural { - let col_idx = *counter; - *counter += 1; - Some(col_idx) - } else { - None - } - }) - .collect::>(); + let mut column_indices = vec![]; + let mut curr_column_idx = 0; + let mut packed_struct_fields_num = 0; + for field in schema.fields_pre_order() { + if packed_struct_fields_num > 0 { + packed_struct_fields_num -= 1; + continue; + } + if field.is_packed_struct() { + column_indices.push(curr_column_idx); + curr_column_idx += 1; + packed_struct_fields_num = field.children.len(); + } else if field.children.is_empty() || !is_structural { + column_indices.push(curr_column_idx); + curr_column_idx += 1; + } + } Self { schema, column_indices, @@ -265,14 +292,14 @@ impl ReaderProjection { } } -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub struct FileReaderOptions { validate_on_decode: bool, } #[derive(Debug)] pub struct FileReader { - scheduler: Arc, + scheduler: Arc, // The default projection to be applied to all reads base_projection: ReaderProjection, num_rows: u64, @@ -299,6 +326,18 @@ struct Footer { const FOOTER_LEN: usize = 40; impl FileReader { + pub fn with_scheduler(&self, scheduler: Arc) -> Self { + Self { + scheduler, + base_projection: self.base_projection.clone(), + cache: self.cache.clone(), + decoder_plugins: self.decoder_plugins.clone(), + metadata: self.metadata.clone(), + options: self.options.clone(), + num_rows: self.num_rows, + } + } + pub fn num_rows(&self) -> u64 { self.num_rows } @@ -307,6 +346,30 @@ impl FileReader { &self.metadata } + pub fn file_statistics(&self) -> FileStatistics { + let column_metadatas = &self.metadata().column_metadatas; + + let column_stats = column_metadatas + .iter() + .map(|col_metadata| { + let num_pages = col_metadata.pages.len(); + let size_bytes = col_metadata + .pages + .iter() + .map(|page| page.buffer_sizes.iter().sum::()) + .sum::(); + ColumnStatistics { + num_pages, + size_bytes, + } + }) + .collect(); + + FileStatistics { + columns: column_stats, + } + } + pub async fn read_global_buffer(&self, index: u32) -> Result { let buffer_desc = self.metadata.file_buffers.get(index as usize).ok_or_else(||Error::invalid_input(format!("request for global buffer at index {} but there were only {} global buffers in the file", index, self.metadata.file_buffers.len()), location!()))?; self.scheduler @@ -670,15 +733,17 @@ impl FileReader { pub async fn try_open( scheduler: FileScheduler, base_projection: Option, - decoder_strategy: Arc, + decoder_plugins: Arc, cache: &FileMetadataCache, options: FileReaderOptions, ) -> Result { let file_metadata = Arc::new(Self::read_all_metadata(&scheduler).await?); + let path = scheduler.reader().path().clone(); Self::try_open_with_file_metadata( - scheduler, + Arc::new(LanceEncodingsIo(scheduler)), + path, base_projection, - decoder_strategy, + decoder_plugins, file_metadata, cache, options, @@ -687,22 +752,27 @@ impl FileReader { } /// Same as `try_open` but with the file metadata already loaded. + /// + /// This method also can accept any kind of `EncodingsIo` implementation allowing + /// for custom strategies to be used for I/O scheduling (e.g. for takes on fast + /// disks it may be better to avoid asynchronous overhead). pub async fn try_open_with_file_metadata( - scheduler: FileScheduler, + scheduler: Arc, + path: Path, base_projection: Option, decoder_plugins: Arc, file_metadata: Arc, cache: &FileMetadataCache, options: FileReaderOptions, ) -> Result { - let cache = Arc::new(cache.with_base_path(scheduler.reader().path().clone())); + let cache = Arc::new(cache.with_base_path(path)); if let Some(base_projection) = base_projection.as_ref() { Self::validate_projection(base_projection, &file_metadata)?; } let num_rows = file_metadata.num_rows; Ok(Self { - scheduler: Arc::new(LanceEncodingsIo(scheduler)), + scheduler, base_projection: base_projection.unwrap_or(ReaderProjection::from_whole_schema( file_metadata.file_schema.as_ref(), file_metadata.version(), @@ -979,6 +1049,161 @@ impl FileReader { ))) } + fn take_rows_blocking( + &self, + indices: Vec, + batch_size: u32, + projection: ReaderProjection, + filter: FilterExpression, + ) -> Result> { + let column_infos = self.collect_columns_from_projection(&projection)?; + debug!( + "Taking {} rows spread across range {}..{} with batch_size {} from columns {:?}", + indices.len(), + indices[0], + indices[indices.len() - 1], + batch_size, + column_infos.iter().map(|ci| ci.index).collect::>() + ); + + let config = SchedulerDecoderConfig { + batch_size, + cache: self.cache.clone(), + decoder_plugins: self.decoder_plugins.clone(), + io: self.scheduler.clone(), + should_validate: self.options.validate_on_decode, + }; + + let requested_rows = RequestedRows::Indices(indices); + + schedule_and_decode_blocking( + column_infos, + requested_rows, + filter, + projection.column_indices, + projection.schema, + config, + ) + } + + fn read_range_blocking( + &self, + range: Range, + batch_size: u32, + projection: ReaderProjection, + filter: FilterExpression, + ) -> Result> { + let column_infos = self.collect_columns_from_projection(&projection)?; + let num_rows = self.num_rows; + + debug!( + "Reading range {:?} with batch_size {} from file with {} rows and {} columns into schema with {} columns", + range, + batch_size, + num_rows, + column_infos.len(), + projection.schema.fields.len(), + ); + + let config = SchedulerDecoderConfig { + batch_size, + cache: self.cache.clone(), + decoder_plugins: self.decoder_plugins.clone(), + io: self.scheduler.clone(), + should_validate: self.options.validate_on_decode, + }; + + let requested_rows = RequestedRows::Ranges(vec![range]); + + schedule_and_decode_blocking( + column_infos, + requested_rows, + filter, + projection.column_indices, + projection.schema, + config, + ) + } + + /// Read data from the file as an iterator of record batches + /// + /// This is a blocking variant of [`Self::read_stream_projected`] that runs entirely in the + /// calling thread. It will block on I/O if the decode is faster than the I/O. It is useful + /// for benchmarking and potentially from "take"ing small batches from fast disks. + /// + /// Large scans of in-memory data will still benefit from threading (and should therefore not + /// use this method) because we can parallelize the decode. + /// + /// Note: calling this from within a tokio runtime will panic. It is acceptable to call this + /// from a spawn_blocking context. + pub fn read_stream_projected_blocking( + &self, + params: ReadBatchParams, + batch_size: u32, + projection: Option, + filter: FilterExpression, + ) -> Result> { + let projection = projection.unwrap_or_else(|| self.base_projection.clone()); + Self::validate_projection(&projection, &self.metadata)?; + let verify_bound = |params: &ReadBatchParams, bound: u64, inclusive: bool| { + if bound > self.num_rows || bound == self.num_rows && inclusive { + Err(Error::invalid_input( + format!( + "cannot read {:?} from file with {} rows", + params, self.num_rows + ), + location!(), + )) + } else { + Ok(()) + } + }; + match ¶ms { + ReadBatchParams::Indices(indices) => { + for idx in indices { + match idx { + None => { + return Err(Error::invalid_input( + "Null value in indices array", + location!(), + )); + } + Some(idx) => { + verify_bound(¶ms, idx as u64, true)?; + } + } + } + let indices = indices.iter().map(|idx| idx.unwrap() as u64).collect(); + self.take_rows_blocking(indices, batch_size, projection, filter) + } + ReadBatchParams::Range(range) => { + verify_bound(¶ms, range.end as u64, false)?; + self.read_range_blocking( + range.start as u64..range.end as u64, + batch_size, + projection, + filter, + ) + } + ReadBatchParams::RangeFrom(range) => { + verify_bound(¶ms, range.start as u64, true)?; + self.read_range_blocking( + range.start as u64..self.num_rows, + batch_size, + projection, + filter, + ) + } + ReadBatchParams::RangeTo(range) => { + verify_bound(¶ms, range.end as u64, false)?; + self.read_range_blocking(0..range.end as u64, batch_size, projection, filter) + } + ReadBatchParams::RangeFull => { + self.read_range_blocking(0..self.num_rows, batch_size, projection, filter) + } + } + } + /// Reads data from the file as a stream of record batches /// /// This is similar to [`Self::read_stream_projected`] but uses the base projection @@ -1030,6 +1255,16 @@ pub fn describe_encoding(page: &pbfile::column_metadata::Page) -> String { format!("Unsupported(decode_err={})", err) } } + } else if encoding_any.type_url == "/lance.encodings.PageLayout" { + let encoding = encoding_any.to_msg::(); + match encoding { + Ok(encoding) => { + format!("{:#?}", encoding) + } + Err(err) => { + format!("Unsupported(decode_err={})", err) + } + } } else { format!("Unrecognized(type_url={})", encoding_any.type_url) } @@ -1153,13 +1388,13 @@ pub mod tests { use arrow_array::{ types::{Float64Type, Int32Type}, - RecordBatch, + RecordBatch, UInt32Array, }; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; use bytes::Bytes; use futures::{prelude::stream::TryStreamExt, StreamExt}; use lance_arrow::RecordBatchExt; - use lance_core::datatypes::Schema; + use lance_core::{datatypes::Schema, ArrowResult}; use lance_datagen::{array, gen, BatchCount, ByteCount, RowCount}; use lance_encoding::{ decoder::{decode_batch, DecodeBatchScheduler, DecoderPlugins, FilterExpression}, @@ -1176,22 +1411,32 @@ pub mod tests { writer::{EncodedBatchWriteExt, FileWriter, FileWriterOptions}, }; - async fn create_some_file(fs: &FsFixture) -> WrittenFile { + async fn create_some_file(fs: &FsFixture, version: LanceFileVersion) -> WrittenFile { let location_type = DataType::Struct(Fields::from(vec![ Field::new("x", DataType::Float64, true), Field::new("y", DataType::Float64, true), ])); let categories_type = DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))); - let reader = gen() + let mut reader = gen() .col("score", array::rand::()) .col("location", array::rand_type(&location_type)) .col("categories", array::rand_type(&categories_type)) - .col("binary", array::rand_type(&DataType::Binary)) - .col("large_bin", array::rand_type(&DataType::LargeBinary)) - .into_reader_rows(RowCount::from(1000), BatchCount::from(100)); + .col("binary", array::rand_type(&DataType::Binary)); + if version <= LanceFileVersion::V2_0 { + reader = reader.col("large_bin", array::rand_type(&DataType::LargeBinary)); + } + let reader = reader.into_reader_rows(RowCount::from(1000), BatchCount::from(100)); - write_lance_file(reader, fs, FileWriterOptions::default()).await + write_lance_file( + reader, + fs, + FileWriterOptions { + format_version: Some(version), + ..Default::default() + }, + ) + .await } type Transformer = Box RecordBatch>; @@ -1250,7 +1495,7 @@ pub mod tests { async fn test_round_trip() { let fs = FsFixture::default(); - let WrittenFile { data, .. } = create_some_file(&fs).await; + let WrittenFile { data, .. } = create_some_file(&fs, LanceFileVersion::V2_0).await; for read_size in [32, 1024, 1024 * 1024] { let file_scheduler = fs.scheduler.open_file(&fs.tmp_path).await.unwrap(); @@ -1346,7 +1591,7 @@ pub mod tests { async fn test_projection() { let fs = FsFixture::default(); - let written_file = create_some_file(&fs).await; + let written_file = create_some_file(&fs, LanceFileVersion::V2_0).await; let file_scheduler = fs.scheduler.open_file(&fs.tmp_path).await.unwrap(); let field_id_mapping = written_file @@ -1355,6 +1600,11 @@ pub mod tests { .copied() .collect::>(); + let empty_projection = ReaderProjection { + column_indices: Vec::default(), + schema: Arc::new(Schema::default()), + }; + for columns in [ vec!["score"], vec!["location"], @@ -1436,12 +1686,17 @@ pub mod tests { })), ) .await; - } - let empty_projection = ReaderProjection { - column_indices: Vec::default(), - schema: Arc::new(Schema::default()), - }; + assert!(file_reader + .read_stream_projected( + lance_io::ReadBatchParams::RangeFull, + 1024, + 16, + empty_projection.clone(), + FilterExpression::no_filter(), + ) + .is_err()); + } assert!(FileReader::try_open( file_scheduler.clone(), @@ -1479,7 +1734,7 @@ pub mod tests { async fn test_compressing_buffer() { let fs = FsFixture::default(); - let written_file = create_some_file(&fs).await; + let written_file = create_some_file(&fs, LanceFileVersion::V2_0).await; let file_scheduler = fs.scheduler.open_file(&fs.tmp_path).await.unwrap(); // We can specify the projection as part of the read operation via read_stream_projected @@ -1529,7 +1784,7 @@ pub mod tests { #[tokio::test] async fn test_read_all() { let fs = FsFixture::default(); - let WrittenFile { data, .. } = create_some_file(&fs).await; + let WrittenFile { data, .. } = create_some_file(&fs, LanceFileVersion::V2_0).await; let total_rows = data.iter().map(|batch| batch.num_rows()).sum::(); let file_scheduler = fs.scheduler.open_file(&fs.tmp_path).await.unwrap(); @@ -1558,10 +1813,47 @@ pub mod tests { assert_eq!(batches[0].num_rows(), total_rows); } + #[tokio::test] + async fn test_blocking_take() { + let fs = FsFixture::default(); + let WrittenFile { data, schema, .. } = create_some_file(&fs, LanceFileVersion::V2_1).await; + let total_rows = data.iter().map(|batch| batch.num_rows()).sum::(); + + let file_scheduler = fs.scheduler.open_file(&fs.tmp_path).await.unwrap(); + let file_reader = FileReader::try_open( + file_scheduler.clone(), + Some(ReaderProjection::from_column_names(&schema, &["score"]).unwrap()), + Arc::::default(), + &test_cache(), + FileReaderOptions::default(), + ) + .await + .unwrap(); + + let batches = tokio::task::spawn_blocking(move || { + file_reader + .read_stream_projected_blocking( + lance_io::ReadBatchParams::Indices(UInt32Array::from(vec![0, 1, 2, 3, 4])), + total_rows as u32, + None, + FilterExpression::no_filter(), + ) + .unwrap() + .collect::>>() + .unwrap() + }) + .await + .unwrap(); + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 5); + assert_eq!(batches[0].num_columns(), 1); + } + #[tokio::test(flavor = "multi_thread")] async fn test_drop_in_progress() { let fs = FsFixture::default(); - let WrittenFile { data, .. } = create_some_file(&fs).await; + let WrittenFile { data, .. } = create_some_file(&fs, LanceFileVersion::V2_0).await; let total_rows = data.iter().map(|batch| batch.num_rows()).sum::(); let file_scheduler = fs.scheduler.open_file(&fs.tmp_path).await.unwrap(); @@ -1605,7 +1897,7 @@ pub mod tests { // if the stream was dropped before it finished. let fs = FsFixture::default(); - let written_file = create_some_file(&fs).await; + let written_file = create_some_file(&fs, LanceFileVersion::V2_0).await; let total_rows = written_file .data .iter() diff --git a/rust/lance-file/src/v2/writer.rs b/rust/lance-file/src/v2/writer.rs index a3eb7d99b56..809810bd37d 100644 --- a/rust/lance-file/src/v2/writer.rs +++ b/rust/lance-file/src/v2/writer.rs @@ -3,6 +3,7 @@ use core::panic; use std::collections::HashMap; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use arrow_array::RecordBatch; @@ -21,12 +22,14 @@ use lance_encoding::encoder::{ }; use lance_encoding::repdef::RepDefBuilder; use lance_encoding::version::LanceFileVersion; +use lance_io::object_store::ObjectStore; use lance_io::object_writer::ObjectWriter; use lance_io::traits::Writer; -use log::debug; +use log::{debug, warn}; +use object_store::path::Path; use prost::Message; use prost_types::Any; -use snafu::{location, Location}; +use snafu::location; use tokio::io::AsyncWriteExt; use tracing::instrument; @@ -112,6 +115,8 @@ fn initial_column_metadata() -> pbfile::ColumnMetadata { } } +static WARNED_ON_UNSTABLE_API: AtomicBool = AtomicBool::new(false); + impl FileWriter { /// Create a new FileWriter with a desired output schema pub fn try_new( @@ -129,6 +134,20 @@ impl FileWriter { /// The output schema will be set based on the first batch of data to arrive. /// If no data arrives and the writer is finished then the write will fail. pub fn new_lazy(object_writer: ObjectWriter, options: FileWriterOptions) -> Self { + if let Some(format_version) = options.format_version { + if format_version > LanceFileVersion::Stable + && WARNED_ON_UNSTABLE_API + .compare_exchange( + false, + true, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ) + .is_ok() + { + warn!("You have requested an unstable format version. Files written with this format version may not be readable in the future! This is a development feature and should only be used for experimentation and never for production data."); + } + } Self { writer: object_writer, schema: None, @@ -143,6 +162,24 @@ impl FileWriter { } } + /// Write a series of record batches to a new file + /// + /// Returns the number of rows written + pub async fn create_file_with_batches( + store: &ObjectStore, + path: &Path, + schema: lance_core::datatypes::Schema, + batches: impl Iterator + Send, + options: FileWriterOptions, + ) -> Result { + let writer = store.create(path).await?; + let mut writer = Self::try_new(writer, schema, options)?; + for batch in batches { + writer.write_batch(&batch).await?; + } + Ok(writer.finish().await? as usize) + } + async fn do_write_buffer(writer: &mut ObjectWriter, buf: &[u8]) -> Result<()> { writer.write_all(buf).await?; let pad_bytes = pad_bytes::(buf.len()); @@ -309,11 +346,13 @@ impl FileWriter { location: location!(), })?; let repdef = RepDefBuilder::default(); + let num_rows = array.len() as u64; column_writer.maybe_encode( array.clone(), external_buffers, repdef, self.rows_written, + num_rows, ) }) .collect::>>() diff --git a/rust/lance-file/src/writer.rs b/rust/lance-file/src/writer.rs index bb51ec93800..ed132863551 100644 --- a/rust/lance-file/src/writer.rs +++ b/rust/lance-file/src/writer.rs @@ -25,7 +25,7 @@ use lance_io::object_store::ObjectStore; use lance_io::object_writer::ObjectWriter; use lance_io::traits::{WriteExt, Writer}; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tokio::io::AsyncWriteExt; use crate::format::metadata::{Metadata, StatisticsMetadata}; diff --git a/rust/lance-file/src/writer/statistics.rs b/rust/lance-file/src/writer/statistics.rs index 2ef4d081d48..7ba7c4a909e 100644 --- a/rust/lance-file/src/writer/statistics.rs +++ b/rust/lance-file/src/writer/statistics.rs @@ -486,7 +486,7 @@ fn get_boolean_statistics(arrays: &[&ArrayRef]) -> StatisticsRow { fn cast_dictionary_arrays<'a, T: ArrowDictionaryKeyType + 'static>( arrays: &'a [&'a ArrayRef], -) -> Vec<&Arc> { +) -> Vec<&'a Arc> { arrays .iter() .map(|x| x.as_dictionary::().values()) diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 12d38e56781..b09c20570b9 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -26,9 +26,12 @@ datafusion-physical-expr.workspace = true datafusion-sql.workspace = true datafusion.workspace = true deepsize.workspace = true +dirs.workspace = true +fst.workspace = true futures.workspace = true half.workspace = true itertools.workspace = true +jieba-rs = { workspace = true, optional = true } lance-arrow.workspace = true lance-core.workspace = true lance-datafusion.workspace = true @@ -50,6 +53,8 @@ serde_json.workspace = true serde.workspace = true snafu.workspace = true tantivy.workspace = true +lindera = { workspace = true, optional = true } +lindera-tantivy = { workspace = true, optional = true } tokio.workspace = true tracing.workspace = true tempfile.workspace = true @@ -61,6 +66,7 @@ uuid.workspace = true approx.workspace = true clap = { workspace = true, features = ["derive"] } criterion.workspace = true +env_logger = "0.11.6" lance-datagen.workspace = true lance-testing.workspace = true tempfile.workspace = true @@ -68,12 +74,23 @@ test-log.workspace = true datafusion-sql.workspace = true random_word = { version = "0.4.3", features = ["en"] } +[features] +protoc = ["dep:protobuf-src"] +tokenizer-lindera = ["lindera", "lindera-tantivy", "tokenizer-common"] +tokenizer-jieba = ["jieba-rs", "tokenizer-common"] +tokenizer-common = [] + [build-dependencies] prost-build.workspace = true +protobuf-src = { version = "2.1", optional = true } [target.'cfg(target_os = "linux")'.dev-dependencies] pprof.workspace = true +[package.metadata.docs.rs] +# docs.rs uses an older version of Ubuntu that does not have the necessary protoc version +features = ["protoc"] + [[bench]] name = "find_partitions" harness = false @@ -98,6 +115,10 @@ harness = false name = "sq" harness = false +[[bench]] +name = "ngram" +harness = false + [[bench]] name = "inverted" harness = false diff --git a/rust/lance-index/benches/4bitpq_dist_table.rs b/rust/lance-index/benches/4bitpq_dist_table.rs index e37fa4455c8..57f8e8ce2b0 100644 --- a/rust/lance-index/benches/4bitpq_dist_table.rs +++ b/rust/lance-index/benches/4bitpq_dist_table.rs @@ -3,7 +3,7 @@ //! Benchmark of building PQ distance table. -use std::iter::repeat; +use std::iter::repeat_n; use arrow_array::types::Float32Type; use arrow_array::{FixedSizeListArray, UInt8Array}; @@ -74,7 +74,7 @@ fn compute_distances(c: &mut Criterion) { let query = generate_random_array_with_seed::(DIM, [32; 32]); let mut rnd = StdRng::from_seed([32; 32]); - let code = UInt8Array::from_iter_values(repeat(rnd.gen::()).take(TOTAL * PQ)); + let code = UInt8Array::from_iter_values(repeat_n(rnd.gen::(), TOTAL * PQ)); for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() { let pq = ProductQuantizer::new( diff --git a/rust/lance-index/benches/hnsw.rs b/rust/lance-index/benches/hnsw.rs index b51d75d7469..e250dfffd83 100644 --- a/rust/lance-index/benches/hnsw.rs +++ b/rust/lance-index/benches/hnsw.rs @@ -15,7 +15,7 @@ use lance_index::vector::v3::subindex::IvfSubIndex; use pprof::criterion::{Output, PProfProfiler}; use lance_index::vector::{ - flat::storage::FlatStorage, + flat::storage::FlatFloatStorage, hnsw::builder::{HnswBuildParams, HNSW}, }; use lance_linalg::distance::DistanceType; @@ -31,7 +31,7 @@ fn bench_hnsw(c: &mut Criterion) { let data = generate_random_array_with_seed::(TOTAL * DIMENSION, SEED); let fsl = FixedSizeListArray::try_new_from_values(data, DIMENSION as i32).unwrap(); - let vectors = Arc::new(FlatStorage::new(fsl.clone(), DistanceType::L2)); + let vectors = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2)); let query = fsl.value(0); c.bench_function( diff --git a/rust/lance-index/benches/inverted.rs b/rust/lance-index/benches/inverted.rs index 7f1f7b16ae3..6cba5089f7b 100644 --- a/rust/lance-index/benches/inverted.rs +++ b/rust/lance-index/benches/inverted.rs @@ -14,10 +14,12 @@ use futures::stream; use itertools::Itertools; use lance_core::cache::FileMetadataCache; use lance_core::ROW_ID; +use lance_index::metrics::NoOpMetricsCollector; use lance_index::prefilter::NoFilter; +use lance_index::scalar::inverted::query::{FtsSearchParams, Operator}; use lance_index::scalar::inverted::{InvertedIndex, InvertedIndexBuilder}; use lance_index::scalar::lance_format::LanceIndexStore; -use lance_index::scalar::{FullTextSearchQuery, ScalarIndex}; +use lance_index::scalar::ScalarIndex; use lance_io::object_store::ObjectStore; use object_store::path::Path; #[cfg(target_os = "linux")] @@ -32,7 +34,7 @@ fn bench_inverted(c: &mut Criterion) { let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); let store = rt.block_on(async { Arc::new(LanceIndexStore::new( - ObjectStore::local(), + Arc::new(ObjectStore::local()), index_dir, FileMetadataCache::no_cache(), )) @@ -69,17 +71,19 @@ fn bench_inverted(c: &mut Criterion) { rt.block_on(async { builder.update(stream, store.as_ref()).await.unwrap() }); let invert_index = rt.block_on(InvertedIndex::load(store)).unwrap(); + let params = FtsSearchParams::new().with_limit(Some(10)); let no_filter = Arc::new(NoFilter); c.bench_function(format!("invert({TOTAL})").as_str(), |b| { b.to_async(&rt).iter(|| async { black_box( invert_index - .full_text_search( - &FullTextSearchQuery::new( - tokens[rand::random::() % tokens.len()].to_owned(), - ) - .limit(Some(10)), + .bm25_search( + &[tokens[rand::random::() % tokens.len()].to_owned()], + ¶ms, + Operator::Or, + false, no_filter.clone(), + &NoOpMetricsCollector, ) .await .unwrap(), diff --git a/rust/lance-index/benches/ngram.rs b/rust/lance-index/benches/ngram.rs new file mode 100644 index 00000000000..87055038c2f --- /dev/null +++ b/rust/lance-index/benches/ngram.rs @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{sync::Arc, time::Duration}; + +use arrow::array::AsArray; +use arrow_array::{RecordBatch, StringArray, UInt64Array}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use futures::stream; +use itertools::Itertools; +use lance_core::cache::FileMetadataCache; +use lance_core::ROW_ID; +use lance_index::metrics::NoOpMetricsCollector; +use lance_index::scalar::lance_format::LanceIndexStore; +use lance_index::scalar::ngram::{NGramIndex, NGramIndexBuilder, NGramIndexBuilderOptions}; +use lance_index::scalar::{ScalarIndex, TextQuery}; +use lance_io::object_store::ObjectStore; +use object_store::path::Path; +#[cfg(target_os = "linux")] +use pprof::criterion::{Output, PProfProfiler}; + +fn bench_ngram(c: &mut Criterion) { + const TOTAL: usize = 1_000_000; + + env_logger::init(); + + let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap(); + + let tempdir = tempfile::tempdir().unwrap(); + let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); + let store = rt.block_on(async { + Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + index_dir, + FileMetadataCache::no_cache(), + )) + }); + + // generate 2000 different tokens + let tokens = random_word::all(random_word::Lang::En); + let row_id_col = Arc::new(UInt64Array::from( + (0..TOTAL).map(|i| i as u64).collect_vec(), + )); + let docs = (0..TOTAL) + .map(|_| { + let num_words = rand::random::() % 30 + 1; + let doc = (0..num_words) + .map(|_| tokens[rand::random::() % tokens.len()]) + .collect::>(); + doc.join(" ") + }) + .collect_vec(); + let doc_col = Arc::new(StringArray::from(docs)); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("doc", arrow_schema::DataType::Utf8, false), + arrow_schema::Field::new(ROW_ID, arrow_schema::DataType::UInt64, false), + ]) + .into(), + vec![doc_col, row_id_col], + ) + .unwrap(); + + let batches = (0..1000).map(|i| batch.slice(i * 1000, 1000)).collect_vec(); + + let mut group = c.benchmark_group("train"); + + group.sample_size(10); + group.bench_function(format!("ngram_train({TOTAL})").as_str(), |b| { + b.to_async(&rt).iter(|| async { + let stream = RecordBatchStreamAdapter::new( + batch.schema(), + stream::iter(batches.clone().into_iter().map(Ok)), + ); + let stream = Box::pin(stream); + let mut builder = + NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap(); + let num_spill_files = builder.train(stream).await.unwrap(); + + builder + .write_index(store.as_ref(), num_spill_files, None) + .await + .unwrap(); + }) + }); + + drop(group); + + let mut group = c.benchmark_group("search"); + + group + .sample_size(10) + .measurement_time(Duration::from_secs(10)); + let index = rt.block_on(NGramIndex::load(store)).unwrap(); + group.bench_function(format!("ngram_search({TOTAL})").as_str(), |b| { + b.to_async(&rt).iter(|| async { + let sample_idx = rand::random::() % batch.num_rows(); + let sample = batch + .column(0) + .as_string::() + .value(sample_idx) + .to_string(); + black_box( + index + .search(&TextQuery::StringContains(sample), &NoOpMetricsCollector) + .await + .unwrap(), + ); + }) + }); +} + +#[cfg(target_os = "linux")] +criterion_group!( + name=benches; + config = Criterion::default() + .measurement_time(Duration::from_secs(10)) + .sample_size(10) + .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = bench_ngram); + +// Non-linux version does not support pprof. +#[cfg(not(target_os = "linux"))] +criterion_group!( + name=benches; + config = Criterion::default() + .measurement_time(Duration::from_secs(10)) + .sample_size(10); + targets = bench_ngram); + +criterion_main!(benches); diff --git a/rust/lance-index/benches/pq_dist_table.rs b/rust/lance-index/benches/pq_dist_table.rs index 515a309a0ff..a20b026a324 100644 --- a/rust/lance-index/benches/pq_dist_table.rs +++ b/rust/lance-index/benches/pq_dist_table.rs @@ -3,7 +3,7 @@ //! Benchmark of building PQ distance table. -use std::iter::repeat; +use std::iter::repeat_n; use arrow_array::types::Float32Type; use arrow_array::{FixedSizeListArray, UInt8Array}; @@ -72,7 +72,7 @@ fn compute_distances(c: &mut Criterion) { let query = generate_random_array_with_seed::(DIM, [32; 32]); let mut rnd = StdRng::from_seed([32; 32]); - let code = UInt8Array::from_iter_values(repeat(rnd.gen::()).take(TOTAL * PQ)); + let code = UInt8Array::from_iter_values(repeat_n(rnd.gen::(), TOTAL * PQ)); for dt in [DistanceType::L2, DistanceType::Cosine, DistanceType::Dot].iter() { let pq = ProductQuantizer::new( diff --git a/rust/lance-index/benches/sq.rs b/rust/lance-index/benches/sq.rs index f19aad60949..5bd8474b480 100644 --- a/rust/lance-index/benches/sq.rs +++ b/rust/lance-index/benches/sq.rs @@ -10,6 +10,7 @@ use arrow_schema::{DataType, Field, Schema}; use criterion::{criterion_group, criterion_main, Criterion}; use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt}; use lance_core::ROW_ID; +use lance_index::vector::storage::DistCalculator; use lance_index::vector::{ sq::storage::ScalarQuantizationStorage, storage::VectorStore, SQ_CODE_COLUMN, }; @@ -85,7 +86,7 @@ pub fn bench_storage(c: &mut Criterion) { b.iter(|| { let a = rng.gen_range(0..total as u32); let b = rng.gen_range(0..total as u32); - storage.distance_between(a, b) + storage.dist_calculator_from_id(a).distance(b); }); }, ); diff --git a/rust/lance-index/build.rs b/rust/lance-index/build.rs index 8a31fbf600c..402ef5012ca 100644 --- a/rust/lance-index/build.rs +++ b/rust/lance-index/build.rs @@ -7,6 +7,10 @@ use std::io::Result; fn main() -> Result<()> { println!("cargo:rerun-if-changed=protos"); + #[cfg(feature = "protoc")] + // Use vendored protobuf compiler if requested. + std::env::set_var("PROTOC", protobuf_src::protoc()); + let mut prost_build = prost_build::Config::new(); prost_build.protoc_arg("--experimental_allow_proto3_optional"); prost_build.compile_protos(&["./protos/index.proto"], &["./protos"])?; diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 0e6d2603a58..d8725848ede 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -16,9 +16,10 @@ use deepsize::DeepSizeOf; use lance_core::{Error, Result}; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; use std::convert::TryFrom; +pub mod metrics; pub mod optimize; pub mod prefilter; pub mod scalar; @@ -55,6 +56,11 @@ pub trait Index: Send + Sync + DeepSizeOf { /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result; + /// Prewarm the index. + /// + /// This will load the index into memory and cache it. + async fn prewarm(&self) -> Result<()>; + /// Get the type of the index fn index_type(&self) -> IndexType; @@ -79,6 +85,8 @@ pub enum IndexType { Inverted = 4, // Inverted + NGram = 5, // NGram + // 100+ and up for vector index. /// Flat vector index. Vector = 100, // Legacy vector index, alias to IvfPq @@ -96,6 +104,7 @@ impl std::fmt::Display for IndexType { Self::Bitmap => write!(f, "Bitmap"), Self::LabelList => write!(f, "LabelList"), Self::Inverted => write!(f, "Inverted"), + Self::NGram => write!(f, "NGram"), Self::Vector | Self::IvfPq => write!(f, "IVF_PQ"), Self::IvfFlat => write!(f, "IVF_FLAT"), Self::IvfSq => write!(f, "IVF_SQ"), @@ -114,6 +123,7 @@ impl TryFrom for IndexType { v if v == Self::BTree as i32 => Ok(Self::BTree), v if v == Self::Bitmap as i32 => Ok(Self::Bitmap), v if v == Self::LabelList as i32 => Ok(Self::LabelList), + v if v == Self::NGram as i32 => Ok(Self::NGram), v if v == Self::Inverted as i32 => Ok(Self::Inverted), v if v == Self::Vector as i32 => Ok(Self::Vector), v if v == Self::IvfFlat as i32 => Ok(Self::IvfFlat), @@ -133,14 +143,24 @@ impl IndexType { pub fn is_scalar(&self) -> bool { matches!( self, - Self::Scalar | Self::BTree | Self::Bitmap | Self::LabelList | Self::Inverted + Self::Scalar + | Self::BTree + | Self::Bitmap + | Self::LabelList + | Self::Inverted + | Self::NGram ) } pub fn is_vector(&self) -> bool { matches!( self, - Self::Vector | Self::IvfPq | Self::IvfHnswSq | Self::IvfHnswPq + Self::Vector + | Self::IvfPq + | Self::IvfHnswSq + | Self::IvfHnswPq + | Self::IvfFlat + | Self::IvfSq ) } } diff --git a/rust/lance-index/src/metrics.rs b/rust/lance-index/src/metrics.rs new file mode 100644 index 00000000000..ec74bc76795 --- /dev/null +++ b/rust/lance-index/src/metrics.rs @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// A trait used by the index to report metrics +/// +/// Callers can implement this trait to collect metrics +pub trait MetricsCollector: Send + Sync { + /// Record partition loads + /// + /// Many indices consist of partitions that may need to be loaded + /// into cache. For example, an inverted index or ngram index has a + /// posting list for each token. + /// + /// In the ideal case, these shards are in the cache and will not need + /// to be loaded from disk. This method should not be called if the + /// shard is in the cache. + fn record_parts_loaded(&self, num_parts: usize); + + /// Record a shard load + fn record_part_load(&self) { + self.record_parts_loaded(1); + } + + /// Record an index load + /// + /// This should be called when a scalar index is loaded from storage. + /// It should not be called if the index is already in memory. + fn record_index_loads(&self, num_indexes: usize); + + /// Record an index load + fn record_index_load(&self) { + self.record_index_loads(1); + } + + /// Record the number of "comparisons" made by the index + /// + /// What exactly constitutes a comparison depends on the index type. + /// For example, a B-tree index may make comparisons while searching for a value. + /// On the other hand, a bitmap index makes comparisons when computing the intersection + /// of two bitmaps. + /// + /// The goal is to provide some visibility into the compute cost of the search + fn record_comparisons(&self, num_comparisons: usize); +} + +/// A no-op metrics collector that does nothing +pub struct NoOpMetricsCollector; + +impl MetricsCollector for NoOpMetricsCollector { + fn record_parts_loaded(&self, _num_parts: usize) {} + fn record_index_loads(&self, _num_indexes: usize) {} + fn record_comparisons(&self, _num_comparisons: usize) {} +} + +#[derive(Default)] +pub struct LocalMetricsCollector { + parts_loaded: AtomicUsize, + index_loads: AtomicUsize, + comparisons: AtomicUsize, +} + +impl LocalMetricsCollector { + pub fn dump_into(self, other: &dyn MetricsCollector) { + other.record_parts_loaded(self.parts_loaded.load(Ordering::Relaxed)); + other.record_index_loads(self.index_loads.load(Ordering::Relaxed)); + other.record_comparisons(self.comparisons.load(Ordering::Relaxed)); + } +} + +impl MetricsCollector for LocalMetricsCollector { + fn record_parts_loaded(&self, num_parts: usize) { + self.parts_loaded.fetch_add(num_parts, Ordering::Relaxed); + } + + fn record_index_loads(&self, num_indexes: usize) { + self.index_loads.fetch_add(num_indexes, Ordering::Relaxed); + } + + fn record_comparisons(&self, num_comparisons: usize) { + self.comparisons + .fetch_add(num_comparisons, Ordering::Relaxed); + } +} diff --git a/rust/lance-index/src/optimize.rs b/rust/lance-index/src/optimize.rs index 5f9b6a78edc..558640f2d5a 100644 --- a/rust/lance-index/src/optimize.rs +++ b/rust/lance-index/src/optimize.rs @@ -20,6 +20,19 @@ pub struct OptimizeOptions { /// the index names to optimize. If None, all indices will be optimized. pub index_names: Option>, + + /// whether to retrain the whole index. Default: false. + /// + /// If true, the index will be retrained based on the current data, + /// `num_indices_to_merge` will be ignored, and all indices will be merged into one. + /// If false, the index will be optimized by merging `num_indices_to_merge` indices. + /// + /// This is useful when the data distribution has changed significantly, + /// and we want to retrain the index to improve the search quality. + /// This would be faster than re-create the index from scratch. + /// + /// NOTE: this option is only supported for v3 vector indices. + pub retrain: bool, } impl Default for OptimizeOptions { @@ -27,6 +40,43 @@ impl Default for OptimizeOptions { Self { num_indices_to_merge: 1, index_names: None, + retrain: false, } } } + +impl OptimizeOptions { + pub fn new() -> Self { + Self { + num_indices_to_merge: 1, + index_names: None, + retrain: false, + } + } + + pub fn append() -> Self { + Self { + num_indices_to_merge: 0, + index_names: None, + retrain: false, + } + } + + pub fn retrain() -> Self { + Self { + num_indices_to_merge: 0, + index_names: None, + retrain: true, + } + } + + pub fn num_indices_to_merge(mut self, num: usize) -> Self { + self.num_indices_to_merge = num; + self + } + + pub fn index_names(mut self, names: Vec) -> Self { + self.index_names = Some(names); + self + } +} diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index effabea0be2..ab6bb6589c8 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -3,7 +3,7 @@ //! Scalar indices for metadata search & filtering -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, ops::Bound, sync::Arc}; @@ -11,6 +11,7 @@ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow_array::{ListArray, RecordBatch}; use arrow_schema::{Field, Schema}; use async_trait::async_trait; +use datafusion::functions::string::contains::ContainsFunc; use datafusion::functions_array::array_has; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::{scalar::ScalarValue, Column}; @@ -18,11 +19,14 @@ use datafusion_common::{scalar::ScalarValue, Column}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use deepsize::DeepSizeOf; +use inverted::query::{fill_fts_query_column, FtsQuery, FtsQueryNode, FtsSearchParams, MatchQuery}; use inverted::TokenizerConfig; use lance_core::utils::mask::RowIdTreeMap; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use serde::{Deserialize, Serialize}; +use snafu::location; +use crate::metrics::MetricsCollector; use crate::{Index, IndexParams, IndexType}; pub mod bitmap; @@ -32,6 +36,7 @@ pub mod flat; pub mod inverted; pub mod label_list; pub mod lance_format; +pub mod ngram; pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index"; @@ -40,6 +45,7 @@ pub enum ScalarIndexType { BTree, Bitmap, LabelList, + NGram, Inverted, } @@ -51,6 +57,7 @@ impl TryFrom for ScalarIndexType { IndexType::BTree | IndexType::Scalar => Ok(Self::BTree), IndexType::Bitmap => Ok(Self::Bitmap), IndexType::LabelList => Ok(Self::LabelList), + IndexType::NGram => Ok(Self::NGram), IndexType::Inverted => Ok(Self::Inverted), _ => Err(Error::InvalidInput { source: format!("Index type {:?} is not a scalar index", value).into(), @@ -85,6 +92,7 @@ impl IndexParams for ScalarIndexParams { Some(ScalarIndexType::Bitmap) => IndexType::Bitmap, Some(ScalarIndexType::LabelList) => IndexType::LabelList, Some(ScalarIndexType::Inverted) => IndexType::Inverted, + Some(ScalarIndexType::NGram) => IndexType::NGram, } } @@ -93,7 +101,7 @@ impl IndexParams for ScalarIndexParams { } } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct InvertedIndexParams { /// If true, store the position of the term in the document /// This can significantly increase the size of the index @@ -101,6 +109,7 @@ pub struct InvertedIndexParams { /// Default is true pub with_position: bool, + #[serde(flatten)] pub tokenizer_config: TokenizerConfig, } @@ -165,7 +174,7 @@ pub trait IndexWriter: Send { #[async_trait] pub trait IndexReader: Send + Sync { /// Read the n-th record batch from the file - async fn read_record_batch(&self, n: u32) -> Result; + async fn read_record_batch(&self, n: u64, batch_size: u64) -> Result; /// Read the range of rows from the file. /// If projection is Some, only return the columns in the projection, /// nested columns like Some(&["x.y"]) are not supported. @@ -176,7 +185,7 @@ pub trait IndexReader: Send + Sync { projection: Option<&[&str]>, ) -> Result; /// Return the number of batches in the file - async fn num_batches(&self) -> u32; + async fn num_batches(&self, batch_size: u64) -> u32; /// Return the number of rows in the file fn num_rows(&self) -> usize; /// Return the metadata of the file @@ -206,6 +215,12 @@ pub trait IndexStore: std::fmt::Debug + Send + Sync + DeepSizeOf { /// /// This is often useful when remapping or updating async fn copy_index_file(&self, name: &str, dest_store: &dyn IndexStore) -> Result<()>; + + /// Rename an index file + async fn rename_index_file(&self, name: &str, new_name: &str) -> Result<()>; + + /// Delete an index file (used in the tmp spill store to keep tmp size down) + async fn delete_index_file(&self, name: &str) -> Result<()>; } /// Different scalar indices may support different kinds of queries @@ -227,6 +242,10 @@ pub trait AnyQuery: std::fmt::Debug + Any + Send + Sync { fn to_expr(&self, col: String) -> Expr; /// Compare this query to another query fn dyn_eq(&self, other: &dyn AnyQuery) -> bool; + /// If true, the query results are inexact and will need rechecked + fn needs_recheck(&self) -> bool { + false + } } impl PartialEq for dyn AnyQuery { @@ -234,17 +253,14 @@ impl PartialEq for dyn AnyQuery { self.dyn_eq(other) } } - /// A full text search query #[derive(Debug, Clone, PartialEq)] pub struct FullTextSearchQuery { - /// The columns to search, - /// if empty, search all indexed columns - pub columns: Vec, - /// The full text search query - pub query: String, + pub query: FtsQuery, + /// The maximum number of results to return pub limit: Option, + /// The wand factor to use for ranking /// if None, use the default value of 1.0 /// Increasing this value will reduce the recall and improve the performance @@ -253,22 +269,51 @@ pub struct FullTextSearchQuery { } impl FullTextSearchQuery { + /// Create a new terms query pub fn new(query: String) -> Self { + let query = MatchQuery::new(query).into(); + Self { + query, + limit: None, + wand_factor: None, + } + } + + /// Create a new fuzzy query + pub fn new_fuzzy(term: String, max_distance: Option) -> Self { + let query = MatchQuery::new(term).with_fuzziness(max_distance).into(); Self { query, limit: None, - columns: vec![], wand_factor: None, } } - pub fn columns(mut self, columns: Option>) -> Self { - if let Some(columns) = columns { - self.columns = columns; + /// Create a new compound query + pub fn new_query(query: FtsQuery) -> Self { + Self { + query, + limit: None, + wand_factor: None, } - self } + /// Set the column to search over + /// This is available for only MatchQuery and PhraseQuery + pub fn with_column(mut self, column: String) -> Result { + self.query = fill_fts_query_column(&self.query, &[column], true)?; + Ok(self) + } + + /// Set the column to search over + /// This is available for only MatchQuery + pub fn with_columns(mut self, columns: &[String]) -> Result { + self.query = fill_fts_query_column(&self.query, columns, true)?; + Ok(self) + } + + /// limit the number of results to return + /// if None, return all results pub fn limit(mut self, limit: Option) -> Self { self.limit = limit; self @@ -278,6 +323,17 @@ impl FullTextSearchQuery { self.wand_factor = wand_factor; self } + + pub fn columns(&self) -> HashSet { + self.query.columns() + } + + pub fn params(&self) -> FtsSearchParams { + FtsSearchParams { + limit: self.limit.map(|limit| limit as usize), + wand_factor: self.wand_factor.unwrap_or(1.0), + } + } } /// A query that a basic scalar index (e.g. btree / bitmap) can satisfy @@ -390,9 +446,9 @@ impl AnyQuery for SargableQuery { .collect::>(), false, ), - Self::FullTextSearch(query) => { - col_expr.like(Expr::Literal(ScalarValue::Utf8(Some(query.query.clone())))) - } + Self::FullTextSearch(query) => col_expr.like(Expr::Literal(ScalarValue::Utf8(Some( + query.query.to_string(), + )))), Self::IsNull() => col_expr.is_null(), Self::Equals(value) => col_expr.eq(Expr::Literal(value.clone())), } @@ -477,13 +533,96 @@ impl AnyQuery for LabelListQuery { } } +/// A query that a NGramIndex can satisfy +#[derive(Debug, Clone, PartialEq)] +pub enum TextQuery { + /// Retrieve all row ids where the text contains the given string + StringContains(String), + // TODO: In the future we should be able to do string-insensitive contains + // as well as partial matches (e.g. LIKE 'foo%') and potentially even + // some regular expressions +} + +impl AnyQuery for TextQuery { + fn as_any(&self) -> &dyn Any { + self + } + + fn format(&self, col: &str) -> String { + format!("{}", self.to_expr(col.to_string())) + } + + fn to_expr(&self, col: String) -> Expr { + match self { + Self::StringContains(substr) => Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ContainsFunc::new().into()), + args: vec![ + Expr::Column(Column::new_unqualified(col)), + Expr::Literal(ScalarValue::Utf8(Some(substr.clone()))), + ], + }), + } + } + + fn dyn_eq(&self, other: &dyn AnyQuery) -> bool { + match other.as_any().downcast_ref::() { + Some(o) => self == o, + None => false, + } + } + + fn needs_recheck(&self) -> bool { + true + } +} + +/// The result of a search operation against a scalar index +#[derive(Debug, PartialEq)] +pub enum SearchResult { + /// The exact row ids that satisfy the query + Exact(RowIdTreeMap), + /// Any row id satisfying the query will be in this set but not every + /// row id in this set will satisfy the query, a further recheck step + /// is needed + AtMost(RowIdTreeMap), + /// All of the given row ids satisfy the query but there may be more + /// + /// No scalar index actually returns this today but it can arise from + /// boolean operations (e.g. NOT(AtMost(x)) == AtLeast(NOT(x))) + AtLeast(RowIdTreeMap), +} + +impl SearchResult { + pub fn row_ids(&self) -> &RowIdTreeMap { + match self { + Self::Exact(row_ids) => row_ids, + Self::AtMost(row_ids) => row_ids, + Self::AtLeast(row_ids) => row_ids, + } + } + + pub fn is_exact(&self) -> bool { + matches!(self, Self::Exact(_)) + } +} + /// A trait for a scalar index, a structure that can determine row ids that satisfy scalar queries #[async_trait] pub trait ScalarIndex: Send + Sync + std::fmt::Debug + Index + DeepSizeOf { /// Search the scalar index /// /// Returns all row ids that satisfy the query, these row ids are not necessarily ordered - async fn search(&self, query: &dyn AnyQuery) -> Result; + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result; + + /// Returns true if the query can be answered exactly + /// + /// If false is returned then the query still may be answered exactly but if true is returned + /// then the query must be answered exactly + fn can_answer_exact(&self, query: &dyn AnyQuery) -> bool; /// Load the scalar index from storage async fn load(store: Arc) -> Result> diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index fc531b3bf7a..c994c10f930 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -10,7 +10,7 @@ use std::{ }; use arrow::array::BinaryBuilder; -use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; +use arrow_array::{new_empty_array, new_null_array, Array, BinaryArray, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use async_trait::async_trait; use datafusion::physical_plan::SendableRecordBatchStream; @@ -20,12 +20,12 @@ use futures::TryStreamExt; use lance_core::{utils::mask::RowIdTreeMap, Error, Result}; use roaring::RoaringBitmap; use serde::Serialize; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; -use crate::{Index, IndexType}; +use crate::{metrics::MetricsCollector, Index, IndexType}; -use super::{btree::OrderableScalarValue, SargableQuery}; +use super::{btree::OrderableScalarValue, SargableQuery, SearchResult}; use super::{btree::TrainingSource, AnyQuery, IndexStore, ScalarIndex}; pub const BITMAP_LOOKUP_NAME: &str = "bitmap_page_lookup.lance"; @@ -37,6 +37,10 @@ pub const BITMAP_LOOKUP_NAME: &str = "bitmap_page_lookup.lance"; #[derive(Clone, Debug)] pub struct BitmapIndex { index_map: BTreeMap, + // We put null in its own map to avoid it matching range queries (arrow-rs considers null to come before minval) + null_map: RowIdTreeMap, + // The data type of the values in the index + value_type: DataType, // Memoized index_map size for DeepSizeOf index_map_size_bytes: usize, store: Arc, @@ -45,11 +49,15 @@ pub struct BitmapIndex { impl BitmapIndex { fn new( index_map: BTreeMap, + null_map: RowIdTreeMap, + value_type: DataType, index_map_size_bytes: usize, store: Arc, ) -> Self { Self { index_map, + null_map, + value_type, index_map_size_bytes, store, } @@ -58,13 +66,18 @@ impl BitmapIndex { // creates a new BitmapIndex from a serialized RecordBatch fn try_from_serialized(data: RecordBatch, store: Arc) -> Result { if data.num_rows() == 0 { - return Err(Error::Internal { - message: "attempt to load bitmap index from empty record batch".into(), - location: location!(), - }); + let data_type = data.schema().field(0).data_type().clone(); + return Ok(Self::new( + BTreeMap::new(), + RowIdTreeMap::default(), + data_type, + 0, + store, + )); } let dict_keys = data.column(0); + let value_type = dict_keys.data_type().clone(); let binary_bitmaps = data.column(1); let bitmap_binary_array = binary_bitmaps .as_any() @@ -74,6 +87,7 @@ impl BitmapIndex { let mut index_map: BTreeMap = BTreeMap::new(); let mut index_map_size_bytes = 0; + let mut null_map = RowIdTreeMap::default(); for idx in 0..data.num_rows() { let key = OrderableScalarValue(ScalarValue::try_from_array(dict_keys, idx)?); let bitmap_bytes = bitmap_binary_array.value(idx); @@ -82,10 +96,20 @@ impl BitmapIndex { index_map_size_bytes += key.deep_size_of(); // This should be a reasonable approximation of the RowIdTreeMap size index_map_size_bytes += bitmap_bytes.len(); - index_map.insert(key, bitmap); + if key.0.is_null() { + null_map = bitmap; + } else { + index_map.insert(key, bitmap); + } } - Ok(Self::new(index_map, index_map_size_bytes, store)) + Ok(Self::new( + index_map, + null_map, + value_type, + index_map_size_bytes, + store, + )) } } @@ -125,6 +149,13 @@ impl Index for BitmapIndex { }) } + async fn prewarm(&self) -> Result<()> { + // TODO: Bitmap index essentially pre-warms on load right now. This is a problem for some + // of the larger bitmap indices (e.g. label_list). We should probably change it to behave + // like other indices and then we will need to implement this. + Ok(()) + } + fn index_type(&self) -> IndexType { IndexType::Bitmap } @@ -147,13 +178,22 @@ impl Index for BitmapIndex { #[async_trait] impl ScalarIndex for BitmapIndex { #[instrument(name = "bitmap_search", level = "debug", skip_all)] - async fn search(&self, query: &dyn AnyQuery) -> Result { + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); let row_ids = match query { SargableQuery::Equals(val) => { - let key = OrderableScalarValue(val.clone()); - self.index_map.get(&key).cloned().unwrap_or_default() + metrics.record_comparisons(1); + if val.is_null() { + self.null_map.clone() + } else { + let key = OrderableScalarValue(val.clone()); + self.index_map.get(&key).cloned().unwrap_or_default() + } } SargableQuery::Range(start, end) => { let range_start = match start { @@ -174,30 +214,28 @@ impl ScalarIndex for BitmapIndex { .map(|(_, v)| v) .collect::>(); + metrics.record_comparisons(maps.len()); RowIdTreeMap::union_all(&maps) } SargableQuery::IsIn(values) => { let mut union_bitmap = RowIdTreeMap::default(); + metrics.record_comparisons(values.len()); for val in values { - let key = OrderableScalarValue(val.clone()); - if let Some(bitmap) = self.index_map.get(&key) { - union_bitmap |= bitmap.clone(); + if val.is_null() { + union_bitmap |= self.null_map.clone(); + } else { + let key = OrderableScalarValue(val.clone()); + if let Some(bitmap) = self.index_map.get(&key) { + union_bitmap |= bitmap.clone(); + } } } union_bitmap } SargableQuery::IsNull() => { - if let Some(array) = self - .index_map - .iter() - .find(|(key, _)| key.0.is_null()) - .map(|(_, value)| value) - { - array.clone() - } else { - RowIdTreeMap::default() - } + metrics.record_comparisons(1); + self.null_map.clone() } SargableQuery::FullTextSearch(_) => { return Err(Error::NotSupported { @@ -207,7 +245,11 @@ impl ScalarIndex for BitmapIndex { } }; - Ok(row_ids) + Ok(SearchResult::Exact(row_ids)) + } + + fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool { + true } async fn load(store: Arc) -> Result> { @@ -241,7 +283,7 @@ impl ScalarIndex for BitmapIndex { (key.0.clone(), bitmap) }) .collect::>(); - write_bitmap_index(state, dest_store).await + write_bitmap_index(state, dest_store, &self.value_type).await } /// Add the new data into the index, creating an updated version of the index in `dest_store` @@ -250,11 +292,17 @@ impl ScalarIndex for BitmapIndex { new_data: SendableRecordBatchStream, dest_store: &dyn IndexStore, ) -> Result<()> { - let state = self + let mut state = self .index_map .iter() .map(|(key, bitmap)| (key.0.clone(), bitmap.clone())) .collect::>(); + + // Also insert the null map + let ex_null = new_null_array(&self.value_type, 1); + let ex_null = ScalarValue::try_from_array(ex_null.as_ref(), 0)?; + state.insert(ex_null, self.null_map.clone()); + do_train_bitmap_index(new_data, state, dest_store).await } } @@ -293,9 +341,14 @@ where async fn write_bitmap_index( state: HashMap, index_store: &dyn IndexStore, + value_type: &DataType, ) -> Result<()> { let keys_iter = state.keys().cloned(); - let keys_array = ScalarValue::iter_to_array(keys_iter)?; + let keys_array = if state.is_empty() { + new_empty_array(value_type) + } else { + ScalarValue::iter_to_array(keys_iter)? + }; let values_iter = state.into_values(); let binary_bitmap_array = get_bitmaps_from_iter(values_iter); @@ -315,6 +368,7 @@ async fn do_train_bitmap_index( mut state: HashMap, index_store: &dyn IndexStore, ) -> Result<()> { + let value_type = data_source.schema().field(0).data_type().clone(); while let Some(batch) = data_source.try_next().await? { debug_assert_eq!(batch.num_columns(), 2); debug_assert_eq!(*batch.column(1).data_type(), DataType::UInt64); @@ -333,7 +387,7 @@ async fn do_train_bitmap_index( } } - write_bitmap_index(state, index_store).await + write_bitmap_index(state, index_store, &value_type).await } pub async fn train_bitmap_index( diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index ce23f85d851..b155e998e77 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -10,46 +10,58 @@ use std::{ sync::Arc, }; -use arrow_array::{Array, RecordBatch, UInt32Array}; +use arrow_array::{new_empty_array, Array, RecordBatch, UInt32Array}; use arrow_schema::{DataType, Field, Schema, SortOptions}; use async_trait::async_trait; -use datafusion::{ - functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}, - physical_plan::{ - sorts::sort_preserving_merge::SortPreservingMergeExec, stream::RecordBatchStreamAdapter, - union::UnionExec, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, - }, +use datafusion::physical_plan::{ + sorts::sort_preserving_merge::SortPreservingMergeExec, stream::RecordBatchStreamAdapter, + union::UnionExec, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, }; use datafusion_common::{DataFusionError, ScalarValue}; -use datafusion_expr::Accumulator; -use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; -use deepsize::DeepSizeOf; +use datafusion_physical_expr::{expressions::Column, LexOrdering, PhysicalSortExpr}; +use deepsize::{Context, DeepSizeOf}; use futures::{ future::BoxFuture, stream::{self}, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, }; use lance_core::{ - utils::{mask::RowIdTreeMap, tokio::get_num_compute_intensive_cpus}, + utils::{ + mask::RowIdTreeMap, + tokio::get_num_compute_intensive_cpus, + tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS}, + }, Error, Result, }; use lance_datafusion::{ chunker::chunk_concat_stream, exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; +use log::debug; +use moka::sync::Cache; use roaring::RoaringBitmap; use serde::{Serialize, Serializer}; -use snafu::{location, Location}; +use snafu::location; +use tracing::info; use crate::{Index, IndexType}; use super::{ - flat::FlatIndexMetadata, AnyQuery, IndexReader, IndexStore, IndexWriter, SargableQuery, - ScalarIndex, + flat::FlatIndexMetadata, AnyQuery, IndexReader, IndexStore, IndexWriter, MetricsCollector, + SargableQuery, ScalarIndex, SearchResult, }; const BTREE_LOOKUP_NAME: &str = "page_lookup.lance"; const BTREE_PAGES_NAME: &str = "page_data.lance"; +pub const DEFAULT_BTREE_BATCH_SIZE: u64 = 4096; +const BATCH_SIZE_META_KEY: &str = "batch_size"; + +lazy_static::lazy_static! { + static ref CACHE_SIZE: u64 = std::env::var("LANCE_BTREE_CACHE_SIZE") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(512 * 1024 * 1024); +} /// Wraps a ScalarValue and implements Ord (ScalarValue only implements PartialOrd) #[derive(Clone, Debug)] @@ -530,7 +542,7 @@ impl Ord for OrderableScalarValue { } } -#[derive(Debug, DeepSizeOf)] +#[derive(Debug, DeepSizeOf, PartialEq, Eq)] struct PageRecord { max: OrderableScalarValue, page_number: u32, @@ -548,7 +560,7 @@ impl BTreeMapExt for BTreeMap { } /// An in-memory structure that can quickly satisfy scalar queries using a btree of ScalarValue -#[derive(Debug, DeepSizeOf)] +#[derive(Debug, DeepSizeOf, PartialEq, Eq)] pub struct BTreeLookup { tree: BTreeMap>, /// Pages where the value may be null @@ -560,20 +572,13 @@ impl BTreeLookup { Self { tree, null_pages } } - fn all_page_ids(&self) -> Vec { - let mut ids = self - .tree - .iter() - .flat_map(|(_, pages)| pages) - .map(|page| page.page_number) - .collect::>(); - ids.dedup(); - ids - } - // All pages that could have a value equal to val fn pages_eq(&self, query: &OrderableScalarValue) -> Vec { - self.pages_between((Bound::Included(query), Bound::Excluded(query))) + if query.0.is_null() { + self.pages_null() + } else { + self.pages_between((Bound::Included(query), Bound::Excluded(query))) + } } // All pages that could have a value equal to one of the values @@ -631,6 +636,25 @@ impl BTreeLookup { // matches an upper bound. This will all be moot if/when we merge pages. Bound::Excluded(upper) => Bound::Included(upper), }; + + match (lower_bound, upper_bound) { + (Bound::Excluded(lower), Bound::Excluded(upper)) + | (Bound::Excluded(lower), Bound::Included(upper)) + | (Bound::Included(lower), Bound::Excluded(upper)) => { + // It's not really clear what (Included(5), Excluded(5)) would mean so we + // interpret it as an empty range which matches rust's BTreeMap behavior + if lower >= upper { + return vec![]; + } + } + (Bound::Included(lower), Bound::Included(upper)) => { + if lower > upper { + return vec![]; + } + } + _ => {} + } + let candidates = self .tree .range((lower_bound, upper_bound)) @@ -653,6 +677,42 @@ impl BTreeLookup { } } +// Caches btree pages in memory +#[derive(Debug)] +struct BTreeCache(Cache>); + +impl DeepSizeOf for BTreeCache { + fn deep_size_of_children(&self, _: &mut Context) -> usize { + self.0.iter().map(|(_, v)| v.deep_size_of()).sum() + } +} + +// We only need to open a file reader for pages if we need to load a page. If all +// pages are cached we don't open it. If we do open it we should only open it once. +#[derive(Clone)] +struct LazyIndexReader { + index_reader: Arc>>>, + store: Arc, +} + +impl LazyIndexReader { + fn new(store: Arc) -> Self { + Self { + index_reader: Arc::new(tokio::sync::Mutex::new(None)), + store, + } + } + + async fn get(&self) -> Result> { + let mut reader = self.index_reader.lock().await; + if reader.is_none() { + let index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; + *reader = Some(index_reader); + } + Ok(reader.as_ref().unwrap().clone()) + } +} + /// A btree index satisfies scalar queries using a b tree /// /// The upper layers of the btree are expected to be cached and, when unloaded, @@ -671,8 +731,10 @@ impl BTreeLookup { #[derive(Clone, Debug, DeepSizeOf)] pub struct BTreeIndex { page_lookup: Arc, + page_cache: Arc, store: Arc, sub_index: Arc, + batch_size: u64, } impl BTreeIndex { @@ -681,39 +743,77 @@ impl BTreeIndex { null_pages: Vec, store: Arc, sub_index: Arc, + batch_size: u64, ) -> Self { let page_lookup = Arc::new(BTreeLookup::new(tree, null_pages)); + let page_cache = Arc::new(BTreeCache( + Cache::builder() + .max_capacity(*CACHE_SIZE) + .weigher(|_, v: &Arc| v.deep_size_of() as u32) + .build(), + )); Self { page_lookup, + page_cache, store, sub_index, + batch_size, } } + async fn lookup_page( + &self, + page_number: u32, + index_reader: LazyIndexReader, + metrics: &dyn MetricsCollector, + ) -> Result> { + if let Some(cached) = self.page_cache.0.get(&page_number) { + return Ok(cached); + } + metrics.record_part_load(); + info!(target: TRACE_IO_EVENTS, type=IO_TYPE_LOAD_SCALAR_PART, index_type="btree", part_id=page_number); + let index_reader = index_reader.get().await?; + let serialized_page = index_reader + .read_record_batch(page_number as u64, self.batch_size) + .await?; + let subindex = self.sub_index.load_subindex(serialized_page).await?; + self.page_cache.0.insert(page_number, subindex.clone()); + Ok(subindex) + } + async fn search_page( &self, query: &SargableQuery, page_number: u32, - index_reader: Arc, + index_reader: LazyIndexReader, + metrics: &dyn MetricsCollector, ) -> Result { - let serialized_page = index_reader.read_record_batch(page_number).await?; - let subindex = self.sub_index.load_subindex(serialized_page).await?; + let subindex = self.lookup_page(page_number, index_reader, metrics).await?; // TODO: If this is an IN query we can perhaps simplify the subindex query by restricting it to the // values that might be in the page. E.g. if we are searching for X IN [5, 3, 7] and five is in pages // 1 and 2 and three is in page 2 and seven is in pages 8 and 9 then when we search page 2 we only need // to search for X IN [5, 3] - subindex.search(query).await + match subindex.search(query, metrics).await? { + SearchResult::Exact(map) => Ok(map), + _ => Err(Error::Internal { + message: "BTree sub-indices need to return exact results".to_string(), + location: location!(), + }), + } } - fn try_from_serialized(data: RecordBatch, store: Arc) -> Result { + fn try_from_serialized( + data: RecordBatch, + store: Arc, + batch_size: u64, + ) -> Result { let mut map = BTreeMap::>::new(); let mut null_pages = Vec::::new(); if data.num_rows() == 0 { - return Err(Error::Internal { - message: "attempt to load btree index from empty stats batch".into(), - location: location!(), - }); + let data_type = data.column(0).data_type().clone(); + let sub_index = Arc::new(FlatIndexMetadata::new(data_type)); + return Ok(Self::new(map, null_pages, store, sub_index, batch_size)); } let mins = data.column(0); @@ -735,9 +835,13 @@ impl BTreeIndex { let null_count = null_counts.values()[idx]; let page_number = page_numbers.values()[idx]; - map.entry(min) - .or_default() - .push(PageRecord { max, page_number }); + // If the page is entirely null don't even bother putting it in the tree + if !max.0.is_null() { + map.entry(min) + .or_default() + .push(PageRecord { max, page_number }); + } + if null_count > 0 { null_pages.push(page_number); } @@ -751,22 +855,18 @@ impl BTreeIndex { // TODO: Support other page types? let sub_index = Arc::new(FlatIndexMetadata::new(data_type.clone())); - Ok(Self::new(map, null_pages, store, sub_index)) + Ok(Self::new(map, null_pages, store, sub_index, batch_size)) } /// Create a stream of all the data in the index, in the same format used to train the index async fn into_data_stream(self) -> Result { let reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; - let pages = self.page_lookup.all_page_ids(); let schema = self.sub_index.schema().clone(); - let batches = IndexReaderStream { - reader, - pages, - idx: 0, - } - .map(|fut| fut.map_err(DataFusionError::from)) - .buffered(self.store.io_parallelism()) - .boxed(); + let reader_stream = IndexReaderStream::new(reader, self.batch_size).await; + let batches = reader_stream + .map(|fut| fut.map_err(DataFusionError::from)) + .buffered(self.store.io_parallelism()) + .boxed(); Ok(RecordBatchStreamAdapter::new(schema, batches)) } } @@ -816,6 +916,11 @@ impl Index for BTreeIndex { }) } + async fn prewarm(&self) -> Result<()> { + // TODO: BTree can (and should) support pre-warming by loading the pages into memory + Ok(()) + } + fn index_type(&self) -> IndexType { IndexType::BTree } @@ -843,8 +948,10 @@ impl Index for BTreeIndex { let mut frag_ids = RoaringBitmap::default(); let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; - for page_number in self.page_lookup.all_page_ids() { - let serialized = sub_index_reader.read_record_batch(page_number).await?; + let mut reader_stream = IndexReaderStream::new(sub_index_reader, self.batch_size) + .await + .buffered(self.store.io_parallelism()); + while let Some(serialized) = reader_stream.try_next().await? { let page = self.sub_index.load_subindex(serialized).await?; frag_ids |= page.calculate_included_frags().await?; } @@ -855,7 +962,11 @@ impl Index for BTreeIndex { #[async_trait] impl ScalarIndex for BTreeIndex { - async fn search(&self, query: &dyn AnyQuery) -> Result { + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); let pages = match query { SargableQuery::Equals(val) => self @@ -873,28 +984,44 @@ impl ScalarIndex for BTreeIndex { )), SargableQuery::IsNull() => self.page_lookup.pages_null(), }; - let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; + let lazy_index_reader = LazyIndexReader::new(self.store.clone()); let page_tasks = pages .into_iter() .map(|page_index| { - self.search_page(query, page_index, sub_index_reader.clone()) + self.search_page(query, page_index, lazy_index_reader.clone(), metrics) .boxed() }) .collect::>(); - stream::iter(page_tasks) + debug!("Searching {} btree pages", page_tasks.len()); + let row_ids = stream::iter(page_tasks) // I/O and compute mixed here but important case is index in cache so // use compute intensive thread count .buffered(get_num_compute_intensive_cpus()) .try_collect::() - .await + .await?; + Ok(SearchResult::Exact(row_ids)) + } + + fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool { + true } async fn load(store: Arc) -> Result> { let page_lookup_file = store.open_index_file(BTREE_LOOKUP_NAME).await?; - let serialized_lookup = page_lookup_file.read_record_batch(0).await?; + let num_rows_in_lookup = page_lookup_file.num_rows(); + let serialized_lookup = page_lookup_file + .read_range(0..num_rows_in_lookup, None) + .await?; + let file_schema = page_lookup_file.schema(); + let batch_size = file_schema + .metadata + .get(BATCH_SIZE_META_KEY) + .map(|bs| bs.parse().unwrap_or(DEFAULT_BTREE_BATCH_SIZE)) + .unwrap_or(DEFAULT_BTREE_BATCH_SIZE); Ok(Arc::new(Self::try_from_serialized( serialized_lookup, store, + batch_size, )?)) } @@ -909,13 +1036,11 @@ impl ScalarIndex for BTreeIndex { .await?; let sub_index_reader = self.store.open_index_file(BTREE_PAGES_NAME).await?; - - for page_number in self.page_lookup.all_page_ids() { - let old_serialized = sub_index_reader.read_record_batch(page_number).await?; - let remapped = self - .sub_index - .remap_subindex(old_serialized, mapping) - .await?; + let mut reader_stream = IndexReaderStream::new(sub_index_reader, self.batch_size) + .await + .buffered(self.store.io_parallelism()); + while let Some(serialized) = reader_stream.try_next().await? { + let remapped = self.sub_index.remap_subindex(serialized, mapping).await?; sub_index_file.write_record_batch(remapped).await?; } @@ -934,7 +1059,13 @@ impl ScalarIndex for BTreeIndex { ) -> Result<()> { // Merge the existing index data with the new data and then retrain the index on the merged stream let merged_data_source = Box::new(BTreeUpdater::new(self.clone(), new_data)); - train_btree_index(merged_data_source, self.sub_index.as_ref(), dest_store).await + train_btree_index( + merged_data_source, + self.sub_index.as_ref(), + dest_store, + DEFAULT_BTREE_BATCH_SIZE as u32, + ) + .await } } @@ -944,39 +1075,24 @@ struct BatchStats { null_count: u32, } -// See https://github.com/apache/arrow-datafusion/issues/8031 for the underlying issue. We use -// MinAccumulator / MaxAccumulator to retrieve the min/max values and these are unreliable in the -// presence of NaN -fn check_for_nan(value: ScalarValue) -> Result { - match value { - ScalarValue::Float32(Some(val)) if val.is_nan() => Err(Error::NotSupported { - source: "Scalar indices cannot currently be created on columns with NaN values".into(), - location: location!(), - }), - ScalarValue::Float64(Some(val)) if val.is_nan() => Err(Error::NotSupported { - source: "Scalar indices cannot currently be created on columns with NaN values".into(), +fn analyze_batch(batch: &RecordBatch) -> Result { + let values = batch.column(0); + if values.is_empty() { + return Err(Error::Internal { + message: "received an empty batch in btree training".to_string(), location: location!(), - }), - _ => Ok(value), + }); } -} - -fn min_val(array: &Arc) -> Result { - let mut acc = MinAccumulator::try_new(array.data_type())?; - acc.update_batch(&[array.clone()])?; - check_for_nan(acc.evaluate()?) -} - -fn max_val(array: &Arc) -> Result { - let mut acc = MaxAccumulator::try_new(array.data_type())?; - acc.update_batch(&[array.clone()])?; - check_for_nan(acc.evaluate()?) -} + let min = ScalarValue::try_from_array(&values, 0).map_err(|e| Error::Internal { + message: format!("failed to get min value from batch: {}", e), + location: location!(), + })?; + let max = + ScalarValue::try_from_array(&values, values.len() - 1).map_err(|e| Error::Internal { + message: format!("failed to get max value from batch: {}", e), + location: location!(), + })?; -fn analyze_batch(batch: &RecordBatch) -> Result { - let values = batch.column(0); - let min = min_val(values)?; - let max = max_val(values)?; Ok(BatchStats { min, max, @@ -1032,9 +1148,17 @@ async fn train_btree_page( }) } -fn btree_stats_as_batch(stats: Vec) -> Result { - let mins = ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.min.clone()))?; - let maxs = ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.max.clone()))?; +fn btree_stats_as_batch(stats: Vec, value_type: &DataType) -> Result { + let mins = if stats.is_empty() { + new_empty_array(value_type) + } else { + ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.min.clone()))? + }; + let maxs = if stats.is_empty() { + new_empty_array(value_type) + } else { + ScalarValue::iter_to_array(stats.iter().map(|stat| stat.stats.max.clone()))? + }; let null_counts = UInt32Array::from_iter_values(stats.iter().map(|stat| stat.stats.null_count)); let page_numbers = UInt32Array::from_iter_values(stats.iter().map(|stat| stat.page_number)); @@ -1092,13 +1216,15 @@ pub async fn train_btree_index( data_source: Box, sub_index_trainer: &dyn BTreeSubIndex, index_store: &dyn IndexStore, + batch_size: u32, ) -> Result<()> { let mut sub_index_file = index_store .new_index_file(BTREE_PAGES_NAME, sub_index_trainer.schema().clone()) .await?; let mut encoded_batches = Vec::new(); let mut batch_idx = 0; - let mut batches_source = data_source.scan_ordered_chunks(4096).await?; + let mut batches_source = data_source.scan_ordered_chunks(batch_size).await?; + let value_type = batches_source.schema().field(0).data_type().clone(); while let Some(batch) = batches_source.try_next().await? { debug_assert_eq!(batch.num_columns(), 2); debug_assert_eq!(*batch.column(1).data_type(), DataType::UInt64); @@ -1108,9 +1234,13 @@ pub async fn train_btree_index( batch_idx += 1; } sub_index_file.finish().await?; - let record_batch = btree_stats_as_batch(encoded_batches)?; + let record_batch = btree_stats_as_batch(encoded_batches, &value_type)?; + let mut file_schema = record_batch.schema().as_ref().clone(); + file_schema + .metadata + .insert(BATCH_SIZE_META_KEY.to_string(), batch_size.to_string()); let mut btree_index_file = index_store - .new_index_file(BTREE_LOOKUP_NAME, record_batch.schema()) + .new_index_file(BTREE_LOOKUP_NAME, Arc::new(file_schema)) .await?; btree_index_file.write_record_batch(record_batch).await?; btree_index_file.finish().await?; @@ -1144,6 +1274,13 @@ impl TrainingSource for BTreeUpdater { self: Box, chunk_size: u32, ) -> Result { + let data_type = self.new_data.schema().field(0).data_type().clone(); + // Datafusion currently has bugs with spilling on string columns + // See https://github.com/apache/datafusion/issues/10073 + // + // One we upgrade we can remove this + let use_spilling = !matches!(data_type, DataType::Utf8 | DataType::LargeUtf8); + let new_input = Arc::new(OneShotExec::new(self.new_data)); let old_input = Self::into_old_input(self.index); debug_assert_eq!( @@ -1160,11 +1297,15 @@ impl TrainingSource for BTreeUpdater { // The UnionExec creates multiple partitions but the SortPreservingMergeExec merges // them back into a single partition. let all_data = Arc::new(UnionExec::new(vec![old_input, new_input])); - let ordered = Arc::new(SortPreservingMergeExec::new(vec![sort_expr], all_data)); + let ordered = Arc::new(SortPreservingMergeExec::new( + LexOrdering::new(vec![sort_expr]), + all_data, + )); + let unchunked = execute_plan( ordered, LanceExecutionOptions { - use_spilling: true, + use_spilling, ..Default::default() }, )?; @@ -1185,8 +1326,21 @@ impl TrainingSource for BTreeUpdater { /// This is used for updating the index struct IndexReaderStream { reader: Arc, - pages: Vec, - idx: usize, + batch_size: u64, + num_batches: u32, + batch_idx: u32, +} + +impl IndexReaderStream { + async fn new(reader: Arc, batch_size: u64) -> Self { + let num_batches = reader.num_batches(batch_size).await; + Self { + reader, + batch_size, + num_batches, + batch_idx: 0, + } + } } impl Stream for IndexReaderStream { @@ -1197,28 +1351,56 @@ impl Stream for IndexReaderStream { _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.get_mut(); - let idx = this.idx; - if idx >= this.pages.len() { + if this.batch_idx >= this.num_batches { return std::task::Poll::Ready(None); } - let page_number = this.pages[idx]; - this.idx += 1; + let batch_num = this.batch_idx; + this.batch_idx += 1; let reader_copy = this.reader.clone(); - let read_task = async move { reader_copy.read_record_batch(page_number).await }.boxed(); + let batch_size = this.batch_size; + let read_task = async move { + reader_copy + .read_record_batch(batch_num as u64, batch_size) + .await + } + .boxed(); std::task::Poll::Ready(Some(read_task)) } } #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{collections::HashMap, sync::Arc}; - use arrow::datatypes::Int32Type; + use arrow::datatypes::{Float32Type, Float64Type, Int32Type, UInt64Type}; use arrow_array::FixedSizeListArray; - use datafusion_common::ScalarValue; + use arrow_schema::DataType; + use datafusion::{ + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{sorts::sort::SortExec, stream::RecordBatchStreamAdapter, ExecutionPlan}, + }; + use datafusion_common::{DataFusionError, ScalarValue}; + use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; use deepsize::DeepSizeOf; - - use super::OrderableScalarValue; + use futures::TryStreamExt; + use lance_core::{cache::FileMetadataCache, utils::mask::RowIdTreeMap}; + use lance_datafusion::{chunker::break_stream, datagen::DatafusionDatagenExt}; + use lance_datagen::{array, gen, ArrayGeneratorExt, BatchCount, RowCount}; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use tempfile::tempdir; + + use crate::{ + metrics::NoOpMetricsCollector, + scalar::{ + btree::{BTreeIndex, BTREE_PAGES_NAME, DEFAULT_BTREE_BATCH_SIZE}, + flat::FlatIndexMetadata, + lance_format::{tests::MockTrainingSource, LanceIndexStore}, + IndexStore, SargableQuery, ScalarIndex, SearchResult, + }, + }; + + use super::{train_btree_index, OrderableScalarValue}; #[test] fn test_scalar_value_size() { @@ -1235,4 +1417,125 @@ mod tests { assert!(size_of_i32 > 4); assert!(size_of_many_i32 > 128 * 4); } + + #[tokio::test] + async fn test_null_ids() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + FileMetadataCache::no_cache(), + )); + + // Generate 50,000 rows of random data with 80% nulls + let stream = gen() + .col( + "value", + array::rand::().with_nulls(&[true, false, false, false, false]), + ) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(5000), BatchCount::from(10)); + let data_source = Box::new(MockTrainingSource::from(stream)); + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); + + train_btree_index( + data_source, + &sub_index_trainer, + test_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE as u32, + ) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone()).await.unwrap(); + + assert_eq!(index.page_lookup.null_pages.len(), 10); + + let remap_dir = Arc::new(tempdir().unwrap()); + let remap_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(remap_dir.path()).unwrap(), + FileMetadataCache::no_cache(), + )); + + // Remap with a no-op mapping. The remapped index should be identical to the original + index + .remap(&HashMap::default(), remap_store.as_ref()) + .await + .unwrap(); + + let remap_index = BTreeIndex::load(remap_store.clone()).await.unwrap(); + + assert_eq!(remap_index.page_lookup, index.page_lookup); + + let original_pages = test_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + let remapped_pages = remap_store.open_index_file(BTREE_PAGES_NAME).await.unwrap(); + + assert_eq!(original_pages.num_rows(), remapped_pages.num_rows()); + + let original_data = original_pages + .read_record_batch(0, original_pages.num_rows() as u64) + .await + .unwrap(); + let remapped_data = remapped_pages + .read_record_batch(0, remapped_pages.num_rows() as u64) + .await + .unwrap(); + + assert_eq!(original_data, remapped_data); + } + + #[tokio::test] + async fn test_nan_ordering() { + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + FileMetadataCache::no_cache(), + )); + + let values = vec![ + 0.0, + 1.0, + 2.0, + 3.0, + f64::NAN, + f64::NEG_INFINITY, + f64::INFINITY, + ]; + + // This is a bit overkill but we've had bugs in the past where DF's sort + // didn't agree with Arrow's sort so we do an end-to-end test here + // and use DF to sort the data like we would in a real dataset. + let data = gen() + .col("value", array::cycle::(values.clone())) + .col("_rowid", array::step::()) + .into_df_exec(RowCount::from(10), BatchCount::from(100)); + let schema = data.schema(); + let sort_expr = PhysicalSortExpr::new_default(col("value", schema.as_ref()).unwrap()); + let plan = Arc::new(SortExec::new(LexOrdering::new(vec![sort_expr]), data)); + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let stream = break_stream(stream, 64); + let stream = stream.map_err(DataFusionError::from); + let stream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream; + let data_source = Box::new(MockTrainingSource::from(stream)); + + let sub_index_trainer = FlatIndexMetadata::new(DataType::Float64); + + train_btree_index(data_source, &sub_index_trainer, test_store.as_ref(), 64) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store).await.unwrap(); + + for (idx, value) in values.into_iter().enumerate() { + let query = SargableQuery::Equals(ScalarValue::Float64(Some(value))); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert_eq!( + result, + SearchResult::Exact(RowIdTreeMap::from_iter(((idx as u64)..1000).step_by(7))) + ); + } + } } diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 24bbbd7cc0d..17b65441844 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -10,7 +10,7 @@ use async_trait::async_trait; use datafusion_common::ScalarValue; use datafusion_expr::{ expr::{InList, ScalarFunction}, - Between, BinaryExpr, Expr, Operator, ScalarUDF, + Between, BinaryExpr, Expr, Operator, ReturnTypeArgs, ScalarUDF, }; use futures::join; @@ -18,7 +18,9 @@ use lance_core::{utils::mask::RowIdMask, Result}; use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; use tracing::instrument; -use super::{AnyQuery, LabelListQuery, SargableQuery, ScalarIndex}; +use super::{ + AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery, +}; /// An indexed expression consists of a scalar index query with a post-scan filter /// @@ -214,6 +216,57 @@ impl ScalarQueryParser for LabelListQueryParser { } } +#[derive(Debug, Default, Clone)] +pub struct TextQueryParser {} + +impl ScalarQueryParser for TextQueryParser { + fn visit_between(&self, _: &str, _: ScalarValue, _: ScalarValue) -> Option { + None + } + + fn visit_in_list(&self, _: &str, _: Vec) -> Option { + None + } + + fn visit_is_bool(&self, _: &str, _: bool) -> Option { + None + } + + fn visit_is_null(&self, _: &str) -> Option { + None + } + + fn visit_comparison(&self, _: &str, _: ScalarValue, _: &Operator) -> Option { + None + } + + fn visit_scalar_function( + &self, + column: &str, + data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + if args.len() != 2 { + return None; + } + let scalar = maybe_scalar(&args[1], data_type)?; + if let ScalarValue::Utf8(Some(scalar_str)) = scalar { + if func.name() == "contains" { + let query = TextQuery::StringContains(scalar_str); + Some(IndexedExpression::index_query( + column.to_string(), + Arc::new(query), + )) + } else { + None + } + } else { + None + } + } +} + impl IndexedExpression { /// Create an expression that only does refine fn refine_only(refine_expr: Expr) -> Self { @@ -341,7 +394,11 @@ impl IndexedExpression { #[async_trait] pub trait ScalarIndexLoader: Send + Sync { /// Load the index with the given name - async fn load_index(&self, name: &str) -> Result>; + async fn load_index( + &self, + name: &str, + metrics: &dyn MetricsCollector, + ) -> Result>; } /// This represents a lookup into one or more scalar indices @@ -379,6 +436,19 @@ impl std::fmt::Display for ScalarIndexExpr { } } +pub enum IndexExprResult { + // The answer is exactly the rows in the allow list minus the rows in the block list + Exact(RowIdMask), + // The answer is at most the rows in the allow list minus the rows in the block list + // Some of the rows in the allow list may not be in the result and will need to be filtered + // by a recheck. Every row in the block list is definitely not in the result. + AtMost(RowIdMask), + // The answer is at least the rows in the allow list minus the rows in the block list + // Some of the rows in the block list might be in the result. Every row in the allow list is + // definitely in the result. + AtLeast(RowIdMask), +} + impl ScalarIndexExpr { /// Evaluates the scalar index expression /// @@ -388,31 +458,110 @@ impl ScalarIndexExpr { /// any situations where the session cache has been disabled. #[async_recursion] #[instrument(level = "debug", skip_all)] - pub async fn evaluate(&self, index_loader: &dyn ScalarIndexLoader) -> Result { + pub async fn evaluate( + &self, + index_loader: &dyn ScalarIndexLoader, + metrics: &dyn MetricsCollector, + ) -> Result { match self { Self::Not(inner) => { - let result = inner.evaluate(index_loader).await?; - Ok(!result) + let result = inner.evaluate(index_loader, metrics).await?; + match result { + IndexExprResult::Exact(mask) => Ok(IndexExprResult::Exact(!mask)), + IndexExprResult::AtMost(mask) => Ok(IndexExprResult::AtLeast(!mask)), + IndexExprResult::AtLeast(mask) => Ok(IndexExprResult::AtMost(!mask)), + } } Self::And(lhs, rhs) => { - let lhs_result = lhs.evaluate(index_loader); - let rhs_result = rhs.evaluate(index_loader); + let lhs_result = lhs.evaluate(index_loader, metrics); + let rhs_result = rhs.evaluate(index_loader, metrics); let (lhs_result, rhs_result) = join!(lhs_result, rhs_result); - Ok(lhs_result? & rhs_result?) + match (lhs_result?, rhs_result?) { + (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => { + Ok(IndexExprResult::Exact(lhs & rhs)) + } + (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs)) + | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => { + Ok(IndexExprResult::AtMost(lhs & rhs)) + } + (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(_)) => { + // We could do better here, elements in both lhs and rhs are known + // to be true and don't require a recheck. We only need to recheck + // elements in lhs that are not in rhs + Ok(IndexExprResult::AtMost(lhs)) + } + (IndexExprResult::AtLeast(_), IndexExprResult::Exact(rhs)) => { + // We could do better here (see above) + Ok(IndexExprResult::AtMost(rhs)) + } + (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => { + Ok(IndexExprResult::AtMost(lhs & rhs)) + } + (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => { + Ok(IndexExprResult::AtLeast(lhs & rhs)) + } + (IndexExprResult::AtLeast(_), IndexExprResult::AtMost(rhs)) => { + Ok(IndexExprResult::AtMost(rhs)) + } + (IndexExprResult::AtMost(lhs), IndexExprResult::AtLeast(_)) => { + Ok(IndexExprResult::AtMost(lhs)) + } + } } Self::Or(lhs, rhs) => { - let lhs_result = lhs.evaluate(index_loader); - let rhs_result = rhs.evaluate(index_loader); + let lhs_result = lhs.evaluate(index_loader, metrics); + let rhs_result = rhs.evaluate(index_loader, metrics); let (lhs_result, rhs_result) = join!(lhs_result, rhs_result); - Ok(lhs_result? | rhs_result?) + match (lhs_result?, rhs_result?) { + (IndexExprResult::Exact(lhs), IndexExprResult::Exact(rhs)) => { + Ok(IndexExprResult::Exact(lhs | rhs)) + } + (IndexExprResult::Exact(lhs), IndexExprResult::AtMost(rhs)) + | (IndexExprResult::AtMost(lhs), IndexExprResult::Exact(rhs)) => { + // We could do better here. Elements in the exact side don't need + // re-check. We only need to recheck elements exclusively in the + // at-most side + Ok(IndexExprResult::AtMost(lhs | rhs)) + } + (IndexExprResult::Exact(lhs), IndexExprResult::AtLeast(rhs)) => { + Ok(IndexExprResult::AtLeast(lhs | rhs)) + } + (IndexExprResult::AtLeast(lhs), IndexExprResult::Exact(rhs)) => { + Ok(IndexExprResult::AtLeast(lhs | rhs)) + } + (IndexExprResult::AtMost(lhs), IndexExprResult::AtMost(rhs)) => { + Ok(IndexExprResult::AtMost(lhs | rhs)) + } + (IndexExprResult::AtLeast(lhs), IndexExprResult::AtLeast(rhs)) => { + Ok(IndexExprResult::AtLeast(lhs | rhs)) + } + (IndexExprResult::AtLeast(lhs), IndexExprResult::AtMost(_)) => { + Ok(IndexExprResult::AtLeast(lhs)) + } + (IndexExprResult::AtMost(_), IndexExprResult::AtLeast(rhs)) => { + Ok(IndexExprResult::AtLeast(rhs)) + } + } } Self::Query(column, query) => { - let index = index_loader.load_index(column).await?; - let matching_row_ids = index.search(query.as_ref()).await?; - Ok(RowIdMask { - block_list: None, - allow_list: Some(matching_row_ids), - }) + let index = index_loader.load_index(column, metrics).await?; + let search_result = index.search(query.as_ref(), metrics).await?; + match search_result { + SearchResult::Exact(matching_row_ids) => { + Ok(IndexExprResult::Exact(RowIdMask { + block_list: None, + allow_list: Some(matching_row_ids), + })) + } + SearchResult::AtMost(row_ids) => Ok(IndexExprResult::AtMost(RowIdMask { + block_list: None, + allow_list: Some(row_ids), + })), + SearchResult::AtLeast(row_ids) => Ok(IndexExprResult::AtLeast(RowIdMask { + block_list: None, + allow_list: Some(row_ids), + })), + } } } } @@ -433,6 +582,14 @@ impl ScalarIndexExpr { Self::Query(column, query) => query.to_expr(column.clone()), } } + + pub fn needs_recheck(&self) -> bool { + match self { + Self::Not(inner) => inner.needs_recheck(), + Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.needs_recheck() || rhs.needs_recheck(), + Self::Query(_, query) => query.needs_recheck(), + } + } } // Extract a column from the expression, if it is a column, or None @@ -457,6 +614,44 @@ fn maybe_indexed_column<'a, 'b>( fn maybe_scalar(expr: &Expr, expected_type: &DataType) -> Option { match expr { Expr::Literal(value) => safe_coerce_scalar(value, expected_type), + // Some literals can't be expressed in datafusion's SQL and can only be expressed with + // a cast. For example, there is no way to express a fixed-size-binary literal (which is + // commonly used for UUID). As a result the expression could look like... + // + // col = arrow_cast(value, 'fixed_size_binary(16)') + // + // In this case we need to extract the value, apply the cast, and then test the casted value + Expr::Cast(cast) => match cast.expr.as_ref() { + Expr::Literal(value) => { + let casted = value.cast_to(&cast.data_type).ok()?; + safe_coerce_scalar(&casted, expected_type) + } + _ => None, + }, + Expr::ScalarFunction(scalar_function) => { + if scalar_function.name() == "arrow_cast" { + if scalar_function.args.len() != 2 { + return None; + } + match (&scalar_function.args[0], &scalar_function.args[1]) { + (Expr::Literal(value), Expr::Literal(cast_type)) => { + let target_type = scalar_function + .func + .return_type_from_args(ReturnTypeArgs { + arg_types: &[value.data_type(), cast_type.data_type()], + scalar_arguments: &[Some(value), Some(cast_type)], + nullables: &[false, false], + }) + .ok()?; + let casted = value.cast_to(target_type.return_type()).ok()?; + safe_coerce_scalar(&casted, expected_type) + } + _ => None, + } + } else { + None + } + } _ => None, } } @@ -564,9 +759,66 @@ fn visit_comparison( let scalar = maybe_scalar(&expr.right, col_type)?; query_parser.visit_comparison(column, scalar, &expr.op) } else { - let (column, col_type, query_parser) = maybe_indexed_column(&expr.right, index_info)?; - let scalar = maybe_scalar(&expr.left, col_type)?; - query_parser.visit_comparison(column, scalar, &expr.op) + // Datafusion's query simplifier will canonicalize expressions and so we shouldn't reach this case. If, for some reason, we + // do reach this case we can handle it in the future by inverting expr.op and swapping the left and right sides + None + } +} + +fn maybe_between(expr: &BinaryExpr) -> Option { + let left_comparison = match expr.left.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + let right_comparison = match expr.right.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + + match (left_comparison.op, right_comparison.op) { + (Operator::GtEq, Operator::LtEq) => { + // We have x >= y && a <= b. + // If x == a then it is a between query + // if y == b then it is a between query + if left_comparison.left == right_comparison.left { + Some(Between { + expr: left_comparison.left.clone(), + low: left_comparison.right.clone(), + high: right_comparison.right.clone(), + negated: false, + }) + } else if left_comparison.right == right_comparison.right { + Some(Between { + expr: left_comparison.right.clone(), + low: right_comparison.left.clone(), + high: left_comparison.left.clone(), + negated: false, + }) + } else { + None + } + } + (Operator::LtEq, Operator::GtEq) => { + // Same logic as above we just switch the low/high + if left_comparison.left == right_comparison.left { + Some(Between { + expr: left_comparison.left.clone(), + low: right_comparison.right.clone(), + high: left_comparison.right.clone(), + negated: false, + }) + } else if left_comparison.right == right_comparison.right { + Some(Between { + expr: left_comparison.right.clone(), + low: left_comparison.left.clone(), + high: right_comparison.left.clone(), + negated: false, + }) + } else { + None + } + } + _ => None, } } @@ -574,6 +826,17 @@ fn visit_and( expr: &BinaryExpr, index_info: &dyn IndexInformationProvider, ) -> Option { + // Many scalar indices can efficiently handle a BETWEEN query as a single search and this + // can be much more efficient than two separate range queries. As an optimization we check + // to see if this is a between query and, if so, we handle it as a single query + // + // Note: We can't rely on users writing the SQL BETWEEN operator because: + // * Some users won't realize it's an option or a good idea + // * Datafusion's simplifier will rewrite the BETWEEN operator into two separate range queries + if let Some(between) = maybe_between(expr) { + return visit_between(&between, index_info); + } + let left = visit_node(&expr.left, index_info); let right = visit_node(&expr.right, index_info); match (left, right) { @@ -664,6 +927,8 @@ pub fn apply_scalar_indices( #[derive(Default, Debug)] pub struct FilterPlan { pub index_query: Option, + /// True if the index query is guaranteed to return exact results + pub skip_recheck: bool, pub refine_expr: Option, pub full_expr: Option, } @@ -716,14 +981,20 @@ impl PlannerIndexExt for Planner { let logical_expr = self.optimize_expr(filter)?; if use_scalar_index { let indexed_expr = apply_scalar_indices(logical_expr.clone(), index_info); + let mut skip_recheck = false; + if let Some(scalar_query) = indexed_expr.scalar_query.as_ref() { + skip_recheck = !scalar_query.needs_recheck(); + } Ok(FilterPlan { index_query: indexed_expr.scalar_query, refine_expr: indexed_expr.refine_expr, full_expr: Some(logical_expr), + skip_recheck, }) } else { Ok(FilterPlan { index_query: None, + skip_recheck: true, refine_expr: Some(logical_expr.clone()), full_expr: Some(logical_expr), }) @@ -736,12 +1007,8 @@ mod tests { use std::collections::HashMap; use arrow_schema::{Field, Schema}; - use datafusion::error::Result as DFResult; - use datafusion_common::{config::ConfigOptions, TableReference}; + use datafusion::prelude::SessionContext; use datafusion_common::{Column, DFSchema}; - use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; - use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; - use datafusion_sql::sqlparser::{dialect::PostgreSqlDialect, parser::Parser}; use super::*; @@ -780,47 +1047,6 @@ mod tests { } } - struct MockContextProvider {} - - // We're just compiling simple expressions (not entire statements) and so this is unused - impl ContextProvider for MockContextProvider { - fn get_table_source(&self, _name: TableReference) -> DFResult> { - todo!() - } - - fn get_function_meta(&self, _: &str) -> Option> { - todo!() - } - - fn get_aggregate_meta(&self, _: &str) -> Option> { - todo!() - } - - fn get_window_meta(&self, _: &str) -> Option> { - todo!() - } - - fn get_variable_type(&self, _: &[String]) -> Option { - todo!() - } - - fn options(&self) -> &ConfigOptions { - todo!() - } - - fn udf_names(&self) -> Vec { - todo!() - } - - fn udaf_names(&self) -> Vec { - todo!() - } - - fn udwf_names(&self) -> Vec { - todo!() - } - } - fn check( index_info: &dyn IndexInformationProvider, expr: &str, @@ -833,16 +1059,11 @@ mod tests { Field::new("on_sale", DataType::Boolean, false), Field::new("price", DataType::Float32, false), ]); - let dialect = PostgreSqlDialect {}; - let mut parser = Parser::new(&dialect).try_with_sql(expr).unwrap(); - let expr = parser.parse_expr().unwrap(); - let context_provider = MockContextProvider {}; - let planner = SqlToRel::new(&context_provider); let df_schema: DFSchema = schema.try_into().unwrap(); - let mut planner_context = PlannerContext::new(); - let expr = planner - .sql_to_expr(expr, &df_schema, &mut planner_context) - .unwrap(); + + let ctx = SessionContext::default(); + let state = ctx.state(); + let expr = state.create_logical_expr(expr, &df_schema).unwrap(); let actual = apply_scalar_indices(expr.clone(), index_info); if let Some(expected) = expected { @@ -912,6 +1133,14 @@ mod tests { ]); check_no_index(&index_info, "size BETWEEN 5 AND 10"); + // Cast case. We will cast 5 (an int64) to Int16 and then coerce to UInt32 + check_simple( + &index_info, + "aisle = arrow_cast(5, 'Int16')", + "aisle", + SargableQuery::Equals(ScalarValue::UInt32(Some(5))), + ); + // 5 different ways of writing BETWEEN (all should be recognized) check_simple( &index_info, "aisle BETWEEN 5 AND 10", @@ -921,6 +1150,45 @@ mod tests { Bound::Included(ScalarValue::UInt32(Some(10))), ), ); + check_simple( + &index_info, + "aisle >= 5 AND aisle <= 10", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "aisle <= 10 AND aisle >= 5", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "5 <= aisle AND 10 >= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_simple( + &index_info, + "10 >= aisle AND 5 <= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); check_simple( &index_info, "on_sale IS TRUE", @@ -1023,6 +1291,10 @@ mod tests { Bound::Unbounded, ), ); + // In the future we can handle this case if we need to. For + // now let's make sure we don't accidentally do the wrong thing + // (we were getting this backwards in the past) + check_no_index(&index_info, "10 > aisle"); check_simple( &index_info, "aisle >= 10", diff --git a/rust/lance-index/src/scalar/flat.rs b/rust/lance-index/src/scalar/flat.rs index 66a69e95e53..5aae43f374a 100644 --- a/rust/lance-index/src/scalar/flat.rs +++ b/rust/lance-index/src/scalar/flat.rs @@ -17,12 +17,12 @@ use lance_core::utils::address::RowAddress; use lance_core::utils::mask::RowIdTreeMap; use lance_core::{Error, Result}; use roaring::RoaringBitmap; -use snafu::{location, Location}; +use snafu::location; use crate::{Index, IndexType}; use super::{btree::BTreeSubIndex, IndexStore, ScalarIndex}; -use super::{AnyQuery, SargableQuery}; +use super::{AnyQuery, MetricsCollector, SargableQuery, SearchResult}; /// A flat index is just a batch of value/row-id pairs /// @@ -33,6 +33,7 @@ use super::{AnyQuery, SargableQuery}; #[derive(Debug)] pub struct FlatIndex { data: Arc, + has_nulls: bool, } impl DeepSizeOf for FlatIndex { @@ -132,8 +133,10 @@ impl BTreeSubIndex for FlatIndexMetadata { } async fn load_subindex(&self, serialized: RecordBatch) -> Result> { + let has_nulls = serialized.column(0).null_count() > 0; Ok(Arc::new(FlatIndex { data: Arc::new(serialized), + has_nulls, })) } @@ -171,6 +174,11 @@ impl Index for FlatIndex { IndexType::Scalar } + async fn prewarm(&self) -> Result<()> { + // There is nothing to pre-warm + Ok(()) + } + fn statistics(&self) -> Result { Ok(serde_json::json!({ "num_values": self.data.num_rows(), @@ -192,17 +200,32 @@ impl Index for FlatIndex { #[async_trait] impl ScalarIndex for FlatIndex { - async fn search(&self, query: &dyn AnyQuery) -> Result { + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { + metrics.record_comparisons(self.data.num_rows()); let query = query.as_any().downcast_ref::().unwrap(); // Since we have all the values in memory we can use basic arrow-rs compute // functions to satisfy scalar queries. - let predicate = match query { - SargableQuery::Equals(value) => arrow_ord::cmp::eq(self.values(), &value.to_scalar()?)?, + let mut predicate = match query { + SargableQuery::Equals(value) => { + if value.is_null() { + arrow::compute::is_null(self.values())? + } else { + arrow_ord::cmp::eq(self.values(), &value.to_scalar()?)? + } + } SargableQuery::IsNull() => arrow::compute::is_null(self.values())?, SargableQuery::IsIn(values) => { + let mut has_null = false; let choices = values .iter() - .map(|val| lit(val.clone())) + .map(|val| { + has_null |= val.is_null(); + lit(val.clone()) + }) .collect::>(); let in_list_expr = in_list( Arc::new(Column::new("values", 0)), @@ -211,12 +234,20 @@ impl ScalarIndex for FlatIndex { &self.data.schema(), )?; let result_col = in_list_expr.evaluate(&self.data)?; - result_col + let predicate = result_col .into_array(self.data.num_rows())? .as_any() .downcast_ref::() .expect("InList evaluation should return boolean array") - .clone() + .clone(); + + // Arrow's in_list does not handle nulls so we need to join them in here if user asked for them + if has_null && self.has_nulls { + let nulls = arrow::compute::is_null(self.values())?; + arrow::compute::or(&predicate, &nulls)? + } else { + predicate + } } SargableQuery::Range(lower_bound, upper_bound) => match (lower_bound, upper_bound) { (Bound::Unbounded, Bound::Unbounded) => { @@ -256,12 +287,24 @@ impl ScalarIndex for FlatIndex { location!(), )), }; + if self.has_nulls && matches!(query, SargableQuery::Range(_, _)) { + // Arrow's comparison kernels do not return false for nulls. They consider nulls to + // be less than any value. So we need to filter out the nulls manually. + let valid_values = arrow::compute::is_not_null(self.values())?; + predicate = arrow::compute::and(&valid_values, &predicate)?; + } let matching_ids = arrow_select::filter::filter(self.ids(), &predicate)?; let matching_ids = matching_ids .as_any() .downcast_ref::() .expect("Result of arrow_select::filter::filter did not match input type"); - Ok(RowIdTreeMap::from_iter(matching_ids.values())) + Ok(SearchResult::Exact(RowIdTreeMap::from_iter( + matching_ids.values(), + ))) + } + + fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool { + true } // Note that there is no write/train method for flat index at the moment and so it isn't @@ -269,9 +312,12 @@ impl ScalarIndex for FlatIndex { // data as a single batch named data.lance async fn load(store: Arc) -> Result> { let batches = store.open_index_file("data.lance").await?; - let batch = batches.read_record_batch(0).await?; + let num_rows = batches.num_rows(); + let batch = batches.read_range(0..num_rows, None).await?; + let has_nulls = batch.column(0).null_count() > 0; Ok(Arc::new(Self { data: Arc::new(batch), + has_nulls, })) } @@ -302,6 +348,8 @@ impl ScalarIndex for FlatIndex { #[cfg(test)] mod tests { + use crate::metrics::NoOpMetricsCollector; + use super::*; use arrow_array::types::Int32Type; use datafusion_common::ScalarValue; @@ -319,14 +367,18 @@ mod tests { FlatIndex { data: Arc::new(batch), + has_nulls: false, } } async fn check_index(query: &SargableQuery, expected: &[u64]) { let index = example_index(); - let actual = index.search(query).await.unwrap(); + let actual = index.search(query, &NoOpMetricsCollector).await.unwrap(); + let SearchResult::Exact(actual_row_ids) = actual else { + panic! {"Expected exact search result"} + }; let expected = RowIdTreeMap::from_iter(expected); - assert_eq!(actual, expected); + assert_eq!(actual_row_ids, expected); } #[tokio::test] diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 32371773a3c..a1114d48b66 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -1,8 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -mod builder; +pub mod builder; mod index; +pub mod query; mod tokenizer; mod wand; diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 6b99e7c9175..2132325b3af 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -12,8 +12,8 @@ use crate::scalar::{IndexReader, IndexStore, IndexWriter, InvertedIndexParams}; use crate::vector::graph::OrderedFloat; use arrow::array::{ArrayBuilder, AsArray, Int32Builder, StringBuilder}; use arrow::datatypes; -use arrow_array::{Int32Array, RecordBatch, StringArray}; -use arrow_schema::SchemaRef; +use arrow_array::{Array, Int32Array, RecordBatch, StringArray, UInt64Array}; +use arrow_schema::{Field, Schema, SchemaRef}; use crossbeam_queue::ArrayQueue; use datafusion::execution::SendableRecordBatchStream; use deepsize::DeepSizeOf; @@ -22,10 +22,11 @@ use itertools::Itertools; use lance_arrow::iter_str_array; use lance_core::cache::FileMetadataCache; use lance_core::utils::tokio::{get_num_compute_intensive_cpus, CPU_RUNTIME}; -use lance_core::{Result, ROW_ID}; +use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD}; use lance_io::object_store::ObjectStore; use lazy_static::lazy_static; use object_store::path::Path; +use snafu::location; use tempfile::{tempdir, TempDir}; use tracing::instrument; @@ -53,7 +54,7 @@ lazy_static! { // it doesn't mean higher value will result in better performance, // because the bottleneck can be the IO once the number of shards is large enough, // it's 8 by default - static ref LANCE_FTS_NUM_SHARDS: usize = std::env::var("LANCE_FTS_NUM_SHARDS") + pub static ref LANCE_FTS_NUM_SHARDS: usize = std::env::var("LANCE_FTS_NUM_SHARDS") .unwrap_or_else(|_| "8".to_string()) .parse() .expect("failed to parse LANCE_FTS_NUM_SHARDS"); @@ -108,15 +109,38 @@ impl InvertedIndexBuilder { #[instrument(level = "debug", skip_all)] async fn update_index(&mut self, stream: SendableRecordBatchStream) -> Result<()> { + let flatten_stream = stream.map(|batch| { + let batch = batch?; + let doc_col = batch.column(0); + match doc_col.data_type() { + datatypes::DataType::Utf8 | datatypes::DataType::LargeUtf8 => Ok(batch), + datatypes::DataType::List(_) => { + flatten_string_list::(&batch, doc_col) + } + datatypes::DataType::LargeList(_) => { + flatten_string_list::(&batch, doc_col) + } + _ => { + Err(Error::Index { message: format!("expect data type String, LargeString or List of String/LargeString, but got {}", doc_col.data_type()), location: location!() }) + } + } + }); + let num_shards = *LANCE_FTS_NUM_SHARDS; // init the token maps let mut token_maps = vec![HashMap::new(); num_shards]; - for (token, token_id) in self.tokens.tokens.iter() { - let mut hasher = DefaultHasher::new(); - hasher.write(token.as_bytes()); - let shard = hasher.finish() as usize % num_shards; - token_maps[shard].insert(token.clone(), *token_id); + + match self.tokens.tokens { + TokenMap::HashMap(ref tokens) => { + for (token, token_id) in tokens.iter() { + let mut hasher = DefaultHasher::new(); + hasher.write(token.as_bytes()); + let shard = hasher.finish() as usize % num_shards; + token_maps[shard].insert(token.clone(), *token_id); + } + } + _ => unreachable!("tokens must be HashMap at indexing"), } // spawn `num_shards` workers to build the index, @@ -153,13 +177,15 @@ impl InvertedIndexBuilder { for _ in 0..num_shards { let _ = tokenizer_pool.push(tokenizer.clone()); } - let mut stream = stream + let mut stream = flatten_stream .map(move |batch| { let senders = senders.clone(); let tokenizer_pool = tokenizer_pool.clone(); CPU_RUNTIME.spawn_blocking(move || { let batch = batch?; - let doc_iter = iter_str_array(batch.column(0)); + + let doc_col = batch.column(0); + let doc_iter = iter_str_array(doc_col); let row_id_col = batch[ROW_ID].as_primitive::(); let docs = doc_iter .zip(row_id_col.values().iter()) @@ -285,8 +311,7 @@ impl InvertedIndexBuilder { Result::Ok((batch, max_score)) } }); - let mut stream = - stream::iter(batches).buffer_unordered(get_num_compute_intensive_cpus()); + let mut stream = stream::iter(batches).buffered(get_num_compute_intensive_cpus()); let mut offsets = Vec::new(); let mut max_scores = Vec::new(); let mut num_rows = 0; @@ -421,7 +446,7 @@ impl IndexWorker { async fn new(existing_tokens: HashMap, with_position: bool) -> Result { let tmpdir = tempdir()?; let store = Arc::new(LanceIndexStore::new( - ObjectStore::local(), + Arc::new(ObjectStore::local()), Path::from_filesystem_path(tmpdir.path())?, FileMetadataCache::no_cache(), )); @@ -717,214 +742,41 @@ pub fn inverted_list_schema(with_position: bool) -> SchemaRef { Arc::new(arrow_schema::Schema::new(fields)) } -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow_array::{Array, ArrayRef, GenericStringArray, RecordBatch, UInt64Array}; - use datafusion::physical_plan::stream::RecordBatchStreamAdapter; - use futures::stream; - use lance_core::cache::{CapacityMode, FileMetadataCache}; - use lance_core::ROW_ID_FIELD; - use lance_io::object_store::ObjectStore; - use object_store::path::Path; - - use crate::scalar::inverted::TokenizerConfig; - use crate::scalar::lance_format::LanceIndexStore; - use crate::scalar::{FullTextSearchQuery, SargableQuery, ScalarIndex}; - - use super::InvertedIndex; - - async fn create_index( - with_position: bool, - tokenizer: TokenizerConfig, - ) -> Arc { - let tempdir = tempfile::tempdir().unwrap(); - let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); - let cache = FileMetadataCache::with_capacity(128 * 1024 * 1024, CapacityMode::Bytes); - let store = LanceIndexStore::new(ObjectStore::local(), index_dir, cache); - - let mut params = super::InvertedIndexParams::default().with_position(with_position); - params.tokenizer_config = tokenizer; - let mut invert_index = super::InvertedIndexBuilder::new(params); - let doc_col = GenericStringArray::::from(vec![ - "lance database the search", - "lance database", - "lance search", - "database search", - "unrelated doc", - "unrelated", - "mots accentués", - ]); - let row_id_col = UInt64Array::from(Vec::from_iter(0..doc_col.len() as u64)); - let batch = RecordBatch::try_new( - arrow_schema::Schema::new(vec![ - arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), false), - ROW_ID_FIELD.clone(), - ]) - .into(), - vec![ - Arc::new(doc_col) as ArrayRef, - Arc::new(row_id_col) as ArrayRef, - ], - ) - .unwrap(); - let stream = RecordBatchStreamAdapter::new(batch.schema(), stream::iter(vec![Ok(batch)])); - let stream = Box::pin(stream); - - invert_index - .update(stream, &store) - .await - .expect("failed to update invert index"); - - super::InvertedIndex::load(Arc::new(store)).await.unwrap() - } - - async fn test_inverted_index() { - let invert_index = create_index::(false, TokenizerConfig::default()).await; - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("lance".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(3)); - assert!(row_ids.contains(0)); - assert!(row_ids.contains(1)); - assert!(row_ids.contains(2)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("database".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(3)); - assert!(row_ids.contains(0)); - assert!(row_ids.contains(1)); - assert!(row_ids.contains(3)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("unknown null".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(0)); - - // test phrase query - // for non-phrasal query, the order of the tokens doesn't matter - // so there should be 4 documents that contain "database" or "lance" - - // we built the index without position, so the phrase query will not work - let results = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("\"unknown null\"".to_owned()).limit(Some(3)), - )) - .await; - assert!(results.unwrap_err().to_string().contains("position is not found but required for phrase queries, try recreating the index with position")); - let results = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("\"lance database\"".to_owned()).limit(Some(10)), - )) - .await; - assert!(results.unwrap_err().to_string().contains("position is not found but required for phrase queries, try recreating the index with position")); - - // recreate the index with position - let invert_index = create_index::(true, TokenizerConfig::default()).await; - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("lance database".to_owned()).limit(Some(10)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(4)); - assert!(row_ids.contains(0)); - assert!(row_ids.contains(1)); - assert!(row_ids.contains(2)); - assert!(row_ids.contains(3)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("\"lance database\"".to_owned()).limit(Some(10)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(2)); - assert!(row_ids.contains(0)); - assert!(row_ids.contains(1)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("\"database lance\"".to_owned()).limit(Some(10)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(0)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("\"lance unknown\"".to_owned()).limit(Some(10)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(0)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("\"unknown null\"".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(0)); - } - - #[tokio::test] - async fn test_inverted_index_with_string() { - test_inverted_index::().await; - } - - #[tokio::test] - async fn test_inverted_index_with_large_string() { - test_inverted_index::().await; - } - - #[tokio::test] - async fn test_accented_chars() { - let invert_index = create_index::(false, TokenizerConfig::default()).await; - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(1)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(0)); - - // with ascii folding enabled, the search should be accent-insensitive - let invert_index = - create_index::(true, TokenizerConfig::default().ascii_folding(true)).await; - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(1)); - - let row_ids = invert_index - .search(&SargableQuery::FullTextSearch( - FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3)), - )) - .await - .unwrap(); - assert_eq!(row_ids.len(), Some(1)); - } +fn flatten_string_list( + batch: &RecordBatch, + doc_col: &Arc, +) -> Result { + let docs = doc_col.as_list::(); + let row_ids = batch[ROW_ID].as_primitive::(); + + let row_ids = row_ids + .values() + .iter() + .zip(docs.iter()) + .flat_map(|(row_id, doc)| std::iter::repeat_n(*row_id, doc.map(|d| d.len()).unwrap_or(0))); + + let row_ids = Arc::new(UInt64Array::from_iter_values(row_ids)); + let docs = match docs.value_type() { + datatypes::DataType::Utf8 | datatypes::DataType::LargeUtf8 => docs.values().clone(), + _ => { + return Err(Error::Index { + message: format!( + "expect data type String or LargeString but got {}", + docs.value_type() + ), + location: location!(), + }); + } + }; + + let schema = Schema::new(vec![ + Field::new( + batch.schema().field(0).name(), + docs.data_type().clone(), + true, + ), + ROW_ID_FIELD.clone(), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![docs, row_ids])?; + Ok(batch) } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 1987e3a0daf..2cfba54971a 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::cmp::min; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::sync::Arc; @@ -11,8 +12,8 @@ use arrow::array::{ use arrow::buffer::ScalarBuffer; use arrow::datatypes::{self, Float32Type, Int32Type, UInt64Type}; use arrow_array::{ - Array, ArrayRef, Float32Array, ListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, - UInt32Array, UInt64Array, + Array, ArrayRef, BooleanArray, Float32Array, ListArray, OffsetSizeTrait, PrimitiveArray, + RecordBatch, UInt32Array, UInt64Array, }; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; @@ -20,25 +21,26 @@ use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion_common::DataFusionError; use deepsize::DeepSizeOf; +use fst::{IntoStreamer, Streamer}; use futures::stream::repeat_with; use futures::{stream, StreamExt, TryStreamExt}; use itertools::Itertools; use lance_arrow::{iter_str_array, RecordBatchExt}; -use lance_core::utils::mask::RowIdTreeMap; -use lance_core::utils::tokio::get_num_compute_intensive_cpus; +use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS}; use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD}; use lazy_static::lazy_static; use moka::future::Cache; use roaring::RoaringBitmap; -use snafu::{location, Location}; -use tracing::instrument; +use snafu::location; +use tracing::{info, instrument}; use super::builder::inverted_list_schema; +use super::query::*; use super::{wand::*, InvertedIndexBuilder, TokenizerConfig}; -use crate::prefilter::{NoFilter, PreFilter}; +use crate::prefilter::PreFilter; use crate::scalar::{ - AnyQuery, FullTextSearchQuery, IndexReader, IndexStore, InvertedIndexParams, SargableQuery, - ScalarIndex, + AnyQuery, IndexReader, IndexStore, InvertedIndexParams, MetricsCollector, SargableQuery, + ScalarIndex, SearchResult, }; use crate::Index; @@ -63,7 +65,7 @@ pub const K1: f32 = 1.2; pub const B: f32 = 0.75; lazy_static! { - static ref CACHE_SIZE: usize = std::env::var("LANCE_INVERTED_CACHE_SIZE") + pub static ref CACHE_SIZE: usize = std::env::var("LANCE_INVERTED_CACHE_SIZE") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(512 * 1024 * 1024); @@ -71,6 +73,7 @@ lazy_static! { #[derive(Clone)] pub struct InvertedIndex { + io_parallelism: usize, params: InvertedIndexParams, tokenizer: tantivy::tokenizer::TextAnalyzer, tokens: TokenSet, @@ -107,55 +110,83 @@ impl InvertedIndex { .collect() } - #[instrument(level = "debug", skip_all)] - pub async fn full_text_search( + fn to_builder(&self) -> InvertedIndexBuilder { + let tokens = self.tokens.clone().into_mut(); + let inverted_list = self.inverted_list.clone(); + let docs = self.docs.clone(); + InvertedIndexBuilder::from_existing_index(self.params.clone(), tokens, inverted_list, docs) + } + + pub fn tokenizer(&self) -> tantivy::tokenizer::TextAnalyzer { + self.tokenizer.clone() + } + + pub fn expand_fuzzy( &self, - query: &FullTextSearchQuery, - prefilter: Arc, - ) -> Result> { - let mut tokenizer = self.tokenizer.clone(); - let tokens = collect_tokens(&query.query, &mut tokenizer, None); - let token_ids = self.map(&tokens).into_iter(); - let token_ids = if !is_phrase_query(&query.query) { - token_ids.sorted_unstable().dedup().collect() - } else { - if !self.inverted_list.has_positions() { - return Err(Error::Index { message: "position is not found but required for phrase queries, try recreating the index with position".to_owned(), location: location!() }); - } - let token_ids = token_ids.collect::>(); - // for phrase query, all tokens must be present - if token_ids.len() != tokens.len() { - return Ok(Vec::new()); + tokens: Vec, + fuzziness: Option, + max_expansions: usize, + ) -> Result> { + let mut new_tokens = Vec::with_capacity(min(tokens.len(), max_expansions)); + for token in tokens { + let fuzziness = match fuzziness { + Some(fuzziness) => fuzziness, + None => MatchQuery::auto_fuzziness(&token), + }; + let lev = + fst::automaton::Levenshtein::new(&token, fuzziness).map_err(|e| Error::Index { + message: format!("failed to construct the fuzzy query: {}", e), + location: location!(), + })?; + if let TokenMap::Fst(ref map) = self.tokens.tokens { + let mut stream = map.search(lev).into_stream(); + while let Some((token, _)) = stream.next() { + new_tokens.push(String::from_utf8_lossy(token).into_owned()); + if new_tokens.len() >= max_expansions { + break; + } + } + } else { + return Err(Error::Index { + message: "tokens is not fst, which is not expected".to_owned(), + location: location!(), + }); } - token_ids - }; - self.bm25_search(token_ids, query, prefilter).await + } + Ok(new_tokens) } // search the documents that contain the query // return the row ids of the documents sorted by bm25 score // ref: https://en.wikipedia.org/wiki/Okapi_BM25 #[instrument(level = "debug", skip_all)] - async fn bm25_search( + pub async fn bm25_search( &self, - token_ids: Vec, - query: &FullTextSearchQuery, + tokens: &[String], + params: &FtsSearchParams, + operator: Operator, + is_phrase_query: bool, prefilter: Arc, - ) -> Result> { - let limit = query - .limit - .map(|limit| limit as usize) - .unwrap_or(usize::MAX); - let wand_factor = query.wand_factor.unwrap_or(1.0); + metrics: &dyn MetricsCollector, + ) -> Result<(Vec, Vec)> { + metrics.record_comparisons(tokens.len()); let mask = prefilter.mask(); - let is_phrase_query = is_phrase_query(&query.query); + let token_ids = self.map(tokens); + if token_ids.is_empty() { + return Ok((Vec::new(), Vec::new())); + } + if is_phrase_query && token_ids.len() != tokens.len() { + return Ok((Vec::new(), Vec::new())); + } + let postings = stream::iter(token_ids) .enumerate() - .zip(repeat_with(|| (self.inverted_list.clone(), mask.clone()))) - .map(|((position, token_id), (inverted_list, mask))| async move { - let posting = inverted_list - .posting_list(token_id, is_phrase_query) + .zip(repeat_with(|| mask.clone())) + .map(|((position, token_id), mask)| async move { + let posting = self + .inverted_list + .posting_list(token_id, is_phrase_query, metrics) .await?; Result::Ok(PostingIterator::new( token_id, @@ -165,26 +196,23 @@ impl InvertedIndex { mask, )) }) - // Use compute count since data hopefully cached - .buffered(get_num_compute_intensive_cpus()) + .buffer_unordered(self.io_parallelism) .try_collect::>() .await?; - let mut wand = Wand::new(self.docs.len(), postings.into_iter()); - wand.search(is_phrase_query, limit, wand_factor, |doc, freq| { - let doc_norm = - K1 * (1.0 - B + B * self.docs.num_tokens(doc) as f32 / self.docs.average_length()); - freq / (freq + doc_norm) - }) + let mut wand = Wand::new(self.docs.len(), operator, postings.into_iter()); + wand.search( + is_phrase_query, + params.limit.unwrap_or(usize::MAX), + params.wand_factor, + |doc, freq| { + let doc_norm = K1 + * (1.0 - B + B * self.docs.num_tokens(doc) as f32 / self.docs.average_length()); + freq / (freq + doc_norm) + }, + ) .await } - - fn to_builder(&self) -> InvertedIndexBuilder { - let tokens = self.tokens.clone(); - let inverted_list = self.inverted_list.clone(); - let docs = self.docs.clone(); - InvertedIndexBuilder::from_existing_index(self.params.clone(), tokens, inverted_list, docs) - } } #[async_trait] @@ -206,11 +234,16 @@ impl Index for InvertedIndex { fn statistics(&self) -> Result { Ok(serde_json::json!({ + "params": self.params, "num_tokens": self.tokens.tokens.len(), "num_docs": self.docs.token_count.len(), })) } + async fn prewarm(&self) -> Result<()> { + self.inverted_list.prewarm().await + } + fn index_type(&self) -> crate::IndexType { crate::IndexType::Inverted } @@ -224,23 +257,20 @@ impl Index for InvertedIndex { impl ScalarIndex for InvertedIndex { // return the row ids of the documents that contain the query #[instrument(level = "debug", skip_all)] - async fn search(&self, query: &dyn AnyQuery) -> Result { + async fn search( + &self, + query: &dyn AnyQuery, + _metrics: &dyn MetricsCollector, + ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); - let row_ids = match query { - SargableQuery::FullTextSearch(query) => self - .full_text_search(query, Arc::new(NoFilter)) - .await? - .into_iter() - .map(|(row_id, _)| row_id), - query => { - return Err(Error::invalid_input( - format!("unsupported query {:?} for inverted index", query), - location!(), - )) - } - }; + return Err(Error::invalid_input( + format!("unsupported query {:?} for inverted index", query), + location!(), + )); + } - Ok(RowIdTreeMap::from_iter(row_ids)) + fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool { + true } async fn load(store: Arc) -> Result> @@ -289,6 +319,7 @@ impl ScalarIndex for InvertedIndex { tokenizer_config, }; Ok(Arc::new(Self { + io_parallelism: store.io_parallelism(), params, tokenizer, tokens, @@ -314,24 +345,93 @@ impl ScalarIndex for InvertedIndex { } } +// at indexing, we use HashMap because we need it to be mutable, +// at searching, we use fst::Map because it's more efficient +#[derive(Debug, Clone)] +pub enum TokenMap { + HashMap(HashMap), + Fst(fst::Map>), +} + +impl Default for TokenMap { + fn default() -> Self { + Self::HashMap(HashMap::new()) + } +} + +impl DeepSizeOf for TokenMap { + fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize { + match self { + Self::HashMap(map) => map.deep_size_of_children(ctx), + Self::Fst(map) => map.as_fst().size(), + } + } +} + +impl TokenMap { + pub fn len(&self) -> usize { + match self { + Self::HashMap(map) => map.len(), + Self::Fst(map) => map.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + // TokenSet is a mapping from tokens to token ids -// it also records the frequency of each token #[derive(Debug, Clone, Default, DeepSizeOf)] pub struct TokenSet { - // token -> (token_id, frequency) - pub(crate) tokens: HashMap, + // token -> token_id + pub(crate) tokens: TokenMap, pub(crate) next_id: u32, total_length: usize, } impl TokenSet { + pub fn into_mut(self) -> Self { + let tokens = match self.tokens { + TokenMap::HashMap(map) => map, + TokenMap::Fst(map) => { + let mut new_map = HashMap::with_capacity(map.len()); + let mut stream = map.into_stream(); + while let Some((token, token_id)) = stream.next() { + new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32); + } + + new_map + } + }; + + Self { + tokens: TokenMap::HashMap(tokens), + next_id: self.next_id, + total_length: self.total_length, + } + } + + pub fn num_tokens(&self) -> usize { + self.tokens.len() + } + pub fn to_batch(self) -> Result { let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length); let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len()); - for (token, token_id) in self.tokens.into_iter().sorted_unstable() { - token_builder.append_value(token); - token_id_builder.append_value(token_id); + + if let TokenMap::HashMap(map) = self.tokens { + for (token, token_id) in map.into_iter().sorted_unstable() { + token_builder.append_value(&token); + token_id_builder.append_value(token_id); + } + } else { + return Err(Error::Index { + message: "tokens is not a HashMap".to_owned(), + location: location!(), + }); } + let token_col = token_builder.finish(); let token_id_col = token_id_builder.finish(); @@ -353,21 +453,29 @@ impl TokenSet { pub async fn load(reader: Arc) -> Result { let mut next_id = 0; let mut total_length = 0; - let mut tokens = HashMap::new(); + let mut tokens = fst::MapBuilder::memory(); let batch = reader.read_range(0..reader.num_rows(), None).await?; let token_col = batch[TOKEN_COL].as_string::(); let token_id_col = batch[TOKEN_ID_COL].as_primitive::(); for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) { - let token = token.unwrap(); + let token = token.ok_or(Error::Index { + message: "found null token in token set".to_owned(), + location: location!(), + })?; next_id = next_id.max(token_id + 1); total_length += token.len(); - tokens.insert(token.to_owned(), token_id); + tokens + .insert(token, token_id as u64) + .map_err(|e| Error::Index { + message: format!("failed to insert token {}: {}", token, e), + location: location!(), + })?; } Ok(Self { - tokens, + tokens: TokenMap::Fst(tokens.into_map()), next_id, total_length, }) @@ -376,7 +484,10 @@ impl TokenSet { pub fn add(&mut self, token: String) -> u32 { let next_id = self.next_id(); let len = token.len(); - let token_id = *self.tokens.entry(token).or_insert(next_id); + let token_id = match self.tokens { + TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id), + _ => unreachable!("tokens must be HashMap while indexing"), + }; // add token if it doesn't exist if token_id == next_id { @@ -388,7 +499,10 @@ impl TokenSet { } pub fn get(&self, token: &str) -> Option { - self.tokens.get(token).copied() + match self.tokens { + TokenMap::HashMap(ref map) => map.get(token).copied(), + TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32), + } } pub fn next_id(&self) -> u32 { @@ -493,18 +607,21 @@ impl InvertedListReader { Ok(batch) } - #[instrument(level = "debug", skip(self))] + #[instrument(level = "debug", skip(self, metrics))] pub(crate) async fn posting_list( &self, token_id: u32, is_phrase_query: bool, + metrics: &dyn MetricsCollector, ) -> Result { let mut posting = self .posting_cache .try_get_with(token_id, async move { + metrics.record_part_load(); + info!(target: TRACE_IO_EVENTS, type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id); let batch = self.posting_batch(token_id, false).await?; - let row_ids = batch[ROW_ID].as_primitive::().clone(); - let frequencies = batch[FREQUENCY_COL].as_primitive::().clone(); + let row_ids = batch[ROW_ID].as_primitive::(); + let frequencies = batch[FREQUENCY_COL].as_primitive::(); Result::Ok(PostingList::new( row_ids.values().clone(), frequencies.values().clone(), @@ -525,6 +642,34 @@ impl InvertedListReader { Ok(posting) } + async fn prewarm(&self) -> Result<()> { + let batch = self + .reader + .read_range(0..self.reader.num_rows(), Some(&[ROW_ID, FREQUENCY_COL])) + .await?; + for token_id in 0..self.offsets.len() { + let offset = self.offsets[token_id]; + let length = self.posting_len(token_id as u32); + let batch = batch.slice(offset, length); + let row_ids = batch[ROW_ID].as_primitive::(); + let frequencies = batch[FREQUENCY_COL].as_primitive::(); + self.posting_cache + .insert( + token_id as u32, + PostingList::new( + row_ids.values().clone(), + frequencies.values().clone(), + self.max_scores + .as_ref() + .map(|max_scores| max_scores[token_id]), + ), + ) + .await; + } + + Ok(()) + } + async fn read_positions(&self, token_id: u32) -> Result { self.position_cache.try_get_with(token_id, async move { let length = self.posting_len(token_id); @@ -533,10 +678,16 @@ impl InvertedListReader { let batch = self .reader .read_range(offset..offset + length, Some(&[POSITION_COL])) - .await?; - Result::Ok(batch - .column_by_name(POSITION_COL) - .ok_or(Error::Index { message: "position is not found but required for phrase queries, try recreating the index with position".to_owned(), location: location!() })? + .await.map_err(|e| { + match e { + Error::Schema { .. } => Error::Index { + message: "position is not found but required for phrase queries, try recreating the index with position".to_owned(), + location: location!(), + }, + e => e + } + })?; + Result::Ok(batch[POSITION_COL] .as_list::() .clone()) }).await.map_err(|e| Error::io(e.to_string(), location!())) @@ -1063,19 +1214,19 @@ pub fn flat_bm25_search( let score_col = Arc::new(Float32Array::from(scores)) as ArrayRef; let batch = batch - .drop_column(doc_col)? - .try_with_column(SCORE_FIELD.clone(), score_col)?; + .try_with_column(SCORE_FIELD.clone(), score_col)? + .project_by_schema(&FTS_SCHEMA)?; // the scan node would probably scan some extra columns for prefilter, drop them here Ok(batch) } pub fn flat_bm25_search_stream( input: SendableRecordBatchStream, doc_col: String, - query: FullTextSearchQuery, + query: String, index: &InvertedIndex, ) -> SendableRecordBatchStream { let mut tokenizer = index.tokenizer.clone(); - let query_token_ids = collect_tokens(&query.query, &mut tokenizer, None) + let query_token_ids = collect_tokens(&query, &mut tokenizer, None) .into_iter() .dedup() .map(|token| { @@ -1090,7 +1241,7 @@ pub fn flat_bm25_search_stream( let stream = input.map(move |batch| { let batch = batch?; - flat_bm25_search( + let scored_batch = flat_bm25_search( batch, &doc_col, inverted_list.as_ref(), @@ -1099,30 +1250,22 @@ pub fn flat_bm25_search_stream( &mut tokenizer, avgdl, num_docs, - ) + )?; + + // filter out rows with score 0 + let score_col = scored_batch[SCORE_COL].as_primitive::(); + let mask = score_col + .iter() + .map(|score| score.is_some_and(|score| score > 0.0)) + .collect::>(); + let mask = BooleanArray::from(mask); + let batch = arrow::compute::filter_record_batch(&scored_batch, &mask)?; + Ok(batch) }); Box::pin(RecordBatchStreamAdapter::new(FTS_SCHEMA.clone(), stream)) as SendableRecordBatchStream } -pub fn collect_tokens( - text: &str, - tokenizer: &mut tantivy::tokenizer::TextAnalyzer, - inclusive: Option<&HashSet>, -) -> Vec { - let mut stream = tokenizer.token_stream(text); - let mut tokens = Vec::new(); - while let Some(token) = stream.next() { - if let Some(inclusive) = inclusive { - if !inclusive.contains(&token.text) { - continue; - } - } - tokens.push(token.text.to_owned()); - } - tokens -} - pub fn is_phrase_query(query: &str) -> bool { query.starts_with('\"') && query.ends_with('\"') } diff --git a/rust/lance-index/src/scalar/inverted/query.rs b/rust/lance-index/src/scalar/inverted/query.rs new file mode 100644 index 00000000000..e08d133cd61 --- /dev/null +++ b/rust/lance-index/src/scalar/inverted/query.rs @@ -0,0 +1,587 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::collections::HashSet; + +use lance_core::{Error, Result}; +use serde::ser::SerializeMap; +use serde::{Deserialize, Serialize}; +use snafu::location; + +#[derive(Debug, Clone)] +pub struct FtsSearchParams { + pub limit: Option, + pub wand_factor: f32, +} + +impl FtsSearchParams { + pub fn new() -> Self { + Self { + limit: None, + wand_factor: 1.0, + } + } + + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + pub fn with_wand_factor(mut self, factor: f32) -> Self { + self.wand_factor = factor; + self + } +} + +impl Default for FtsSearchParams { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub enum Operator { + And, + Or, +} + +impl Default for Operator { + fn default() -> Self { + Self::Or + } +} + +impl TryFrom<&str> for Operator { + type Error = Error; + fn try_from(value: &str) -> Result { + match value.to_ascii_uppercase().as_str() { + "AND" => Ok(Self::And), + "OR" => Ok(Self::Or), + _ => Err(Error::invalid_input( + format!("Invalid operator: {}", value), + location!(), + )), + } + } +} + +pub trait FtsQueryNode { + fn columns(&self) -> HashSet; +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FtsQuery { + // leaf queries + Match(MatchQuery), + Phrase(PhraseQuery), + + // compound queries + Boost(BoostQuery), + MultiMatch(MultiMatchQuery), +} + +impl std::fmt::Display for FtsQuery { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Match(query) => write!(f, "Match({:?})", query), + Self::Phrase(query) => write!(f, "Phrase({:?})", query), + Self::Boost(query) => write!( + f, + "Boosting(positive={}, negative={}, negative_boost={})", + query.positive, query.negative, query.negative_boost + ), + Self::MultiMatch(query) => write!(f, "MultiMatch({:?})", query), + } + } +} + +impl FtsQueryNode for FtsQuery { + fn columns(&self) -> HashSet { + match self { + Self::Match(query) => query.columns(), + Self::Phrase(query) => query.columns(), + Self::Boost(query) => { + let mut columns = query.positive.columns(); + columns.extend(query.negative.columns()); + columns + } + Self::MultiMatch(query) => { + let mut columns = HashSet::new(); + for match_query in &query.match_queries { + columns.extend(match_query.columns()); + } + columns + } + } + } +} + +impl FtsQuery { + pub fn query(&self) -> String { + match self { + Self::Match(query) => query.terms.clone(), + Self::Phrase(query) => format!("\"{}\"", query.terms), // Phrase queries are quoted + Self::Boost(query) => query.positive.query(), + Self::MultiMatch(query) => query.match_queries[0].terms.clone(), + } + } + + pub fn is_missing_column(&self) -> bool { + match self { + Self::Match(query) => query.column.is_none(), + Self::Phrase(query) => query.column.is_none(), + Self::Boost(query) => { + query.positive.is_missing_column() || query.negative.is_missing_column() + } + Self::MultiMatch(query) => query.match_queries.iter().any(|q| q.column.is_none()), + } + } + + pub fn with_column(self, column: String) -> Self { + match self { + Self::Match(query) => Self::Match(query.with_column(Some(column))), + Self::Phrase(query) => Self::Phrase(query.with_column(Some(column))), + Self::Boost(query) => { + let positive = query.positive.with_column(column.clone()); + let negative = query.negative.with_column(column); + Self::Boost(BoostQuery { + positive: Box::new(positive), + negative: Box::new(negative), + negative_boost: query.negative_boost, + }) + } + Self::MultiMatch(query) => { + let match_queries = query + .match_queries + .into_iter() + .map(|q| q.with_column(Some(column.clone()))) + .collect(); + Self::MultiMatch(MultiMatchQuery { match_queries }) + } + } + } +} + +impl From for FtsQuery { + fn from(query: MatchQuery) -> Self { + Self::Match(query) + } +} + +impl From for FtsQuery { + fn from(query: PhraseQuery) -> Self { + Self::Phrase(query) + } +} + +impl From for FtsQuery { + fn from(query: BoostQuery) -> Self { + Self::Boost(query) + } +} + +impl From for FtsQuery { + fn from(query: MultiMatchQuery) -> Self { + Self::MultiMatch(query) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MatchQuery { + // The column to search in. + // If None, it will be determined at query time. + pub column: Option, + pub terms: String, + + // literal default is not supported so we set it by function + #[serde(default = "MatchQuery::default_boost")] + pub boost: f32, + + // The max edit distance for fuzzy matching. + // If Some(0), it will be exact match. + // If None, it will be determined automatically by the rules: + // - 0 for terms with length <= 2 + // - 1 for terms with length <= 5 + // - 2 for terms with length > 5 + pub fuzziness: Option, + + /// The maximum number of terms to expand for fuzzy matching. + /// Default to 50. + #[serde(default = "MatchQuery::default_max_expansions")] + pub max_expansions: usize, + + /// The operator to use for combining terms. + /// This can be either `And` or `Or`, it's 'Or' by default. + /// - `And`: All terms must match. + /// - `Or`: At least one term must match. + #[serde(default)] + pub operator: Operator, +} + +impl MatchQuery { + pub fn new(terms: String) -> Self { + Self { + column: None, + terms, + boost: 1.0, + fuzziness: Some(0), + max_expansions: 50, + operator: Operator::Or, + } + } + + fn default_boost() -> f32 { + 1.0 + } + + fn default_max_expansions() -> usize { + 50 + } + + pub fn with_column(mut self, column: Option) -> Self { + self.column = column; + self + } + + pub fn with_boost(mut self, boost: f32) -> Self { + self.boost = boost; + self + } + + pub fn with_fuzziness(mut self, fuzziness: Option) -> Self { + self.fuzziness = fuzziness; + self + } + + pub fn with_max_expansions(mut self, max_expansions: usize) -> Self { + self.max_expansions = max_expansions; + self + } + + pub fn with_operator(mut self, operator: Operator) -> Self { + self.operator = operator; + self + } + + pub fn auto_fuzziness(token: &str) -> u32 { + match token.len() { + 0..=2 => 0, + 3..=5 => 1, + _ => 2, + } + } +} + +impl FtsQueryNode for MatchQuery { + fn columns(&self) -> HashSet { + let mut columns = HashSet::new(); + if let Some(column) = &self.column { + columns.insert(column.clone()); + } + columns + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PhraseQuery { + // The column to search in. + // If None, it will be determined at query time. + pub column: Option, + pub terms: String, +} + +impl PhraseQuery { + pub fn new(terms: String) -> Self { + Self { + column: None, + terms, + } + } + + pub fn with_column(mut self, column: Option) -> Self { + self.column = column; + self + } +} + +impl FtsQueryNode for PhraseQuery { + fn columns(&self) -> HashSet { + let mut columns = HashSet::new(); + if let Some(column) = &self.column { + columns.insert(column.clone()); + } + columns + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct BoostQuery { + pub positive: Box, + pub negative: Box, + #[serde(default = "BoostQuery::default_negative_boost")] + pub negative_boost: f32, +} + +impl BoostQuery { + pub fn new(positive: FtsQuery, negative: FtsQuery, negative_boost: Option) -> Self { + Self { + positive: Box::new(positive), + negative: Box::new(negative), + negative_boost: negative_boost.unwrap_or(0.5), + } + } + + fn default_negative_boost() -> f32 { + 0.5 + } +} + +impl FtsQueryNode for BoostQuery { + fn columns(&self) -> HashSet { + let mut columns = self.positive.columns(); + columns.extend(self.negative.columns()); + columns + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct MultiMatchQuery { + // each query must be a match query with specified column + pub match_queries: Vec, +} + +impl Serialize for MultiMatchQuery { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(Some(3))?; + + let query = self.match_queries.first().ok_or(serde::ser::Error::custom( + "MultiMatchQuery must have at least one MatchQuery".to_string(), + ))?; + map.serialize_entry("query", &query.terms)?; + let columns = self + .match_queries + .iter() + .map(|q| q.column.as_ref().unwrap().clone()) + .collect::>(); + map.serialize_entry("columns", &columns)?; + let boosts = self + .match_queries + .iter() + .map(|q| q.boost) + .collect::>(); + map.serialize_entry("boost", &boosts)?; + map.end() + } +} + +impl<'de> Deserialize<'de> for MultiMatchQuery { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct MultiMatchQueryData { + query: String, + columns: Vec, + boost: Option>, + } + + let data = MultiMatchQueryData::deserialize(deserializer)?; + let boosts = data.boost.unwrap_or(vec![1.0; data.columns.len()]); + + Self::try_new(data.query, data.columns) + .map_err(serde::de::Error::custom)? + .try_with_boosts(boosts) + .map_err(serde::de::Error::custom) + } +} + +impl MultiMatchQuery { + pub fn try_new(query: String, columns: Vec) -> Result { + if columns.is_empty() { + return Err(Error::invalid_input( + "Cannot create MultiMatchQuery with no columns".to_string(), + location!(), + )); + } + + let match_queries = columns + .into_iter() + .map(|column| MatchQuery::new(query.clone()).with_column(Some(column))) + .collect(); + Ok(Self { match_queries }) + } + + pub fn try_with_boosts(mut self, boosts: Vec) -> Result { + if boosts.len() != self.match_queries.len() { + return Err(Error::invalid_input( + "The number of boosts must match the number of queries".to_string(), + location!(), + )); + } + + for (query, boost) in self.match_queries.iter_mut().zip(boosts) { + query.boost = boost; + } + Ok(self) + } + + pub fn with_operator(mut self, operator: Operator) -> Self { + for query in &mut self.match_queries { + query.operator = operator; + } + self + } +} + +impl FtsQueryNode for MultiMatchQuery { + fn columns(&self) -> HashSet { + let mut columns = HashSet::with_capacity(self.match_queries.len()); + for query in &self.match_queries { + columns.extend(query.columns()); + } + columns + } +} + +pub fn collect_tokens( + text: &str, + tokenizer: &mut tantivy::tokenizer::TextAnalyzer, + inclusive: Option<&HashSet>, +) -> Vec { + let mut stream = tokenizer.token_stream(text); + let mut tokens = Vec::new(); + while let Some(token) = stream.next() { + if let Some(inclusive) = inclusive { + if !inclusive.contains(&token.text) { + continue; + } + } + tokens.push(token.text.to_owned()); + } + tokens +} + +pub fn fill_fts_query_column( + query: &FtsQuery, + columns: &[String], + replace: bool, +) -> Result { + if !query.is_missing_column() && !replace { + return Ok(query.clone()); + } + match query { + FtsQuery::Match(match_query) => { + match columns.len() { + 0 => { + Err(Error::invalid_input( + "Cannot perform full text search unless an INVERTED index has been created on at least one column".to_string(), + location!(), + )) + } + 1 => { + let column = columns[0].clone(); + let query = match_query.clone().with_column(Some(column)); + Ok(FtsQuery::Match(query)) + } + _ => { + // if there are multiple columns, we need to create a MultiMatch query + let multi_match_query = + MultiMatchQuery::try_new(match_query.terms.clone(), columns.to_vec())?; + Ok(FtsQuery::MultiMatch(multi_match_query)) + } + } + } + FtsQuery::Phrase(phrase_query) => { + match columns.len() { + 0 => { + Err(Error::invalid_input( + "Cannot perform full text search unless an INVERTED index has been created on at least one column".to_string(), + location!(), + )) + } + 1 => { + let column = columns[0].clone(); + let query = phrase_query.clone().with_column(Some(column)); + Ok(FtsQuery::Phrase(query)) + } + _ => { + Err(Error::invalid_input( + "the column must be specified in the query".to_string(), + location!(), + )) + } + } + } + FtsQuery::Boost(boost_query) => { + let positive = fill_fts_query_column(&boost_query.positive, columns, replace)?; + let negative = fill_fts_query_column(&boost_query.negative, columns, replace)?; + Ok(FtsQuery::Boost(BoostQuery { + positive: Box::new(positive), + negative: Box::new(negative), + negative_boost: boost_query.negative_boost, + })) + } + FtsQuery::MultiMatch(multi_match_query) => { + let match_queries = multi_match_query + .match_queries + .iter() + .map(|query| fill_fts_query_column(&FtsQuery::Match(query.clone()), columns, replace)) + .map(|result| { + result.map(|query| { + if let FtsQuery::Match(match_query) = query { + match_query + } else { + unreachable!("Expected MatchQuery") + } + }) + }) + .collect::>>()?; + Ok(FtsQuery::MultiMatch(MultiMatchQuery { match_queries })) + } + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_match_query_serde() { + use super::*; + use serde_json::json; + + let query = MatchQuery::new("hello world".to_string()) + .with_column(Some("text".to_string())) + .with_boost(2.0) + .with_fuzziness(Some(1)) + .with_max_expansions(10) + .with_operator(Operator::And); + + let serialized = serde_json::to_value(&query).unwrap(); + let expected = json!({ + "column": "text", + "terms": "hello world", + "boost": 2.0, + "fuzziness": 1, + "max_expansions": 10, + "operator": "And" + }); + assert_eq!(serialized, expected); + + let expected = json!({ + "column": "text", + "terms": "hello world", + "fuzziness": 0, + }); + let query = serde_json::from_str::(&expected.to_string()).unwrap(); + assert_eq!(query.column, Some("text".to_owned())); + assert_eq!(query.terms, "hello world"); + assert_eq!(query.boost, 1.0); + assert_eq!(query.fuzziness, Some(0)); + assert_eq!(query.max_expansions, 50); + assert_eq!(query.operator, Operator::Or); + } +} diff --git a/rust/lance-index/src/scalar/inverted/tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer.rs index 440def7a5a1..6a3a8323e5a 100644 --- a/rust/lance-index/src/scalar/inverted/tokenizer.rs +++ b/rust/lance-index/src/scalar/inverted/tokenizer.rs @@ -1,9 +1,17 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::{env, path::PathBuf}; + use lance_core::{Error, Result}; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; + +#[cfg(feature = "tokenizer-lindera")] +mod lindera; + +#[cfg(feature = "tokenizer-jieba")] +mod jieba; /// Tokenizer configs #[derive(Debug, Clone, Serialize, Deserialize)] @@ -12,6 +20,8 @@ pub struct TokenizerConfig { /// - `simple`: splits tokens on whitespace and punctuation /// - `whitespace`: splits tokens on whitespace /// - `raw`: no tokenization + /// - `lindera/*`: Lindera tokenizer + /// - `jieba/*`: Jieba tokenizer /// /// `simple` is recommended for most cases and the default value base_tokenizer: String, @@ -141,9 +151,70 @@ fn build_base_tokenizer_builder(name: &str) -> Result { + let Some(home) = language_model_home() else { + return Err(Error::invalid_input( + format!("unknown base tokenizer {}", name), + location!(), + )); + }; + lindera::LinderaBuilder::load(&home.join(s))?.build() + } + #[cfg(feature = "tokenizer-jieba")] + s if s.starts_with("jieba/") || s == "jieba" => { + let s = if s == "jieba" { "jieba/default" } else { s }; + let Some(home) = language_model_home() else { + return Err(Error::invalid_input( + format!("unknown base tokenizer {}", name), + location!(), + )); + }; + jieba::JiebaBuilder::load(&home.join(s))?.build() + } _ => Err(Error::invalid_input( format!("unknown base tokenizer {}", name), location!(), )), } } + +pub const LANCE_LANGUAGE_MODEL_HOME_ENV_KEY: &str = "LANCE_LANGUAGE_MODEL_HOME"; + +pub const LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY: &str = "lance/language_models"; + +pub const LANCE_LANGUAGE_MODEL_CONFIG_FILE: &str = "config.json"; + +pub fn language_model_home() -> Option { + match env::var(LANCE_LANGUAGE_MODEL_HOME_ENV_KEY) { + Ok(p) => Some(PathBuf::from(p)), + Err(_) => dirs::data_local_dir().map(|p| p.join(LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY)), + } +} + +#[cfg(feature = "tokenizer-common")] +trait TokenizerBuilder: Sized { + type Config: serde::de::DeserializeOwned + Default; + fn load(p: &std::path::Path) -> Result { + if !p.is_dir() { + return Err(Error::io( + format!("{} is not a valid directory", p.display()), + location!(), + )); + } + use std::{fs::File, io::BufReader}; + let config_path = p.join(LANCE_LANGUAGE_MODEL_CONFIG_FILE); + let config = if config_path.exists() { + let file = File::open(config_path)?; + let reader = BufReader::new(file); + serde_json::from_reader::, Self::Config>(reader)? + } else { + Self::Config::default() + }; + Self::new(config, p) + } + + fn new(config: Self::Config, root: &std::path::Path) -> Result; + + fn build(&self) -> Result; +} diff --git a/rust/lance-index/src/scalar/inverted/tokenizer/jieba.rs b/rust/lance-index/src/scalar/inverted/tokenizer/jieba.rs new file mode 100644 index 00000000000..9d6152a060e --- /dev/null +++ b/rust/lance-index/src/scalar/inverted/tokenizer/jieba.rs @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::path::{Path, PathBuf}; + +use super::TokenizerBuilder; +use lance_core::{Error, Result}; +use serde::{Deserialize, Serialize}; +use snafu::location; + +#[derive(Serialize, Deserialize, Default)] +pub struct JiebaConfig { + main: Option, + users: Option>, +} + +pub struct JiebaBuilder { + root: PathBuf, + config: JiebaConfig, +} + +impl JiebaBuilder { + fn main_dict_path(&self) -> PathBuf { + if let Some(p) = &self.config.main { + return self.root.join(p); + } + self.root.join("dict.txt") + } + + fn user_dict_paths(&self) -> Vec { + let Some(users) = &self.config.users else { + return vec![]; + }; + users.iter().map(|p| self.root.join(p)).collect() + } +} + +impl TokenizerBuilder for JiebaBuilder { + type Config = JiebaConfig; + + fn new(config: Self::Config, root: &Path) -> Result { + Ok(Self { + config, + root: root.to_path_buf(), + }) + } + + fn build(&self) -> Result { + let main_dict_path = &self.main_dict_path(); + let file = std::fs::File::open(main_dict_path)?; + let mut f = std::io::BufReader::new(file); + let mut jieba = jieba_rs::Jieba::with_dict(&mut f).map_err(|e| { + Error::io( + format!( + "load jieba tokenizer dictionary {}, error: {}", + main_dict_path.display(), + e + ), + location!(), + ) + })?; + for user_dict_path in &self.user_dict_paths() { + let file = std::fs::File::open(user_dict_path)?; + let mut f = std::io::BufReader::new(file); + jieba.load_dict(&mut f).map_err(|e| { + Error::io( + format!( + "load jieba tokenizer user dictionary {}, error: {}", + user_dict_path.display(), + e + ), + location!(), + ) + })? + } + let tokenizer = JiebaTokenizer { jieba }; + Ok(tantivy::tokenizer::TextAnalyzer::builder(tokenizer).dynamic()) + } +} + +#[derive(Clone)] +struct JiebaTokenizer { + jieba: jieba_rs::Jieba, +} + +struct JiebaTokenStream { + tokens: Vec, + index: usize, +} + +impl tantivy::tokenizer::TokenStream for JiebaTokenStream { + fn advance(&mut self) -> bool { + if self.index < self.tokens.len() { + self.index += 1; + true + } else { + false + } + } + + fn token(&self) -> &tantivy::tokenizer::Token { + &self.tokens[self.index - 1] + } + + fn token_mut(&mut self) -> &mut tantivy::tokenizer::Token { + &mut self.tokens[self.index - 1] + } +} + +#[cfg(feature = "tokenizer-jieba")] +impl tantivy::tokenizer::Tokenizer for JiebaTokenizer { + type TokenStream<'a> = JiebaTokenStream; + + fn token_stream(&mut self, text: &str) -> JiebaTokenStream { + let mut indices = text.char_indices().collect::>(); + indices.push((text.len(), '\0')); + let orig_tokens = self + .jieba + .tokenize(text, jieba_rs::TokenizeMode::Search, true); + let mut tokens = Vec::new(); + for token in orig_tokens { + tokens.push(tantivy::tokenizer::Token { + offset_from: indices[token.start].0, + offset_to: indices[token.end].0, + position: token.start, + text: String::from(&text[(indices[token.start].0)..(indices[token.end].0)]), + position_length: token.end - token.start, + }); + } + JiebaTokenStream { tokens, index: 0 } + } +} diff --git a/rust/lance-index/src/scalar/inverted/tokenizer/lindera.rs b/rust/lance-index/src/scalar/inverted/tokenizer/lindera.rs new file mode 100644 index 00000000000..ad88027753a --- /dev/null +++ b/rust/lance-index/src/scalar/inverted/tokenizer/lindera.rs @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::path::{Path, PathBuf}; + +use super::TokenizerBuilder; +use lance_core::{Error, Result}; +use lindera::{ + dictionary::{ + load_dictionary_from_path, load_user_dictionary_from_config, UserDictionaryConfig, + }, + mode::Mode, + segmenter::Segmenter, +}; +use lindera_tantivy::tokenizer::LinderaTokenizer; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use snafu::location; + +#[derive(Serialize, Deserialize, Default)] +pub struct LinderaConfig { + main: Option, + user: Option, + user_kind: Option, +} + +pub struct LinderaBuilder { + root: PathBuf, + config: LinderaConfig, +} + +impl LinderaBuilder { + fn main_dict_path(&self) -> PathBuf { + if let Some(p) = &self.config.main { + return self.root.join(p); + } + self.root.join("main") + } + + fn user_dict_config(&self) -> Result> { + let Some(user_dict_path) = &self.config.user else { + return Ok(None); + }; + let mut conf = Map::::new(); + let user_path = self.root.join(user_dict_path); + let Some(p) = user_path.to_str() else { + return Err(Error::io( + format!( + "invalid lindera tokenizer user dictionary path: {}", + user_path.display() + ), + location!(), + )); + }; + conf.insert(String::from("path"), Value::String(String::from(p))); + if let Some(kind) = &self.config.user_kind { + conf.insert(String::from("kind"), Value::String(kind.clone())); + } + Ok(Some(Value::Object(conf))) + } +} + +impl TokenizerBuilder for LinderaBuilder { + type Config = LinderaConfig; + + fn new(config: Self::Config, root: &Path) -> Result { + Ok(Self { + config, + root: root.to_path_buf(), + }) + } + + fn build(&self) -> Result { + let main_path = self.main_dict_path(); + let dictionary = load_dictionary_from_path(main_path.as_path()).map_err(|e| { + Error::io( + format!( + "load lindera tokenizer main dictionary from {}, error: {}", + main_path.display(), + e + ), + location!(), + ) + })?; + let user_dictionary = match self.user_dict_config()? { + Some(conf) => { + let user_dictionary = load_user_dictionary_from_config(&conf).map_err(|e| { + Error::io( + format!("load lindera tokenizer user dictionary, conf:{conf}, err: {e}"), + location!(), + ) + })?; + Some(user_dictionary) + } + None => None, + }; + let mode = Mode::Normal; + let segmenter = Segmenter::new(mode, dictionary, user_dictionary); + let tokenizer = LinderaTokenizer::from_segmenter(segmenter); + Ok(tantivy::tokenizer::TextAnalyzer::builder(tokenizer).dynamic()) + } +} diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index d5e16316717..9121d54e42d 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -7,13 +7,13 @@ use std::sync::Arc; use arrow::datatypes::Int32Type; use arrow_array::PrimitiveArray; -use itertools::Itertools; use lance_core::utils::mask::RowIdMask; use lance_core::Result; use tracing::instrument; use super::builder::OrderedDoc; use super::index::{idf, K1}; +use super::query::Operator; use super::{DocInfo, PostingList}; #[derive(Clone)] @@ -43,7 +43,10 @@ impl PartialOrd for PostingIterator { impl Ord for PostingIterator { fn cmp(&self, other: &Self) -> std::cmp::Ordering { match (self.doc(), other.doc()) { - (Some(doc1), Some(doc2)) => doc1.cmp(&doc2), + (Some(doc1), Some(doc2)) => doc1.cmp(&doc2).then( + self.approximate_upper_bound + .total_cmp(&other.approximate_upper_bound), + ), (Some(_), None) => std::cmp::Ordering::Less, (None, Some(_)) => std::cmp::Ordering::Greater, (None, None) => std::cmp::Ordering::Equal, @@ -118,17 +121,32 @@ pub struct Wand { cur_doc: Option, num_docs: usize, postings: Vec, - candidates: BinaryHeap>, } impl Wand { - pub(crate) fn new(num_docs: usize, postings: impl Iterator) -> Self { + pub(crate) fn new( + num_docs: usize, + operator: Operator, + postings: impl Iterator, + ) -> Self { + let mut posting_lists = postings.collect::>(); + posting_lists.sort_unstable(); + let threshold = match operator { + Operator::Or => 0.0, + Operator::And => posting_lists + .iter() + .map(|posting| posting.approximate_upper_bound()) + .sum::(), + }; + Self { - threshold: 0.0, + threshold, cur_doc: None, num_docs, - postings: postings.filter(|posting| posting.doc().is_some()).collect(), - candidates: BinaryHeap::new(), + postings: posting_lists + .into_iter() + .filter(|posting| posting.doc().is_some()) + .collect(), } } @@ -139,46 +157,32 @@ impl Wand { limit: usize, factor: f32, scorer: impl Fn(u64, f32) -> f32, - ) -> Result> { + ) -> Result<(Vec, Vec)> { if limit == 0 { - return Ok(vec![]); + return Ok((vec![], vec![])); } - let num_query_tokens = self.postings.len(); + let mut candidates = BinaryHeap::new(); while let Some(doc) = self.next().await? { - if is_phrase_query { - // all the tokens should be in the same document cause it's a phrase query - if self.postings.len() != num_query_tokens { - break; - } - if let Some(last) = self.postings.last() { - if last.doc().unwrap().row_id != doc { - continue; - } - } - - if !self.check_positions() { - continue; - } + if is_phrase_query && !self.check_positions() { + continue; } let score = self.score(doc, &scorer); - if self.candidates.len() < limit { - self.candidates.push(Reverse(OrderedDoc::new(doc, score))); - } else if score > self.threshold { - self.candidates.pop(); - self.candidates.push(Reverse(OrderedDoc::new(doc, score))); - self.threshold = self.candidates.peek().unwrap().0.score.0 * factor; + if candidates.len() < limit { + candidates.push(Reverse(OrderedDoc::new(doc, score))); + } else if score > candidates.peek().unwrap().0.score.0 { + candidates.pop(); + candidates.push(Reverse(OrderedDoc::new(doc, score))); + self.threshold = candidates.peek().unwrap().0.score.0 * factor; } } - Ok(self - .candidates - .iter() - .map(|doc| (doc.0.row_id, doc.0.score)) - .sorted_unstable() - .map(|(row_id, score)| (row_id, score.0)) - .collect()) + Ok(candidates + .into_sorted_vec() + .into_iter() + .map(|Reverse(doc)| (doc.row_id, doc.score.0)) + .unzip()) } // calculate the score of the document @@ -201,7 +205,6 @@ impl Wand { // find the next doc candidate #[instrument(level = "debug", name = "wand_next", skip_all)] async fn next(&mut self) -> Result> { - self.postings.sort_unstable(); while let Some(pivot_posting) = self.find_pivot_term() { let doc = pivot_posting .doc() @@ -211,7 +214,7 @@ impl Wand { if self.cur_doc.is_some() && doc.row_id <= cur_doc { self.move_term(cur_doc + 1); } else if self.postings[0].doc().unwrap().row_id == doc.row_id { - // all the posting iterators have reached this doc id, + // all the posting iterators preceding pivot have reached this doc id, // so that means the sum of upper bound of all terms is not less than the threshold, // this document is a candidate self.cur_doc = Some(doc.row_id); @@ -264,6 +267,8 @@ impl Wand { if doc.row_id >= least_id { break; } + // a shorter posting list means this term is rare and more likely to skip more documents, + // so we prefer the term with a shorter posting list. if posting.list.len() < least_length { least_length = posting.list.len(); pick_index = i; diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index 0d487e59361..a453a1a08db 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -14,7 +14,7 @@ use deepsize::DeepSizeOf; use futures::{stream::BoxStream, StreamExt, TryStream, TryStreamExt}; use lance_core::{utils::mask::RowIdTreeMap, Error, Result}; use roaring::RoaringBitmap; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use crate::{Index, IndexType}; @@ -23,10 +23,27 @@ use super::{bitmap::train_bitmap_index, SargableQuery}; use super::{ bitmap::BitmapIndex, btree::TrainingSource, AnyQuery, IndexStore, LabelListQuery, ScalarIndex, }; +use super::{MetricsCollector, SearchResult}; pub const BITMAP_LOOKUP_NAME: &str = "bitmap_page_lookup.lance"; -trait LabelListSubIndex: ScalarIndex + DeepSizeOf {} +#[async_trait] +trait LabelListSubIndex: ScalarIndex + DeepSizeOf { + async fn search_exact( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { + let result = self.search(query, metrics).await?; + match result { + SearchResult::Exact(row_ids) => Ok(row_ids), + _ => Err(Error::Internal { + message: "Label list sub-index should return exact results".to_string(), + location: location!(), + }), + } + } +} impl LabelListSubIndex for T {} @@ -61,6 +78,10 @@ impl Index for LabelListIndex { }) } + async fn prewarm(&self) -> Result<()> { + self.values_index.prewarm().await + } + fn index_type(&self) -> IndexType { IndexType::LabelList } @@ -78,11 +99,12 @@ impl LabelListIndex { fn search_values<'a>( &'a self, values: &'a Vec, - ) -> BoxStream> { + metrics: &'a dyn MetricsCollector, + ) -> BoxStream<'a, Result> { futures::stream::iter(values) .then(move |value| { let value_query = SargableQuery::Equals(value.clone()); - async move { self.values_index.search(&value_query).await } + async move { self.values_index.search_exact(&value_query, metrics).await } }) .boxed() } @@ -120,21 +142,30 @@ impl LabelListIndex { #[async_trait] impl ScalarIndex for LabelListIndex { - #[instrument(skip(self), level = "debug")] - async fn search(&self, query: &dyn AnyQuery) -> Result { + #[instrument(skip_all, level = "debug")] + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); - match query { + let row_ids = match query { LabelListQuery::HasAllLabels(labels) => { - let values_results = self.search_values(labels); + let values_results = self.search_values(labels, metrics); self.set_intersection(values_results, labels.len() == 1) .await } LabelListQuery::HasAnyLabel(labels) => { - let values_results = self.search_values(labels); + let values_results = self.search_values(labels, metrics); self.set_union(values_results, labels.len() == 1).await } - } + }?; + Ok(SearchResult::Exact(row_ids)) + } + + fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool { + true } async fn load(store: Arc) -> Result> { @@ -158,7 +189,9 @@ impl ScalarIndex for LabelListIndex { new_data: SendableRecordBatchStream, dest_store: &dyn IndexStore, ) -> Result<()> { - self.values_index.update(new_data, dest_store).await + self.values_index + .update(unnest_chunks(new_data)?, dest_store) + .await } } diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index 75639db33e9..39c38bbadd2 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -16,7 +16,6 @@ use lance_core::{cache::FileMetadataCache, Error, Result}; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; use lance_file::v2; use lance_file::v2::reader::FileReaderOptions; -use lance_file::writer::FileWriterOptions; use lance_file::{ reader::FileReader, writer::{FileWriter, ManifestProvider}, @@ -24,7 +23,6 @@ use lance_file::{ use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; use lance_io::{object_store::ObjectStore, ReadBatchParams}; use lance_table::format::SelfDescribingFileReader; -use lance_table::io::manifest::ManifestDescribing; use object_store::path::Path; use super::{IndexReader, IndexStore, IndexWriter}; @@ -40,7 +38,6 @@ pub struct LanceIndexStore { index_dir: Path, metadata_cache: FileMetadataCache, scheduler: Arc, - use_legacy_format: bool, } impl DeepSizeOf for LanceIndexStore { @@ -54,11 +51,10 @@ impl DeepSizeOf for LanceIndexStore { impl LanceIndexStore { /// Create a new index store at the given directory pub fn new( - object_store: ObjectStore, + object_store: Arc, index_dir: Path, metadata_cache: FileMetadataCache, ) -> Self { - let object_store = Arc::new(object_store); let scheduler = ScanScheduler::new( object_store.clone(), SchedulerConfig::max_bandwidth(&object_store), @@ -68,14 +64,8 @@ impl LanceIndexStore { index_dir, metadata_cache, scheduler, - use_legacy_format: false, } } - - pub fn with_legacy_format(mut self, use_legacy_format: bool) -> Self { - self.use_legacy_format = use_legacy_format; - self - } } #[async_trait] @@ -119,7 +109,7 @@ impl IndexWriter for v2::writer::FileWriter { #[async_trait] impl IndexReader for FileReader { - async fn read_record_batch(&self, offset: u32) -> Result { + async fn read_record_batch(&self, offset: u64, _batch_size: u64) -> Result { self.read_batch(offset as i32, ReadBatchParams::RangeFull, self.schema()) .await } @@ -136,7 +126,7 @@ impl IndexReader for FileReader { self.read_range(range, &projection).await } - async fn num_batches(&self) -> u32 { + async fn num_batches(&self, _batch_size: u64) -> u32 { self.num_batches() as u32 } @@ -151,8 +141,11 @@ impl IndexReader for FileReader { #[async_trait] impl IndexReader for v2::reader::FileReader { - async fn read_record_batch(&self, _offset: u32) -> Result { - unimplemented!("v2 format has no concept of row groups") + async fn read_record_batch(&self, offset: u64, batch_size: u64) -> Result { + let start = offset * batch_size; + let end = start + batch_size; + let end = end.min(self.num_rows()); + self.read_range(start as usize..end as usize, None).await } async fn read_range( @@ -189,8 +182,8 @@ impl IndexReader for v2::reader::FileReader { // V2 format has removed the row group concept, // so here we assume each batch is with 4096 rows. - async fn num_batches(&self) -> u32 { - unimplemented!("v2 format has no concept of row groups") + async fn num_batches(&self, batch_size: u64) -> u32 { + Self::num_rows(self).div_ceil(batch_size) as u32 } fn num_rows(&self) -> usize { @@ -219,24 +212,13 @@ impl IndexStore for LanceIndexStore { ) -> Result> { let path = self.index_dir.child(name); let schema = schema.as_ref().try_into()?; - if self.use_legacy_format { - let writer = FileWriter::::try_new( - &self.object_store, - &path, - schema, - &FileWriterOptions::default(), - ) - .await?; - Ok(Box::new(writer)) - } else { - let writer = self.object_store.create(&path).await?; - let writer = v2::writer::FileWriter::try_new( - writer, - schema, - v2::writer::FileWriterOptions::default(), - )?; - Ok(Box::new(writer)) - } + let writer = self.object_store.create(&path).await?; + let writer = v2::writer::FileWriter::try_new( + writer, + schema, + v2::writer::FileWriterOptions::default(), + )?; + Ok(Box::new(writer)) } async fn open_index_file(&self, name: &str) -> Result> { @@ -296,16 +278,29 @@ impl IndexStore for LanceIndexStore { Ok(()) } } + + async fn rename_index_file(&self, name: &str, new_name: &str) -> Result<()> { + let path = self.index_dir.child(name); + let new_path = self.index_dir.child(new_name); + self.object_store.copy(&path, &new_path).await?; + self.object_store.delete(&path).await + } + + async fn delete_index_file(&self, name: &str) -> Result<()> { + let path = self.index_dir.child(name); + self.object_store.delete(&path).await + } } #[cfg(test)] -mod tests { +pub mod tests { use std::{collections::HashMap, ops::Bound, path::Path}; + use crate::metrics::NoOpMetricsCollector; use crate::scalar::{ bitmap::{train_bitmap_index, BitmapIndex}, - btree::{train_btree_index, BTreeIndex, TrainingSource}, + btree::{train_btree_index, BTreeIndex, TrainingSource, DEFAULT_BTREE_BATCH_SIZE}, flat::FlatIndexMetadata, label_list::{train_label_list_index, LabelListIndex}, LabelListQuery, SargableQuery, ScalarIndex, @@ -315,7 +310,7 @@ mod tests { use arrow::{buffer::ScalarBuffer, datatypes::UInt8Type}; use arrow_array::{ cast::AsArray, - types::{Float32Type, Int32Type, UInt64Type}, + types::{Int32Type, UInt64Type}, RecordBatchIterator, RecordBatchReader, StringArray, UInt64Array, }; use arrow_schema::Schema as ArrowSchema; @@ -323,6 +318,7 @@ mod tests { use arrow_select::take::TakeOptions; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::ScalarValue; + use futures::FutureExt; use lance_core::{cache::CapacityMode, utils::mask::RowIdTreeMap}; use lance_datagen::{array, gen, ArrayGeneratorExt, BatchCount, ByteCount, RowCount}; use tempfile::{tempdir, TempDir}; @@ -330,31 +326,32 @@ mod tests { fn test_store(tempdir: &TempDir) -> Arc { let test_path: &Path = tempdir.path(); let (object_store, test_path) = - ObjectStore::from_path(test_path.as_os_str().to_str().unwrap()).unwrap(); + ObjectStore::from_uri(test_path.as_os_str().to_str().unwrap()) + .now_or_never() + .unwrap() + .unwrap(); let cache = FileMetadataCache::with_capacity(128 * 1024 * 1024, CapacityMode::Bytes); Arc::new(LanceIndexStore::new(object_store, test_path, cache)) } - fn legacy_test_store(tempdir: &TempDir) -> Arc { - let test_path: &Path = tempdir.path(); - let cache = FileMetadataCache::with_capacity(128 * 1024 * 1024, CapacityMode::Bytes); - let (object_store, test_path) = - ObjectStore::from_path(test_path.as_os_str().to_str().unwrap()).unwrap(); - Arc::new(LanceIndexStore::new(object_store, test_path, cache).with_legacy_format(true)) - } - - struct MockTrainingSource { + pub struct MockTrainingSource { data: SendableRecordBatchStream, } impl MockTrainingSource { - async fn new(data: impl RecordBatchReader + Send + 'static) -> Self { + pub async fn new(data: impl RecordBatchReader + Send + 'static) -> Self { Self { data: lance_datafusion::utils::reader_to_stream(Box::new(data)), } } } + impl From for MockTrainingSource { + fn from(data: SendableRecordBatchStream) -> Self { + Self { data } + } + } + #[async_trait] impl TrainingSource for MockTrainingSource { async fn scan_ordered_chunks( @@ -376,64 +373,88 @@ mod tests { index_store: &Arc, data: impl RecordBatchReader + Send + Sync + 'static, value_type: DataType, + custom_batch_size: Option, ) { let sub_index_trainer = FlatIndexMetadata::new(value_type); let data = Box::new(MockTrainingSource::new(data).await); - train_btree_index(data, &sub_index_trainer, index_store.as_ref()) - .await - .unwrap(); + let batch_size = custom_batch_size.unwrap_or(DEFAULT_BTREE_BATCH_SIZE); + train_btree_index( + data, + &sub_index_trainer, + index_store.as_ref(), + batch_size as u32, + ) + .await + .unwrap(); } #[tokio::test] async fn test_basic_btree() { let tempdir = tempdir().unwrap(); - let index_store = legacy_test_store(&tempdir); + let index_store = test_store(&tempdir); let data = gen() .col("values", array::step::()) .col("row_ids", array::step::()) .into_reader_rows(RowCount::from(4096), BatchCount::from(100)); - train_index(&index_store, data, DataType::Int32).await; + train_index(&index_store, data, DataType::Int32, None).await; let index = BTreeIndex::load(index_store).await.unwrap(); - let row_ids = index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(10000)))) + let result = index + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(10000))), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(10000)); - let row_ids = index - .search(&SargableQuery::Range( - Bound::Unbounded, - Bound::Excluded(ScalarValue::Int32(Some(-100))), - )) + let result = index + .search( + &SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::Int32(Some(-100))), + ), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); + assert_eq!(Some(0), row_ids.len()); - let row_ids = index - .search(&SargableQuery::Range( - Bound::Unbounded, - Bound::Excluded(ScalarValue::Int32(Some(100))), - )) + let result = index + .search( + &SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::Int32(Some(100))), + ), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); + assert_eq!(Some(100), row_ids.len()); } #[tokio::test] async fn test_btree_update() { let index_dir = tempdir().unwrap(); - let index_store = legacy_test_store(&index_dir); + let index_store = test_store(&index_dir); let data = gen() .col("values", array::step::()) .col("row_ids", array::step::()) .into_reader_rows(RowCount::from(4096), BatchCount::from(100)); - train_index(&index_store, data, DataType::Int32).await; + train_index(&index_store, data, DataType::Int32, None).await; let index = BTreeIndex::load(index_store).await.unwrap(); let data = gen() @@ -442,7 +463,7 @@ mod tests { .into_reader_rows(RowCount::from(4096), BatchCount::from(100)); let updated_index_dir = tempdir().unwrap(); - let updated_index_store = legacy_test_store(&updated_index_dir); + let updated_index_store = test_store(&updated_index_dir); index .update( lance_datafusion::utils::reader_to_stream(Box::new(data)), @@ -452,33 +473,46 @@ mod tests { .unwrap(); let updated_index = BTreeIndex::load(updated_index_store).await.unwrap(); - let row_ids = updated_index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(10000)))) + let result = updated_index + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(10000))), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); + assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(10000)); - let row_ids = updated_index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(500_000)))) + let result = updated_index + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(500_000))), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); + assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(500_000)); } async fn check(index: &BTreeIndex, query: SargableQuery, expected: &[u64]) { - let results = index.search(&query).await.unwrap(); + let results = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert!(results.is_exact()); let expected_arr = RowIdTreeMap::from_iter(expected); - assert_eq!(results, expected_arr); + assert_eq!(results.row_ids(), &expected_arr); } #[tokio::test] async fn test_btree_with_gaps() { let tempdir = tempdir().unwrap(); - let index_store = legacy_test_store(&tempdir); + let index_store = test_store(&tempdir); let batch_one = gen() .col("values", array::cycle::(vec![0, 1, 4, 5])) .col("row_ids", array::cycle::(vec![0, 1, 2, 3])) @@ -507,7 +541,7 @@ mod tests { Field::new("row_ids", DataType::UInt64, false), ])); let data = RecordBatchIterator::new(batches, schema); - train_index(&index_store, data, DataType::Int32).await; + train_index(&index_store, data, DataType::Int32, Some(4)).await; let index = BTreeIndex::load(index_store).await.unwrap(); // The above should create four pages @@ -698,12 +732,13 @@ mod tests { DataType::Date32, DataType::Time64(TimeUnit::Nanosecond), DataType::Time32(TimeUnit::Second), + DataType::FixedSizeBinary(16), // Not supported today, error from datafusion: // Min/max accumulator not implemented for Duration(Nanosecond) // DataType::Duration(TimeUnit::Nanosecond), ] { let tempdir = tempdir().unwrap(); - let index_store = legacy_test_store(&tempdir); + let index_store = test_store(&tempdir); let data: RecordBatch = gen() .col("values", array::rand_type(data_type)) .col("row_ids", array::step::()) @@ -742,14 +777,17 @@ mod tests { data.schema().clone(), ); - train_index(&index_store, training_data, data_type.clone()).await; + train_index(&index_store, training_data, data_type.clone(), None).await; let index = BTreeIndex::load(index_store).await.unwrap(); - let row_ids = index - .search(&SargableQuery::Equals(sample_value)) + let result = index + .search(&SargableQuery::Equals(sample_value), &NoOpMetricsCollector) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); + // The random data may have had duplicates so there might be more than 1 result // but even for boolean we shouldn't match the entire thing assert!(!row_ids.is_empty()); @@ -758,36 +796,10 @@ mod tests { } } - #[tokio::test] - async fn btree_reject_nan() { - let tempdir = tempdir().unwrap(); - let index_store = legacy_test_store(&tempdir); - let batch = gen() - .col("values", array::cycle::(vec![0.0, f32::NAN])) - .col("row_ids", array::cycle::(vec![0, 1])) - .into_batch_rows(RowCount::from(2)); - let batches = vec![batch]; - let schema = Arc::new(Schema::new(vec![ - Field::new("values", DataType::Float32, false), - Field::new("row_ids", DataType::UInt64, false), - ])); - let data = RecordBatchIterator::new(batches, schema); - let sub_index_trainer = FlatIndexMetadata::new(DataType::Float32); - - let data = Box::new(MockTrainingSource::new(data).await); - // Until DF handles NaN reliably we need to make sure we reject input - // containing NaN - assert!( - train_btree_index(data, &sub_index_trainer, index_store.as_ref()) - .await - .is_err() - ); - } - #[tokio::test] async fn btree_entire_null_page() { let tempdir = tempdir().unwrap(); - let index_store = legacy_test_store(&tempdir); + let index_store = test_store(&tempdir); let batch = gen() .col( "values", @@ -805,21 +817,36 @@ mod tests { let sub_index_trainer = FlatIndexMetadata::new(DataType::Utf8); let data = Box::new(MockTrainingSource::new(data).await); - train_btree_index(data, &sub_index_trainer, index_store.as_ref()) - .await - .unwrap(); + train_btree_index( + data, + &sub_index_trainer, + index_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE as u32, + ) + .await + .unwrap(); let index = BTreeIndex::load(index_store).await.unwrap(); - let row_ids = index - .search(&SargableQuery::Equals(ScalarValue::Utf8(Some( - "foo".to_string(), - )))) + let result = index + .search( + &SargableQuery::Equals(ScalarValue::Utf8(Some("foo".to_string()))), + &NoOpMetricsCollector, + ) .await .unwrap(); + + assert!(result.is_exact()); + let row_ids = result.row_ids(); + assert!(row_ids.is_empty()); - let row_ids = index.search(&SargableQuery::IsNull()).await.unwrap(); + let result = index + .search(&SargableQuery::IsNull(), &NoOpMetricsCollector) + .await + .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert_eq!(row_ids.len(), Some(4096)); } @@ -871,21 +898,29 @@ mod tests { let index = BitmapIndex::load(index_store).await.unwrap(); - let row_ids = index - .search(&SargableQuery::Equals(ScalarValue::Utf8(None))) + let result = index + .search( + &SargableQuery::Equals(ScalarValue::Utf8(None)), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(2)); - let row_ids = index - .search(&SargableQuery::Equals(ScalarValue::Utf8(Some( - "abcd".to_string(), - )))) + let result = index + .search( + &SargableQuery::Equals(ScalarValue::Utf8(Some("abcd".to_string()))), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert_eq!(Some(3), row_ids.len()); assert!(row_ids.contains(1)); assert!(row_ids.contains(3)); @@ -903,39 +938,55 @@ mod tests { train_bitmap(&index_store, data).await; let index = BitmapIndex::load(index_store).await.unwrap(); - let row_ids = index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(10000)))) + let result = index + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(10000))), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(10000)); - let row_ids = index - .search(&SargableQuery::Range( - Bound::Unbounded, - Bound::Excluded(ScalarValue::Int32(Some(-100))), - )) + let result = index + .search( + &SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::Int32(Some(-100))), + ), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert!(row_ids.is_empty()); - let row_ids = index - .search(&SargableQuery::Range( - Bound::Unbounded, - Bound::Excluded(ScalarValue::Int32(Some(100))), - )) + let result = index + .search( + &SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::Int32(Some(100))), + ), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert_eq!(Some(100), row_ids.len()); } async fn check_bitmap(index: &BitmapIndex, query: SargableQuery, expected: &[u64]) { - let results = index.search(&query).await.unwrap(); + let results = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert!(results.is_exact()); let expected_arr = RowIdTreeMap::from_iter(expected); - assert_eq!(results, expected_arr); + assert_eq!(results.row_ids(), &expected_arr); } #[tokio::test] @@ -1175,11 +1226,16 @@ mod tests { .unwrap(); let updated_index = BitmapIndex::load(updated_index_store).await.unwrap(); - let row_ids = updated_index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(5000)))) + let result = updated_index + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(5000))), + &NoOpMetricsCollector, + ) .await .unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); assert_eq!(Some(1), row_ids.len()); assert!(row_ids.contains(5000)); } @@ -1218,21 +1274,33 @@ mod tests { // Remapped to new value assert!(remapped_index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(5)))) + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(5))), + &NoOpMetricsCollector + ) .await .unwrap() + .row_ids() .contains(65)); // Deleted assert!(remapped_index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(7)))) + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(7))), + &NoOpMetricsCollector + ) .await .unwrap() + .row_ids() .is_empty()); // Not remapped assert!(remapped_index - .search(&SargableQuery::Equals(ScalarValue::Int32(Some(3)))) + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(3))), + &NoOpMetricsCollector + ) .await .unwrap() + .row_ids() .contains(3)); } @@ -1277,7 +1345,9 @@ mod tests { let data = data.clone(); async move { let index = LabelListIndex::load(index_store).await.unwrap(); - let row_ids = index.search(&query).await.unwrap(); + let result = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + assert!(result.is_exact()); + let row_ids = result.row_ids(); let row_ids_set = row_ids .row_ids() diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs new file mode 100644 index 00000000000..8a780d4e179 --- /dev/null +++ b/rust/lance-index/src/scalar/ngram.rs @@ -0,0 +1,1576 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::any::Any; +use std::collections::BTreeMap; +use std::iter::once; +use std::time::Instant; +use std::{collections::HashMap, sync::Arc}; + +use arrow::array::{AsArray, UInt32Builder}; +use arrow::datatypes::{UInt32Type, UInt64Type}; +use arrow_array::{BinaryArray, RecordBatch, UInt32Array}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; +use deepsize::DeepSizeOf; +use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; +use lance_core::cache::FileMetadataCache; +use lance_core::error::LanceOptionExt; +use lance_core::utils::address::RowAddress; +use lance_core::utils::tokio::get_num_compute_intensive_cpus; +use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS}; +use lance_core::Result; +use lance_core::{utils::mask::RowIdTreeMap, Error}; +use lance_io::object_store::ObjectStore; +use log::info; +use moka::future::Cache; +use object_store::path::Path; +use roaring::{RoaringBitmap, RoaringTreemap}; +use serde::Serialize; +use snafu::location; +use tantivy::tokenizer::TextAnalyzer; +use tempfile::{tempdir, TempDir}; +use tracing::instrument; + +use crate::metrics::NoOpMetricsCollector; +use crate::scalar::inverted::CACHE_SIZE; +use crate::vector::VectorIndex; +use crate::{Index, IndexType}; + +use super::btree::TrainingSource; +use super::lance_format::LanceIndexStore; +use super::{ + AnyQuery, IndexReader, IndexStore, IndexWriter, MetricsCollector, ScalarIndex, SearchResult, + TextQuery, +}; + +const TOKENS_COL: &str = "tokens"; +const POSTING_LIST_COL: &str = "posting_list"; +const POSTINGS_FILENAME: &str = "ngram_postings.lance"; + +lazy_static::lazy_static! { + pub static ref TOKENS_FIELD: Field = Field::new(TOKENS_COL, DataType::UInt32, true); + pub static ref POSTINGS_FIELD: Field = Field::new(POSTING_LIST_COL, DataType::Binary, false); + pub static ref POSTINGS_SCHEMA: SchemaRef = Arc::new(Schema::new(vec![TOKENS_FIELD.clone(), POSTINGS_FIELD.clone()])); + pub static ref TEXT_PREPPER: TextAnalyzer = TextAnalyzer::builder(tantivy::tokenizer::RawTokenizer::default()) + .filter(tantivy::tokenizer::LowerCaser) + .filter(tantivy::tokenizer::AsciiFoldingFilter) + .build(); + /// Currently we ALWAYS use trigrams with ascii folding and lower casing. We may want to make this configurable in the future. + pub static ref NGRAM_TOKENIZER: TextAnalyzer = TextAnalyzer::builder(tantivy::tokenizer::NgramTokenizer::all_ngrams(3, 3).unwrap()) + .filter(tantivy::tokenizer::AlphaNumOnlyFilter) + .build(); +} + +// Helper function to apply a function to each token in a text +fn tokenize_visitor(tokenizer: &TextAnalyzer, text: &str, mut visitor: impl FnMut(&String)) { + // The token_stream method is mutable. As far as I can tell this is to enforce exclusivity and not + // true mutability. For example, the object returned by `token_stream` has thread-local state but + // it is reset each time `token_stream` is called. + // + // However, I don't see this documented anywhere and I'm not sure about relying on it. For now, we + // make a clone as that seems to be the safer option. All the tokenizers we use here should be trivially + // cloneable (although it requires a heap allocation so may be worth investigating in the future) + let mut prepper = TEXT_PREPPER.clone(); + let mut tokenizer = tokenizer.clone(); + let mut raw_stream = prepper.token_stream(text); + while raw_stream.advance() { + let mut token_stream = tokenizer.token_stream(&raw_stream.token().text); + while token_stream.advance() { + visitor(&token_stream.token().text); + } + } +} + +const ALPHA_SPAN: usize = 37; +const MAX_TOKEN: usize = ALPHA_SPAN.pow(2) + ALPHA_SPAN; +const MIN_TOKEN: usize = 0; +const NGRAM_N: usize = 3; + +// Convert an ngram (string) to a token (u32). This helps avoid heap allocations +// and it makes it easier to partition the tokens for shuffling +// +// There are 36 alphanumeric values and we add 1 for the NULL token giving us 37^3 +// potential tokens. +// +// "" => 0 +// "?" => 37^2 * ? +// "?$" => 37^2 * ? + 37 * $ +// "?$#" => 37^2 * ? + 37 * $ + # +// ... +// +// The ?,$,# represent the position in the alphabet (+1 to distinguish from NULL) +// +// Small strings get the larger multipliers because those ngrams are +// less likely to be unique and will have larger bitmaps. We want to +// spread those out. +// +// NOTE: Today we hard-code trigrams and we do not include 1-grams or 2-grams so this +// function is more general than it needs to be...just in case. +fn ngram_to_token(ngram: &str, ngram_length: usize) -> u32 { + let mut token = 0; + // Empty string will get 0 + for (idx, byte) in ngram.bytes().enumerate() { + let pos = if byte <= b'9' { + byte - b'0' + } else if byte <= b'z' { + byte - b'a' + 10 + } else { + unreachable!() + } + 1; + debug_assert!(pos < ALPHA_SPAN as u8); + let mult = ALPHA_SPAN.pow(ngram_length as u32 - idx as u32 - 1) as u32; + token += pos as u32 * mult; + } + token +} + +/// Basic stats about an ngram index +#[derive(Serialize)] +struct NGramStatistics { + num_ngrams: usize, +} + +/// The row ids that contain a given ngram +#[derive(Debug)] +struct NGramPostingList { + bitmap: RoaringTreemap, +} + +impl DeepSizeOf for NGramPostingList { + fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize { + self.bitmap.serialized_size() + } +} + +impl NGramPostingList { + fn try_from_batch(batch: RecordBatch) -> Result { + let bitmap_bytes = batch.column(0).as_binary::().value(0); + let bitmap = + RoaringTreemap::deserialize_from(bitmap_bytes).map_err(|e| Error::Internal { + message: format!("Error deserializing ngram list: {}", e), + location: location!(), + })?; + Ok(Self { bitmap }) + } + + fn intersect<'a>(lists: impl IntoIterator) -> RoaringTreemap { + let mut iter = lists.into_iter(); + let mut result = iter + .next() + .map(|list| list.bitmap.clone()) + .unwrap_or_default(); + for list in iter { + result &= &list.bitmap; + } + result + } +} + +/// Reads on-demand ngram posting lists from storage (and stores them in a cache) +struct NGramPostingListReader { + reader: Arc, + /// The cache key is the row_offset + cache: Cache>, +} + +impl DeepSizeOf for NGramPostingListReader { + fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize { + self.cache.weighted_size() as usize + } +} + +impl std::fmt::Debug for NGramPostingListReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NGramListReader") + .field("cache_entry_count", &self.cache.entry_count()) + .finish() + } +} + +impl NGramPostingListReader { + #[instrument(level = "debug", skip(self, metrics))] + pub async fn ngram_list( + &self, + row_offset: u32, + metrics: &dyn MetricsCollector, + ) -> Result> { + self.cache + .try_get_with(row_offset, async move { + metrics.record_part_load(); + tracing::info!(target: TRACE_IO_EVENTS, type=IO_TYPE_LOAD_SCALAR_PART, index_type="ngram", part_id=row_offset); + let batch = self + .reader + .read_range( + row_offset as usize..row_offset as usize + 1, + Some(&[POSTING_LIST_COL]), + ) + .await?; + Result::Ok(Arc::new(NGramPostingList::try_from_batch(batch)?)) + }) + .await + .map_err(|e| Error::io(e.to_string(), location!())) + } +} + +/// An ngram index +/// +/// At a high level this is an inverted index that maps ngrams (small fixed size substrings) to the +/// row ids that contain them. +/// +/// As a simple example consider a 1-gram index. It would basically be a mapping from +/// each letter to the row ids that contain that letter. Then, if the user searches for +/// "cat", the index would look up the row ids for "c", "a", and "t", and return the intersection +/// of those row ids because only rows have at least one c, a, and t could possible contain "cat". +/// +/// This is an in-exact index, similar to a bloom filter. It can return false positives and a +/// recheck step is needed to confirm the results. +/// +/// Note that it cannot return false negatives. +pub struct NGramIndex { + /// The mapping from tokens to row offsets + tokens: HashMap, + /// The reader for the posting lists + list_reader: Arc, + /// The tokenizer used to tokenize text. Note: not all tokenizers can be used with this index. For + /// example, a stemming tokenizer would not work well because "dozing" would stem to "doze" and if the + /// search term is "zing" it would not match. As a result, this tokenizer is not as configurable as the + /// tokenizers used in an inverted index. + tokenizer: TextAnalyzer, + io_parallelism: usize, + /// The store that owns the index + store: Arc, +} + +impl std::fmt::Debug for NGramIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NGramIndex") + .field("tokens", &self.tokens) + .field("list_reader", &self.list_reader) + .finish() + } +} + +impl DeepSizeOf for NGramIndex { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.tokens.deep_size_of_children(context) + self.list_reader.deep_size_of_children(context) + } +} + +impl NGramIndex { + async fn from_store(store: Arc) -> Result { + let tokens = store.open_index_file(POSTINGS_FILENAME).await?; + let tokens = tokens + .read_range(0..tokens.num_rows(), Some(&[TOKENS_COL])) + .await?; + + let tokens_map = HashMap::from_iter( + tokens + .column(0) + .as_primitive::() + .values() + .iter() + .copied() + .enumerate() + .map(|(idx, token)| (token, idx as u32)), + ); + + let posting_reader = Arc::new(NGramPostingListReader { + reader: store.open_index_file(POSTINGS_FILENAME).await?, + cache: Cache::builder() + .max_capacity(*CACHE_SIZE as u64) + .weigher(|_, posting: &Arc| posting.deep_size_of() as u32) + .build(), + }); + + Ok(Self { + io_parallelism: store.io_parallelism(), + tokens: tokens_map, + list_reader: posting_reader, + tokenizer: NGRAM_TOKENIZER.clone(), + store, + }) + } + + fn remap_batch( + &self, + batch: RecordBatch, + mapping: &HashMap>, + ) -> Result { + let posting_lists_array = batch + .column_by_name(POSTING_LIST_COL) + .expect_ok()? + .as_binary::(); + + let new_posting_lists = posting_lists_array + .iter() + .map(|posting_list| { + let posting_list = posting_list.unwrap(); + let posting_list = RoaringTreemap::deserialize_from(posting_list)?; + let new_posting_list = + RoaringTreemap::from_iter(posting_list.into_iter().filter_map(|row_id| { + match mapping.get(&row_id) { + Some(Some(new_row_id)) => Some(*new_row_id), + Some(None) => None, + None => Some(row_id), + } + })); + let mut buf = Vec::with_capacity(new_posting_list.serialized_size()); + new_posting_list.serialize_into(&mut buf)?; + Ok(buf) + }) + .collect::>>()?; + + let new_posting_lists_array = BinaryArray::from_iter_values(new_posting_lists); + + Ok(RecordBatch::try_new( + POSTINGS_SCHEMA.clone(), + vec![ + batch.column_by_name(TOKENS_COL).expect_ok()?.clone(), + Arc::new(new_posting_lists_array), + ], + )?) + } +} + +#[async_trait] +impl Index for NGramIndex { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_index(self: Arc) -> Arc { + self + } + + fn as_vector_index(self: Arc) -> Result> { + Err(Error::InvalidInput { + source: "NGramIndex is not a vector index".into(), + location: location!(), + }) + } + + fn statistics(&self) -> Result { + let ngram_stats = NGramStatistics { + num_ngrams: self.tokens.len(), + }; + serde_json::to_value(ngram_stats).map_err(|e| Error::Internal { + message: format!("Error serializing statistics: {}", e), + location: location!(), + }) + } + + async fn prewarm(&self) -> Result<()> { + // TODO: NGram index can pre-warm by loading all posting lists into memory + Ok(()) + } + + fn index_type(&self) -> IndexType { + IndexType::NGram + } + + async fn calculate_included_frags(&self) -> Result { + let mut frag_ids = RoaringBitmap::new(); + for row_offset in self.tokens.values() { + let list = self + .list_reader + .ngram_list(*row_offset, &NoOpMetricsCollector) + .await?; + frag_ids.extend( + list.bitmap + .iter() + .map(|row_addr| RowAddress::from(row_addr).fragment_id()), + ); + } + Ok(frag_ids) + } +} + +#[async_trait] +impl ScalarIndex for NGramIndex { + async fn search( + &self, + query: &dyn AnyQuery, + metrics: &dyn MetricsCollector, + ) -> Result { + let query = + query + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::InvalidInput { + source: "Query is not a TextQuery".into(), + location: location!(), + })?; + match query { + TextQuery::StringContains(substr) => { + if substr.len() < NGRAM_N { + // We know nothing on short searches, need to recheck all + return Ok(SearchResult::AtLeast(RowIdTreeMap::new())); + } + + let mut row_offsets = Vec::with_capacity(substr.len() * 3); + let mut missing = false; + tokenize_visitor(&self.tokenizer, substr, |ngram| { + let token = ngram_to_token(ngram, NGRAM_N); + if let Some(row_offset) = self.tokens.get(&token) { + row_offsets.push(*row_offset); + } else { + missing = true; + } + }); + // At least one token was missing, so we know there are zero results + if missing { + return Ok(SearchResult::Exact(RowIdTreeMap::new())); + } + let posting_lists = futures::stream::iter( + row_offsets + .into_iter() + .map(|row_offset| self.list_reader.ngram_list(row_offset, metrics)), + ) + .buffer_unordered(self.io_parallelism) + .try_collect::>() + .await?; + metrics.record_comparisons(posting_lists.len()); + let list_refs = posting_lists.iter().map(|list| list.as_ref()); + let row_ids = NGramPostingList::intersect(list_refs); + Ok(SearchResult::AtMost(RowIdTreeMap::from(row_ids))) + } + } + } + + fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool { + false + } + + async fn load(store: Arc) -> Result> + where + Self: Sized, + { + Ok(Arc::new(Self::from_store(store).await?)) + } + + async fn remap( + &self, + mapping: &HashMap>, + dest_store: &dyn IndexStore, + ) -> Result<()> { + let reader = self.store.open_index_file(POSTINGS_FILENAME).await?; + let mut writer = dest_store + .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone()) + .await?; + + let mut offset = 0; + let num_rows = reader.num_rows(); + const BATCH_SIZE: usize = 64; + while offset < num_rows { + let batch_size = BATCH_SIZE.min(num_rows - offset); + let batch = reader.read_range(offset..offset + batch_size, None).await?; + let batch = self.remap_batch(batch, mapping)?; + writer.write_record_batch(batch).await?; + offset += BATCH_SIZE; + } + + writer.finish().await + } + + async fn update( + &self, + new_data: SendableRecordBatchStream, + dest_store: &dyn IndexStore, + ) -> Result<()> { + let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?; + let spill_files = builder.train(new_data).await?; + + builder + .write_index(dest_store, spill_files, Some(self.store.clone())) + .await?; + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct NGramIndexBuilderOptions { + tokens_per_spill: usize, +} + +lazy_static::lazy_static! { + // A higher value will use more RAM. A lower value will have to do more spilling + static ref DEFAULT_TOKENS_PER_SPILL: usize = std::env::var("LANCE_NGRAM_TOKENS_PER_SPILL") + .unwrap_or_else(|_| "1000000000".to_string()) + .parse() + .expect("failed to parse LANCE_NGRAM_TOKENS_PER_SPILL"); + // How many partitions to use for shuffling out the work. We slightly + // over-allocate this since the amount of work per-partition is not uniform. + // + // Increasing this may increase the performance but it could increase RAM (since we will spill less often) + // and could hurt performance (since there will be more files at the end for the final spill) + static ref DEFAULT_NUM_PARTITIONS: usize = std::env::var("LANCE_NGRAM_NUM_PARTITIONS").map(|s| s.parse().expect("failed to parse LANCE_NGRAM_PARALLELISM")).unwrap_or((get_num_compute_intensive_cpus() * 4).max(128)); + // Just enough so that tokenizing is faster than I/O + static ref DEFAULT_TOKENIZE_PARALLELISM: usize = std::env::var("LANCE_NGRAM_TOKENIZE_PARALLELISM").map(|s| s.parse().expect("failed to parse LANCE_NGRAM_TOKENIZE_PARALLELISM")).unwrap_or(8); +} + +impl Default for NGramIndexBuilderOptions { + fn default() -> Self { + Self { + tokens_per_spill: *DEFAULT_TOKENS_PER_SPILL, + } + } +} + +// An ordered list of tokens and bitmaps +// +// The `tokens` list is ordered by token value. This makes it easier to merge spill files. +struct NGramIndexSpillState { + tokens: UInt32Array, + bitmaps: Vec, +} + +impl NGramIndexSpillState { + fn try_from_batch(batch: RecordBatch) -> Result { + let tokens = batch + .column_by_name(TOKENS_COL) + .expect_ok()? + .as_primitive::() + .clone(); + let postings = batch + .column_by_name(POSTING_LIST_COL) + .expect_ok()? + .as_binary::(); + + let bitmaps = postings + .into_iter() + .map(|bytes| { + RoaringTreemap::deserialize_from(bytes.expect_ok()?).map_err(|e| Error::Internal { + message: format!("Error deserializing ngram list: {}", e), + location: location!(), + }) + }) + .collect::>>()?; + + Ok(Self { tokens, bitmaps }) + } + + fn try_into_batch(self) -> Result { + let bitmap_array = BinaryArray::from_iter_values(self.bitmaps.into_iter().map(|bitmap| { + let mut buf = Vec::with_capacity(bitmap.serialized_size()); + bitmap.serialize_into(&mut buf).unwrap(); + buf + })); + Ok(RecordBatch::try_new( + POSTINGS_SCHEMA.clone(), + vec![Arc::new(self.tokens), Arc::new(bitmap_array)], + )?) + } +} + +// As we're building we create a map from ngram to row ids. When this map gets too large +// we spill it to disk. +struct NGramIndexBuildState { + tokens_map: BTreeMap, +} + +impl NGramIndexBuildState { + fn starting() -> Self { + Self { + tokens_map: BTreeMap::new(), + } + } + + fn take(&mut self) -> Self { + let mut taken = Self::starting(); + std::mem::swap(&mut self.tokens_map, &mut taken.tokens_map); + taken + } + + fn into_spill(self) -> NGramIndexSpillState { + // We can rely on these being in token order because of BTreeMap + let tokens = UInt32Array::from_iter_values(self.tokens_map.keys().copied()); + let bitmaps = Vec::from_iter(self.tokens_map.into_values()); + + NGramIndexSpillState { bitmaps, tokens } + } +} + +/// A builder for an ngram index +/// +/// The builder is a small pipeline. First, we read in the data and tokenize it. This +/// stage uses fan-out parallelism to tokenize the data because tokenization may be a little +/// slower than I/O. +/// +/// The second stage fans out much wider. It partitions the tokens into a number of partitions. +/// Each partition has a BTreemap that maps tokens to row ids. The partitions then build up +/// roaring treemaps. When a partition gets too full it will spill to disk. +/// +/// Once all the data is processed we spill all the parititons to disk and then we merge the +/// spill files into a single index file. +pub struct NGramIndexBuilder { + tokenizer: TextAnalyzer, + options: NGramIndexBuilderOptions, + tmpdir: Arc, + spill_store: Arc, + + tokens_seen: usize, + worker_number: usize, + has_flushed: bool, + + state: NGramIndexBuildState, +} + +impl NGramIndexBuilder { + pub fn try_new(options: NGramIndexBuilderOptions) -> Result { + Self::from_state(NGramIndexBuildState::starting(), options) + } + + fn clone_worker(&self, worker_number: usize) -> Self { + let mut bitmaps = Vec::with_capacity(36 * 36 * 36 + 1); + // Token 0 is always the NULL bitmap + bitmaps.push(RoaringTreemap::new()); + Self { + tokenizer: self.tokenizer.clone(), + state: NGramIndexBuildState::starting(), + tmpdir: self.tmpdir.clone(), + spill_store: self.spill_store.clone(), + options: self.options.clone(), + tokens_seen: 0, + worker_number, + has_flushed: false, + } + } + + fn from_state(state: NGramIndexBuildState, options: NGramIndexBuilderOptions) -> Result { + let tokenizer = NGRAM_TOKENIZER.clone(); + + let tmpdir = Arc::new(tempdir()?); + let spill_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path())?, + FileMetadataCache::no_cache(), + )); + + Ok(Self { + tokenizer, + state, + tmpdir, + spill_store, + options, + tokens_seen: 0, + worker_number: 0, + has_flushed: false, + }) + } + + fn validate_schema(schema: &Schema) -> Result<()> { + if schema.fields().len() != 2 { + return Err(Error::InvalidInput { + source: "Ngram index schema must have exactly two fields".into(), + location: location!(), + }); + } + if *schema.field(0).data_type() != DataType::Utf8 { + return Err(Error::InvalidInput { + source: "First field in ngram index schema must be of type Utf8".into(), + location: location!(), + }); + } + if *schema.field(1).data_type() != DataType::UInt64 { + return Err(Error::InvalidInput { + source: "Second field in ngram index schema must be of type UInt64".into(), + location: location!(), + }); + } + Ok(()) + } + + async fn process_batch(&mut self, tokens_and_ids: Vec<(u32, u64)>) -> Result<()> { + let mut tokens_seen = 0; + for (token, row_id) in tokens_and_ids { + tokens_seen += 1; + // This would be a bit simpler with entry API but, at scale, the vast majority + // of cases will be a hit and we want to avoid cloning the string if we can. So + // for now we do the double-hash. We can simplify in the future with raw_entry + // when it stabilizes. + self.state + .tokens_map + .entry(token) + .or_default() + .insert(row_id); + } + self.tokens_seen += tokens_seen; + if self.tokens_seen >= self.options.tokens_per_spill { + let state = self.state.take(); + self.flush(state).await?; + } + Ok(()) + } + + fn spill_filename(id: usize) -> String { + format!("spill-{}.lance", id) + } + + fn tmp_spill_filename(id: usize) -> String { + format!("spill-{}.lance.tmp", id) + } + + async fn flush(&mut self, state: NGramIndexBuildState) -> Result { + if self.tokens_seen == 0 { + assert!(state.tokens_map.is_empty()); + return Ok(self.has_flushed); + } + self.tokens_seen = 0; + let spill_state = state.into_spill(); + let flush_start = Instant::now(); + // The primary builder should never flush + debug_assert_ne!(self.worker_number, 0); + if self.has_flushed { + info!("Merging flush for worker {}", self.worker_number); + // If we have flushed before then we need to merge with the spill file + let mut writer = self + .spill_store + .new_index_file( + &Self::tmp_spill_filename(self.worker_number), + POSTINGS_SCHEMA.clone(), + ) + .await?; + + let left_stream = stream::once(std::future::ready(Ok(spill_state))); + let right_stream = + Self::stream_spill(self.spill_store.clone(), self.worker_number).await?; + Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?; + drop(writer); + self.spill_store + .rename_index_file( + &Self::tmp_spill_filename(self.worker_number), + &Self::spill_filename(self.worker_number), + ) + .await?; + } else { + // If we haven't flushed before we can just write to the spill file + info!("Initial flush for worker {}", self.worker_number); + self.has_flushed = true; + let writer = self + .spill_store + .new_index_file( + &Self::spill_filename(self.worker_number), + POSTINGS_SCHEMA.clone(), + ) + .await?; + self.write(writer, spill_state).await?; + } + let flush_time = flush_start.elapsed(); + info!( + "Flushed worker {} in {}ms", + self.worker_number, + flush_time.as_millis() + ); + Ok(true) + } + + fn tokenize_and_partition( + tokenizer: &TextAnalyzer, + batch: RecordBatch, + num_workers: usize, + ) -> Vec> { + let text_col = batch.column(0).as_string::(); + let row_id_col = batch.column(1).as_primitive::(); + // Guessing 1000 tokens per row to at least avoid some of the earlier allocations + let mut partitions = vec![Vec::with_capacity(batch.num_rows() * 1000); num_workers]; + let divisor = (MAX_TOKEN - MIN_TOKEN) / num_workers; + for (text, row_id) in text_col.iter().zip(row_id_col.values()) { + if let Some(text) = text { + tokenize_visitor(tokenizer, text, |token| { + let token = ngram_to_token(token, NGRAM_N); + let partition_id = (token as usize).saturating_sub(MIN_TOKEN) / divisor; + partitions[partition_id % num_workers].push((token, *row_id)); + }); + } else { + partitions[0].push((0, *row_id)); + } + } + partitions + } + + pub async fn train(&mut self, data: SendableRecordBatchStream) -> Result> { + let schema = data.schema(); + Self::validate_schema(schema.as_ref())?; + + let num_workers = *DEFAULT_NUM_PARTITIONS; + let mut senders = Vec::with_capacity(num_workers); + let mut builders = Vec::with_capacity(num_workers); + for worker_idx in 0..num_workers { + let (send, mut recv) = tokio::sync::mpsc::channel(2); + senders.push(send); + + let mut builder = self.clone_worker(worker_idx + 1); + let future = tokio::spawn(async move { + while let Some(partition) = recv.recv().await { + builder.process_batch(partition).await?; + } + Result::Ok(builder) + }); + builders.push(future); + } + + let mut partitions_stream = data + .and_then(|batch| { + let tokenizer = self.tokenizer.clone(); + std::future::ready(Ok(tokio::task::spawn(async move { + Ok(Self::tokenize_and_partition(&tokenizer, batch, num_workers)) + }) + .map(|res| res.unwrap()))) + }) + .try_buffer_unordered(*DEFAULT_TOKENIZE_PARALLELISM); + + while let Some(partitions) = partitions_stream.try_next().await? { + for (part_idx, partition) in partitions.into_iter().enumerate() { + senders[part_idx].send(partition).await.unwrap(); + } + } + + std::mem::drop(senders); + let builders = futures::future::try_join_all(builders).await?; + + // Final flush is serialized. If we kick this off in parallel it can + // use a lot of memory. + + let mut to_spill = Vec::with_capacity(builders.len()); + + for builder in builders { + let mut builder = builder?; + let state = builder.state.take(); + if builder.flush(state).await? { + to_spill.push(builder.worker_number); + } + } + + Ok(to_spill) + } + + async fn write( + &mut self, + mut writer: Box, + state: NGramIndexSpillState, + ) -> Result<()> { + writer.write_record_batch(state.try_into_batch()?).await?; + writer.finish().await?; + + Ok(()) + } + + async fn stream_spill_reader( + reader: Arc, + ) -> Result>> { + let num_rows = reader.num_rows(); + + Ok(stream::try_unfold(0, move |offset| { + let reader = reader.clone(); + async move { + // These are small batches but, in the worst case scenario, each row could + // be massive (up to 128MB per row at 1B rows) and we end up breaking memory + let batch_size = std::cmp::min(num_rows - offset, 64); + if batch_size == 0 { + return Ok(None); + } + let batch = reader.read_range(offset..offset + batch_size, None).await?; + let state = NGramIndexSpillState::try_from_batch(batch)?; + let new_offset = offset + batch_size; + Ok(Some((state, new_offset))) + } + .boxed() + })) + } + + async fn stream_spill( + spill_store: Arc, + id: usize, + ) -> Result>> { + let reader = spill_store + .open_index_file(&Self::spill_filename(id)) + .await?; + Self::stream_spill_reader(reader).await + } + + fn merge_spill_states( + left_opt: &mut Option, + right_opt: &mut Option, + ) -> NGramIndexSpillState { + let left = left_opt.take().unwrap(); + let right = right_opt.take().unwrap(); + + let item_capacity = left.tokens.len() + right.tokens.len(); + let mut merged_tokens = UInt32Builder::with_capacity(item_capacity); + let mut merged_bitmaps = Vec::with_capacity(left.bitmaps.len() + right.bitmaps.len()); + + let mut left_tokens = left.tokens.values().iter().copied(); + let mut left_bitmaps = left.bitmaps.into_iter(); + let mut right_tokens = right.tokens.values().iter().copied(); + let mut right_bitmaps = right.bitmaps.into_iter(); + + let mut left_token = left_tokens.next(); + let mut left_bitmap = left_bitmaps.next(); + let mut right_token = right_tokens.next(); + let mut right_bitmap = right_bitmaps.next(); + + while left_token.is_some() && right_token.is_some() { + let left_token_val = left_token.unwrap(); + let right_token_val = right_token.unwrap(); + match left_token_val.cmp(&right_token_val) { + std::cmp::Ordering::Less => { + merged_tokens.append_value(left_token_val); + merged_bitmaps.push(left_bitmap.unwrap()); + left_token = left_tokens.next(); + left_bitmap = left_bitmaps.next(); + } + std::cmp::Ordering::Greater => { + merged_tokens.append_value(right_token_val); + merged_bitmaps.push(right_bitmap.unwrap()); + right_token = right_tokens.next(); + right_bitmap = right_bitmaps.next(); + } + std::cmp::Ordering::Equal => { + merged_tokens.append_value(left_token_val); + merged_bitmaps.push(left_bitmap.unwrap() | &right_bitmap.unwrap()); + left_token = left_tokens.next(); + left_bitmap = left_bitmaps.next(); + right_token = right_tokens.next(); + right_bitmap = right_bitmaps.next(); + } + } + } + + let collect_remaining = |cur_token, tokens, cur_bitmap, bitmaps| { + let tokens = UInt32Array::from_iter_values(once(cur_token).chain(tokens)); + let bitmaps = once(cur_bitmap).chain(bitmaps).collect::>(); + NGramIndexSpillState { tokens, bitmaps } + }; + + if left_token.is_some() { + *left_opt = Some(collect_remaining( + left_token.unwrap(), + left_tokens, + left_bitmap.unwrap(), + left_bitmaps, + )); + } else { + *left_opt = None; + } + if right_token.is_some() { + *right_opt = Some(collect_remaining( + right_token.unwrap(), + right_tokens, + right_bitmap.unwrap(), + right_bitmaps, + )); + } else { + *right_opt = None; + } + + NGramIndexSpillState { + tokens: merged_tokens.finish(), + bitmaps: merged_bitmaps, + } + } + + async fn merge_spill_streams( + mut left_stream: impl Stream> + Unpin, + mut right_stream: impl Stream> + Unpin, + writer: &mut dyn IndexWriter, + ) -> Result<()> { + let mut left_state = left_stream.try_next().await?; + let mut right_state = right_stream.try_next().await?; + + while left_state.is_some() || right_state.is_some() { + if left_state.is_none() { + // Left is done, full drain right + let state = right_state.take().expect_ok()?; + writer.write_record_batch(state.try_into_batch()?).await?; + while let Some(state) = right_stream.try_next().await? { + writer.write_record_batch(state.try_into_batch()?).await?; + } + } else if right_state.is_none() { + // Right is done, full drain left + let state = left_state.take().expect_ok()?; + writer.write_record_batch(state.try_into_batch()?).await?; + while let Some(state) = left_stream.try_next().await? { + writer.write_record_batch(state.try_into_batch()?).await?; + } + } else { + // There is a batch from both left and right. Need to merge them + let merged = Self::merge_spill_states(&mut left_state, &mut right_state); + writer.write_record_batch(merged.try_into_batch()?).await?; + if left_state.is_none() { + left_state = left_stream.try_next().await?; + } + if right_state.is_none() { + right_state = right_stream.try_next().await?; + } + } + } + + writer.finish().await + } + + async fn merge_spill_files( + spill_store: Arc, + index_of_left: usize, + index_of_right: usize, + output_index: usize, + ) -> Result<()> { + // We fully load the small file into memory and then stream the large file + info!( + "Merge spill files {} and {} into {}", + index_of_left, index_of_right, output_index + ); + + let mut writer = spill_store + .new_index_file(&Self::spill_filename(output_index), POSTINGS_SCHEMA.clone()) + .await?; + + let (left_stream, right_stream) = futures::try_join!( + Self::stream_spill(spill_store.clone(), index_of_left), + Self::stream_spill(spill_store.clone(), index_of_right) + )?; + + Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?; + + spill_store + .delete_index_file(&Self::spill_filename(index_of_left)) + .await?; + spill_store + .delete_index_file(&Self::spill_filename(index_of_right)) + .await?; + + Ok(()) + } + + // Can potentially parallelize in the future if this step becomes a bottleneck + // + // We can also merge in a more balanced fashion (e.g. binary tree) to reduce the size of + // intermediate files + // + // Note: worker indices start at 1 and not 0 (hence all the +1's) + async fn merge_spills(&mut self, mut spill_files: Vec) -> Result { + info!( + "Merging {} index files into one combined index", + spill_files.len() + ); + + let mut spill_counter = spill_files.iter().max().expect_ok()? + 1; + while spill_files.len() > 1 { + let mut new_spills = Vec::with_capacity(spill_files.len() / 2); + while spill_files.len() >= 2 { + let left = spill_files.pop().expect_ok()?; + let right = spill_files.pop().expect_ok()?; + new_spills.push(tokio::spawn(Self::merge_spill_files( + self.spill_store.clone(), + left, + right, + spill_counter + new_spills.len(), + ))); + } + for i in 0..new_spills.len() { + spill_files.push(spill_counter + i); + } + spill_counter += new_spills.len(); + futures::future::try_join_all(new_spills).await?; + } + + spill_files.pop().expect_ok() + } + + async fn merge_old_index( + &mut self, + new_data_num: usize, + old_index: Arc, + ) -> Result { + info!("Merging old index into new index"); + let final_num = new_data_num + 1; + + let mut writer = self + .spill_store + .new_index_file(&Self::spill_filename(final_num), POSTINGS_SCHEMA.clone()) + .await?; + + let left_stream = Self::stream_spill(self.spill_store.clone(), new_data_num).await?; + let old_reader = old_index.open_index_file(POSTINGS_FILENAME).await?; + let right_stream = Self::stream_spill_reader(old_reader).await?; + + Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?; + + self.spill_store + .delete_index_file(&Self::spill_filename(new_data_num)) + .await?; + + Ok(final_num) + } + + pub async fn write_index( + mut self, + store: &dyn IndexStore, + spill_files: Vec, + old_index: Option>, + ) -> Result<()> { + let mut writer = store + .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone()) + .await?; + + if spill_files.is_empty() { + if let Some(old_index) = old_index { + // An update with no new data, just copy the old index to the new store + old_index.copy_index_file(POSTINGS_FILENAME, store).await?; + } else { + // Training an index with no data, make an empty index + let mut writer = store + .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone()) + .await?; + writer.finish().await?; + } + return Ok(()); + } + + let mut index_to_copy = self.merge_spills(spill_files).await?; + + if let Some(old_index) = old_index { + index_to_copy = self.merge_old_index(index_to_copy, old_index).await?; + } + + let reader = self + .spill_store + .open_index_file(&Self::spill_filename(index_to_copy)) + .await?; + + let num_rows = reader.num_rows(); + let mut offset = 0; + + while offset < num_rows { + let batch_size = std::cmp::min(num_rows - offset, 64); + let batch = reader.read_range(offset..offset + batch_size, None).await?; + writer.write_record_batch(batch).await?; + offset += batch_size; + } + + writer.finish().await + } +} + +pub async fn train_ngram_index( + data_source: Box, + index_store: &dyn IndexStore, +) -> Result<()> { + let batches_source = data_source.scan_unordered_chunks(4096).await?; + let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?; + + let spill_files = builder.train(batches_source).await?; + + builder.write_index(index_store, spill_files, None).await +} + +#[cfg(test)] +mod tests { + use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + }; + + use arrow::datatypes::UInt64Type; + use arrow_array::{Array, RecordBatch, StringArray, UInt64Array}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::{ + execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter, + }; + use datafusion_common::DataFusionError; + use futures::{stream, TryStreamExt}; + use itertools::Itertools; + use lance_core::{cache::FileMetadataCache, utils::mask::RowIdTreeMap}; + use lance_datagen::{BatchCount, ByteCount, RowCount}; + use lance_io::object_store::ObjectStore; + use object_store::path::Path; + use tantivy::tokenizer::TextAnalyzer; + use tempfile::{tempdir, TempDir}; + + use crate::metrics::NoOpMetricsCollector; + use crate::scalar::{ + lance_format::LanceIndexStore, + ngram::{NGramIndex, NGramIndexBuilder, NGramIndexBuilderOptions}, + ScalarIndex, SearchResult, TextQuery, + }; + + use super::{ngram_to_token, tokenize_visitor, NGRAM_TOKENIZER}; + + fn collect_tokens(analyzer: &TextAnalyzer, text: &str) -> Vec { + let mut tokens = Vec::with_capacity(text.len() * 3); + tokenize_visitor(analyzer, text, |token| tokens.push(token.to_owned())); + tokens + } + + #[test] + fn test_tokenizer() { + let tokenizer = NGRAM_TOKENIZER.clone(); + + // ASCII folding + let tokens = collect_tokens(&tokenizer, "café"); + assert_eq!( + tokens, + vec!["caf", "afe"] // spellchecker:disable-line + ); + + // Allow numbers + let tokens = collect_tokens(&tokenizer, "a1b2"); + assert_eq!(tokens, vec!["a1b", "1b2"]); + + // Remove symbols and UTF-8 that doesn't map to characters + let tokens = collect_tokens(&tokenizer, "abcðŸ‘b!c24"); + + assert_eq!(tokens, vec!["abc", "c24"]); + + let tokens = collect_tokens(&tokenizer, "anstoß"); + + assert_eq!(tokens, vec!["ans", "nst", "sto", "tos", "oss"]); + + // Lower casing + let tokens = collect_tokens(&tokenizer, "ABC"); + assert_eq!(tokens, vec!["abc"]); + + // Duplicate tokens + let tokens = collect_tokens(&tokenizer, "ababab"); + // Confirming that the tokenizer doesn't deduplicate tokens (this can be taken into consideration + // when training the index) + assert_eq!( + tokens, + vec!["aba", "bab", "aba", "bab"] // spellchecker:disable-line + ); + } + + async fn do_train( + mut builder: NGramIndexBuilder, + data: SendableRecordBatchStream, + ) -> (NGramIndex, Arc) { + let spill_files = builder.train(data).await.unwrap(); + + let tmpdir = Arc::new(tempdir().unwrap()); + let test_store = LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(tmpdir.path()).unwrap(), + FileMetadataCache::no_cache(), + ); + + builder + .write_index(&test_store, spill_files, None) + .await + .unwrap(); + + ( + NGramIndex::from_store(Arc::new(test_store)).await.unwrap(), + tmpdir, + ) + } + + async fn get_posting_list_for_trigram(index: &NGramIndex, trigram: &str) -> Vec { + let token = ngram_to_token(trigram, 3); + let row_offset = index.tokens[&token]; + let list = index + .list_reader + .ngram_list(row_offset, &NoOpMetricsCollector) + .await + .unwrap(); + list.bitmap.iter().sorted().collect() + } + + async fn get_null_posting_list(index: &NGramIndex) -> Vec { + let row_offset = index.tokens[&0]; + let list = index + .list_reader + .ngram_list(row_offset, &NoOpMetricsCollector) + .await + .unwrap(); + list.bitmap.iter().sorted().collect() + } + + #[test_log::test(tokio::test)] + async fn test_basic_ngram_index() { + let data = StringArray::from_iter_values([ + "cat", + "dog", + "cat dog", + "dog cat", + "elephant", + "mouse", + "rhino", + "giraffe", + "rhinos nose", + ]); + let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64)); + let schema = Arc::new(Schema::new(vec![ + Field::new("values", DataType::Utf8, false), + Field::new("row_ids", DataType::UInt64, false), + ])); + let data = + RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap(); + let data = Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::once(std::future::ready(Ok(data))), + )); + + let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap(); + + let (index, _tmpdir) = do_train(builder, data).await; + assert_eq!(index.tokens.len(), 21); + + // Basic search + let res = index + .search( + &TextQuery::StringContains("cat".to_string()), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + + let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([0, 2, 3])); + + assert_eq!(expected, res); + + // Whitespace in query + let res = index + .search( + &TextQuery::StringContains("nos nos".to_string()), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([8])); + assert_eq!(expected, res); + + // No matches + let res = index + .search( + &TextQuery::StringContains("tdo".to_string()), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let expected = SearchResult::Exact(RowIdTreeMap::new()); + assert_eq!(expected, res); + + // False positive + let res = index + .search( + &TextQuery::StringContains("inose".to_string()), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([8])); + assert_eq!(expected, res); + + // Too short, don't know anything + let res = index + .search( + &TextQuery::StringContains("ab".to_string()), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let expected = SearchResult::AtLeast(RowIdTreeMap::new()); + assert_eq!(expected, res); + + // One short string but we still get at least one trigram, this is ok + let res = index + .search( + &TextQuery::StringContains("no nos".to_string()), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([8])); + assert_eq!(expected, res); + } + + fn test_data_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("values", DataType::Utf8, true), + Field::new("row_ids", DataType::UInt64, false), + ])) + } + + fn simple_data_with_nulls() -> SendableRecordBatchStream { + let data = StringArray::from_iter(&[Some("cat"), Some("dog"), None, None, Some("cat dog")]); + let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64)); + let schema = test_data_schema(); + let data = + RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap(); + Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::once(std::future::ready(Ok(data))), + )) + } + + #[test_log::test(tokio::test)] + async fn test_ngram_nulls() { + let data = simple_data_with_nulls(); + + let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap(); + + let (index, _tmpdir) = do_train(builder, data).await; + assert_eq!(index.tokens.len(), 3); + + let res = index + .search( + &TextQuery::StringContains("cat".to_string()), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let expected = SearchResult::AtMost(RowIdTreeMap::from_iter([0, 4])); + assert_eq!(expected, res); + + let null_posting_list = get_null_posting_list(&index).await; + assert_eq!(null_posting_list, vec![2, 3]); + + // TODO: Support IS NULL queries + } + + fn empty_data() -> SendableRecordBatchStream { + Box::pin(RecordBatchStreamAdapter::new( + test_data_schema(), + stream::empty::>(), + )) + } + + #[test_log::test(tokio::test)] + async fn test_train_empty() { + let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap(); + + let (index, _tmpdir) = do_train(builder, empty_data()).await; + assert_eq!(index.tokens.len(), 0); + } + + #[test_log::test(tokio::test)] + async fn test_update_empty() { + let data = simple_data_with_nulls(); + + let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap(); + let (index, _tmpdir) = do_train(builder, empty_data()).await; + + let new_tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(new_tmpdir.path()).unwrap(), + FileMetadataCache::no_cache(), + )); + + index.update(data, test_store.as_ref()).await.unwrap(); + + let index = NGramIndex::from_store(test_store).await.unwrap(); + assert_eq!(index.tokens.len(), 3); + } + + async fn row_ids_in_index(index: &NGramIndex) -> Vec { + let mut row_ids = HashSet::new(); + for row_offset in index.tokens.values() { + let list = index + .list_reader + .ngram_list(*row_offset, &NoOpMetricsCollector) + .await + .unwrap(); + row_ids.extend(list.bitmap.iter()); + } + row_ids.into_iter().sorted().collect() + } + + #[test_log::test(tokio::test)] + async fn test_ngram_index_remap() { + let data = simple_data_with_nulls(); + let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap(); + let (index, _tmpdir) = do_train(builder, data).await; + + let row_ids = row_ids_in_index(&index).await; + assert_eq!(row_ids, vec![0, 1, 2, 3, 4]); + + let new_tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(new_tmpdir.path()).unwrap(), + FileMetadataCache::no_cache(), + )); + + let remapping = HashMap::from([(2, Some(100)), (3, None), (4, Some(101))]); + index.remap(&remapping, test_store.as_ref()).await.unwrap(); + + let index = NGramIndex::from_store(test_store).await.unwrap(); + let row_ids = row_ids_in_index(&index).await; + assert_eq!(row_ids, vec![0, 1, 100, 101]); + + let null_posting_list = get_null_posting_list(&index).await; + assert_eq!(null_posting_list, vec![100]); + } + + #[test_log::test(tokio::test)] + async fn test_ngram_index_merge() { + let data = simple_data_with_nulls(); + let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap(); + let (index, _tmpdir) = do_train(builder, data).await; + + let data = StringArray::from_iter(&[Some("giraffe"), Some("cat"), None]); + let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64 + 100)); + let schema = Arc::new(Schema::new(vec![ + Field::new("values", DataType::Utf8, true), + Field::new("row_ids", DataType::UInt64, false), + ])); + let data = + RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap(); + let data = Box::pin(RecordBatchStreamAdapter::new( + schema, + stream::once(std::future::ready(Ok(data))), + )); + + let posting_list = get_posting_list_for_trigram(&index, "cat").await; + assert_eq!(posting_list, vec![0, 4]); + + let new_tmpdir = Arc::new(tempdir().unwrap()); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + Path::from_filesystem_path(new_tmpdir.path()).unwrap(), + FileMetadataCache::no_cache(), + )); + + index.update(data, test_store.as_ref()).await.unwrap(); + + let index = NGramIndex::from_store(test_store).await.unwrap(); + let row_ids = row_ids_in_index(&index).await; + assert_eq!(row_ids, vec![0, 1, 2, 3, 4, 100, 101, 102]); + + let posting_list = get_posting_list_for_trigram(&index, "cat").await; + assert_eq!(posting_list, vec![0, 4, 101]); + + let posting_list = get_posting_list_for_trigram(&index, "ffe").await; + assert_eq!(posting_list, vec![100]); + + let posting_list = get_null_posting_list(&index).await; + assert_eq!(posting_list, vec![2, 3, 102]); + } + + #[test_log::test(tokio::test)] + async fn test_ngram_index_with_spill() { + let (data, schema) = lance_datagen::gen() + .col( + "values", + lance_datagen::array::rand_utf8(ByteCount::from(50), false), + ) + .col("row_ids", lance_datagen::array::step::()) + .into_reader_stream(RowCount::from(128), BatchCount::from(32)); + + let data = Box::pin(RecordBatchStreamAdapter::new( + schema, + data.map_err(|arrow_err| DataFusionError::ArrowError(arrow_err, None)), + )); + + let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions { + tokens_per_spill: 100, + }) + .unwrap(); + + let (index, _tmpdir) = do_train(builder, data).await; + + assert_eq!(index.tokens.len(), 29012); + } +} diff --git a/rust/lance-index/src/traits.rs b/rust/lance-index/src/traits.rs index 5db6d188baa..b69f19c8493 100644 --- a/rust/lance-index/src/traits.rs +++ b/rust/lance-index/src/traits.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; use lance_core::Result; use crate::{optimize::OptimizeOptions, IndexParams, IndexType}; @@ -34,6 +35,26 @@ pub trait DatasetIndexExt { replace: bool, ) -> Result<()>; + /// Drop indices by name. + /// + /// Upon finish, a new dataset version is generated. + /// + /// Parameters: + /// + /// - `name`: the name of the index to drop. + async fn drop_index(&mut self, name: &str) -> Result<()>; + + /// Prewarm an index by name. + /// + /// This will load the index into memory and cache it. + /// + /// Generally, this should only be called when it is known the entire index will + /// fit into the index cache. + /// + /// This is a hint that is not enforced by all indices today. Some indices may choose + /// to ignore this hint. + async fn prewarm_index(&self, name: &str) -> Result<()>; + /// Read all indices of this Dataset version. /// /// The indices are lazy loaded and cached in memory within the [`Dataset`] instance. @@ -88,4 +109,11 @@ pub trait DatasetIndexExt { column: &str, index_id: Uuid, ) -> Result<()>; + + async fn read_index_partition( + &self, + index_name: &str, + partition_id: usize, + with_vector: bool, + ) -> Result; } diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index 63ad4955f33..21465a57348 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -4,16 +4,22 @@ //! Vector Index //! +use std::any::Any; +use std::fmt::Debug; use std::{collections::HashMap, sync::Arc}; use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; use arrow_schema::Field; use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; +use deepsize::DeepSizeOf; use ivf::storage::IvfModel; use lance_core::{Result, ROW_ID_FIELD}; +use lance_io::object_store::ObjectStore; use lance_io::traits::Reader; use lance_linalg::distance::DistanceType; use lazy_static::lazy_static; +use object_store::path::Path; use quantizer::{QuantizationType, Quantizer}; use v3::subindex::SubIndexType; @@ -33,6 +39,7 @@ pub mod utils; pub mod v3; use super::pb; +use crate::metrics::MetricsCollector; use crate::{prefilter::PreFilter, Index}; pub use residual::RESIDUAL_COLUMN; @@ -43,6 +50,7 @@ pub const INDEX_UUID_COLUMN: &str = "__index_uuid"; pub const PART_ID_COLUMN: &str = "__ivf_part_id"; pub const PQ_CODE_COLUMN: &str = "__pq_code"; pub const SQ_CODE_COLUMN: &str = "__sq_code"; +pub const LOSS_METADATA_KEY: &str = "_loss"; lazy_static! { pub static ref VECTOR_RESULT_SCHEMA: arrow_schema::SchemaRef = @@ -50,6 +58,8 @@ lazy_static! { Field::new(DIST_COL, arrow_schema::DataType::Float32, false), ROW_ID_FIELD.clone(), ])); + pub static ref PART_ID_FIELD: arrow_schema::Field = + arrow_schema::Field::new(PART_ID_COLUMN, arrow_schema::DataType::UInt32, true); } /// Query parameters for the vector indices @@ -64,6 +74,12 @@ pub struct Query { /// Top k results to return. pub k: usize, + /// The lower bound (inclusive) of the distance to be searched. + pub lower_bound: Option, + + /// The upper bound (exclusive) of the distance to be searched. + pub upper_bound: Option, + /// The number of probes to load and search. pub nprobes: usize, @@ -127,7 +143,12 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// /// *WARNINGS*: /// - Only supports `f32` now. Will add f64/f16 later. - async fn search(&self, query: &Query, pre_filter: Arc) -> Result; + async fn search( + &self, + query: &Query, + pre_filter: Arc, + metrics: &dyn MetricsCollector, + ) -> Result; fn find_partitions(&self, query: &Query) -> Result; @@ -136,6 +157,7 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { partition_id: usize, query: &Query, pre_filter: Arc, + metrics: &dyn MetricsCollector, ) -> Result; /// If the index is loadable by IVF, so it can be a sub-index that @@ -171,6 +193,21 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { self.load(reader, offset, length).await } + // for IVF only + async fn partition_reader( + &self, + _partition_id: usize, + _with_vector: bool, + _metrics: &dyn MetricsCollector, + ) -> Result { + unimplemented!("only for IVF") + } + + // for SubIndex only + async fn to_batch_stream(&self, with_vector: bool) -> Result; + + fn num_rows(&self) -> u64; + /// Return the IDs of rows in the index. fn row_ids(&self) -> Box + '_>; @@ -182,14 +219,33 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// /// If an old row id is not in the mapping then it should be /// left alone. - fn remap(&mut self, mapping: &HashMap>) -> Result<()>; + async fn remap(&mut self, mapping: &HashMap>) -> Result<()>; + + /// Remap the index according to mapping + /// + /// write the remapped index to the index_dir + /// this is available for only v3 index + async fn remap_to( + self: Arc, + _store: ObjectStore, + _mapping: &HashMap>, + _column: String, + _index_dir: Path, + ) -> Result<()> { + unimplemented!("only for v3 index") + } /// The metric type of this vector index. fn metric_type(&self) -> DistanceType; - fn ivf_model(&self) -> IvfModel; + fn ivf_model(&self) -> &IvfModel; fn quantizer(&self) -> Quantizer; /// the index type of this vector index. fn sub_index_type(&self) -> (SubIndexType, QuantizationType); } + +// it can be an IVF index or a partition of IVF index +pub trait VectorIndexCacheEntry: Debug + Send + Sync + DeepSizeOf { + fn as_any(&self) -> &dyn Any; +} diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index 55127b510de..05495118433 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -10,7 +10,7 @@ use arrow_array::types::Float32Type; use arrow_array::{cast::AsArray, Array, ArrayRef, UInt8Array}; use lance_core::{Error, Result}; use num_traits::Float; -use snafu::{location, Location}; +use snafu::location; #[derive(Clone, Default)] pub struct BinaryQuantization {} diff --git a/rust/lance-index/src/vector/flat.rs b/rust/lance-index/src/vector/flat.rs index d40c14f1e74..7149080a7db 100644 --- a/rust/lance-index/src/vector/flat.rs +++ b/rust/lance-index/src/vector/flat.rs @@ -4,18 +4,22 @@ //! Flat Vector Index. //! -use arrow_array::{make_array, Array, ArrayRef, RecordBatch}; +use std::sync::Arc; + +use arrow::{array::AsArray, buffer::NullBuffer}; +use arrow_array::{make_array, Array, ArrayRef, Float32Array, RecordBatch}; use arrow_schema::{DataType, Field as ArrowField}; use lance_arrow::*; use lance_core::{Error, Result, ROW_ID}; -use lance_linalg::distance::DistanceType; -use snafu::{location, Location}; +use lance_linalg::distance::{multivec_distance, DistanceType}; +use snafu::location; use tracing::instrument; use super::DIST_COL; pub mod index; pub mod storage; +pub mod transform; fn distance_field() -> ArrowField { ArrowField::new(DIST_COL, DataType::Float32, true) @@ -32,30 +36,44 @@ pub async fn compute_distance( // Ignore the distance calculated from inner vector index. batch = batch.drop_column(DIST_COL)?; } - let vectors = batch.column_by_name(column).ok_or_else(|| Error::Schema { - message: format!("column {} does not exist in dataset", column), - location: location!(), - })?; + let vectors = batch + .column_by_name(column) + .ok_or_else(|| Error::Schema { + message: format!("column {} does not exist in dataset", column), + location: location!(), + })? + .clone(); - // A selection vector may have been applied to _rowid column, so we need to - // push that onto vectors if possible. - let vectors = as_fixed_size_list_array(vectors.as_ref()).clone(); let validity_buffer = if let Some(rowids) = batch.column_by_name(ROW_ID) { - rowids.nulls().map(|nulls| nulls.buffer().clone()) + NullBuffer::union(rowids.nulls(), vectors.nulls()) } else { - None + vectors.nulls().cloned() }; - let vectors = vectors - .into_data() - .into_builder() - .null_bit_buffer(validity_buffer) - .build() - .map(make_array)?; - let vectors = as_fixed_size_list_array(vectors.as_ref()).clone(); - tokio::task::spawn_blocking(move || { - let distances = dt.arrow_batch_func()(key.as_ref(), &vectors)? as ArrayRef; + // A selection vector may have been applied to _rowid column, so we need to + // push that onto vectors if possible. + + let vectors = vectors + .into_data() + .into_builder() + .null_bit_buffer(validity_buffer.map(|b| b.buffer().clone())) + .build() + .map(make_array)?; + let distances = match vectors.data_type() { + DataType::FixedSizeList(_, _) => { + let vectors = vectors.as_fixed_size_list(); + dt.arrow_batch_func()(key.as_ref(), vectors)? as ArrayRef + } + DataType::List(_) => { + let vectors = vectors.as_list(); + let dists = multivec_distance(key.as_ref(), vectors, dt)?; + Arc::new(Float32Array::from(dists)) + } + _ => { + unreachable!() + } + }; batch .try_with_column(distance_field(), distances) diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index f50e995e4cb..581723423ad 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -4,6 +4,7 @@ //! Flat Vector Index. //! +use std::collections::HashMap; use std::sync::Arc; use arrow::array::AsArray; @@ -15,9 +16,10 @@ use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_file::reader::FileReader; use lance_linalg::distance::DistanceType; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; use crate::{ + metrics::MetricsCollector, prefilter::PreFilter, vector::{ graph::{OrderedFloat, OrderedNode}, @@ -28,7 +30,7 @@ use crate::{ }, }; -use super::storage::{FlatStorage, FLAT_COLUMN}; +use super::storage::{FlatBinStorage, FlatFloatStorage, FLAT_COLUMN}; /// A Flat index is any index that stores no metadata, and /// during query, it simply scans over the storage and returns the top k results @@ -43,11 +45,17 @@ lazy_static::lazy_static! { } #[derive(Default)] -pub struct FlatQueryParams {} +pub struct FlatQueryParams { + lower_bound: Option, + upper_bound: Option, +} impl From<&Query> for FlatQueryParams { - fn from(_: &Query) -> Self { - Self {} + fn from(q: &Query) -> Self { + Self { + lower_bound: q.lower_bound, + upper_bound: q.upper_bound, + } } } @@ -71,50 +79,54 @@ impl IvfSubIndex for FlatIndex { &self, query: ArrayRef, k: usize, - _params: Self::QueryParams, + params: Self::QueryParams, storage: &impl VectorStore, prefilter: Arc, + metrics: &dyn MetricsCollector, ) -> Result { + let is_range_query = params.lower_bound.is_some() || params.upper_bound.is_some(); let dist_calc = storage.dist_calculator(query); - - let (row_ids, dists): (Vec, Vec) = match prefilter.is_empty() { - true => dist_calc - .distance_all() - .into_iter() - .zip(0..storage.len() as u32) - .map(|(dist, id)| OrderedNode { - id, - dist: OrderedFloat(dist), - }) - .sorted_unstable() - .take(k) - .map( - |OrderedNode { - id, - dist: OrderedFloat(dist), - }| (storage.row_id(id), dist), - ) - .unzip(), + metrics.record_comparisons(storage.len()); + + let res = match prefilter.is_empty() { + true => { + let iter = dist_calc + .distance_all(k) + .into_iter() + .zip(0..storage.len() as u32) + .map(|(dist, id)| OrderedNode::new(id, dist.into())); + if is_range_query { + let lower_bound = params.lower_bound.unwrap_or(f32::MIN); + let upper_bound = params.upper_bound.unwrap_or(f32::MAX); + iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound) + .sorted_unstable() + } else { + iter.sorted_unstable() + } + } false => { let row_id_mask = prefilter.mask(); - (0..storage.len()) + let iter = (0..storage.len()) .filter(|&id| row_id_mask.selected(storage.row_id(id as u32))) .map(|id| OrderedNode { id: id as u32, dist: OrderedFloat(dist_calc.distance(id as u32)), - }) - .sorted_unstable() - .take(k) - .map( - |OrderedNode { - id, - dist: OrderedFloat(dist), - }| (storage.row_id(id), dist), - ) - .unzip() + }); + if is_range_query { + let lower_bound = params.lower_bound.unwrap_or(f32::MIN); + let upper_bound = params.upper_bound.unwrap_or(f32::MAX); + iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound) + .sorted_unstable() + } else { + iter.sorted_unstable() + } } }; + let (row_ids, dists): (Vec<_>, Vec<_>) = res + .take(k) + .map(|r| (storage.row_id(r.id), r.dist.0)) + .unzip(); let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists)); Ok(RecordBatch::try_new( @@ -134,6 +146,10 @@ impl IvfSubIndex for FlatIndex { Ok(Self {}) } + fn remap(&self, _: &HashMap>) -> Result { + Ok(self.clone()) + } + fn to_batch(&self) -> Result { Ok(RecordBatch::new_empty(Schema::empty().into())) } @@ -166,13 +182,17 @@ impl FlatQuantizer { impl Quantization for FlatQuantizer { type BuildParams = (); type Metadata = FlatMetadata; - type Storage = FlatStorage; + type Storage = FlatFloatStorage; fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result { let dim = data.as_fixed_size_list().value_length(); Ok(Self::new(dim as usize, distance_type)) } + fn retrain(&mut self, _: &dyn Array) -> Result<()> { + Ok(()) + } + fn code_dim(&self) -> usize { self.dim } @@ -228,3 +248,85 @@ impl TryFrom for FlatQuantizer { } } } + +#[derive(Debug, Clone, DeepSizeOf)] +pub struct FlatBinQuantizer { + dim: usize, + distance_type: DistanceType, +} + +impl FlatBinQuantizer { + pub fn new(dim: usize, distance_type: DistanceType) -> Self { + Self { dim, distance_type } + } +} + +impl Quantization for FlatBinQuantizer { + type BuildParams = (); + type Metadata = FlatMetadata; + type Storage = FlatBinStorage; + + fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result { + let dim = data.as_fixed_size_list().value_length(); + Ok(Self::new(dim as usize, distance_type)) + } + + fn retrain(&mut self, _: &dyn Array) -> Result<()> { + Ok(()) + } + + fn code_dim(&self) -> usize { + self.dim + } + + fn column(&self) -> &'static str { + FLAT_COLUMN + } + + fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result { + Ok(Quantizer::FlatBin(Self { + dim: metadata.dim, + distance_type, + })) + } + + fn metadata( + &self, + _: Option, + ) -> Result { + let metadata = FlatMetadata { dim: self.dim }; + Ok(serde_json::to_value(metadata)?) + } + + fn metadata_key() -> &'static str { + "flat" + } + + fn quantization_type() -> QuantizationType { + QuantizationType::Flat + } + + fn quantize(&self, vectors: &dyn Array) -> Result { + Ok(vectors.slice(0, vectors.len())) + } +} + +impl From for Quantizer { + fn from(value: FlatBinQuantizer) -> Self { + Self::FlatBin(value) + } +} + +impl TryFrom for FlatBinQuantizer { + type Error = Error; + + fn try_from(value: Quantizer) -> Result { + match value { + Quantizer::FlatBin(quantizer) => Ok(quantizer), + _ => Err(Error::invalid_input( + "quantizer is not FlatBinQuantizer", + location!(), + )), + } + } +} diff --git a/rust/lance-index/src/vector/flat/storage.rs b/rust/lance-index/src/vector/flat/storage.rs index b3bb11d02a0..d0ae227cd2f 100644 --- a/rust/lance-index/src/vector/flat/storage.rs +++ b/rust/lance-index/src/vector/flat/storage.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -//! In-memory graph representations. - use std::sync::Arc; use crate::vector::quantizer::QuantizerStorage; @@ -10,16 +8,19 @@ use crate::vector::storage::{DistCalculator, VectorStore}; use crate::vector::utils::do_prefetch; use arrow::array::AsArray; use arrow::compute::concat_batches; +use arrow::datatypes::UInt8Type; +use arrow_array::ArrowPrimitiveType; use arrow_array::{ types::{Float32Type, UInt64Type}, Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt64Array, }; -use arrow_schema::{DataType, SchemaRef}; +use arrow_schema::SchemaRef; use deepsize::DeepSizeOf; use lance_core::{Error, Result, ROW_ID}; use lance_file::reader::FileReader; +use lance_linalg::distance::hamming::hamming; use lance_linalg::distance::DistanceType; -use snafu::{location, Location}; +use snafu::location; use super::index::FlatMetadata; @@ -27,7 +28,7 @@ pub const FLAT_COLUMN: &str = "flat"; /// All data are stored in memory #[derive(Debug, Clone)] -pub struct FlatStorage { +pub struct FlatFloatStorage { batch: RecordBatch, distance_type: DistanceType, @@ -36,14 +37,14 @@ pub struct FlatStorage { vectors: Arc, } -impl DeepSizeOf for FlatStorage { +impl DeepSizeOf for FlatFloatStorage { fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize { self.batch.get_array_memory_size() } } #[async_trait::async_trait] -impl QuantizerStorage for FlatStorage { +impl QuantizerStorage for FlatFloatStorage { type Metadata = FlatMetadata; async fn load_partition( _: &FileReader, @@ -55,7 +56,7 @@ impl QuantizerStorage for FlatStorage { } } -impl FlatStorage { +impl FlatFloatStorage { // deprecated, use `try_from_batch` instead pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self { let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64)); @@ -80,8 +81,8 @@ impl FlatStorage { } } -impl VectorStore for FlatStorage { - type DistanceCalculator<'a> = FlatDistanceCal<'a>; +impl VectorStore for FlatFloatStorage { + type DistanceCalculator<'a> = FlatDistanceCal<'a, Float32Type>; fn try_from_batch(batch: RecordBatch, distance_type: DistanceType) -> Result { let row_ids = Arc::new( @@ -149,41 +150,144 @@ impl VectorStore for FlatStorage { } fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> { - FlatDistanceCal::new(self.vectors.as_ref(), query, self.distance_type) + Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type) } fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> { - FlatDistanceCal::new( + Self::DistanceCalculator::new( self.vectors.as_ref(), self.vectors.value(id as usize), self.distance_type, ) } +} - /// Distance between two vectors. - fn distance_between(&self, a: u32, b: u32) -> f32 { - match self.vectors.value_type() { - DataType::Float32 => { - let vector1 = self.vectors.value(a as usize); - let vector2 = self.vectors.value(b as usize); - self.distance_type.func()( - vector1.as_primitive::().values(), - vector2.as_primitive::().values(), - ) - } - _ => unimplemented!(), - } +/// All data are stored in memory +#[derive(Debug, Clone)] +pub struct FlatBinStorage { + batch: RecordBatch, + distance_type: DistanceType, + + // helper fields + pub(super) row_ids: Arc, + vectors: Arc, +} + +impl DeepSizeOf for FlatBinStorage { + fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize { + self.batch.get_array_memory_size() + } +} + +#[async_trait::async_trait] +impl QuantizerStorage for FlatBinStorage { + type Metadata = FlatMetadata; + async fn load_partition( + _: &FileReader, + _: std::ops::Range, + _: DistanceType, + _: &Self::Metadata, + ) -> Result { + unimplemented!("Flat will be used in new index builder which doesn't require this") + } +} + +impl FlatBinStorage { + pub fn vector(&self, id: u32) -> ArrayRef { + self.vectors.value(id as usize) + } +} + +impl VectorStore for FlatBinStorage { + type DistanceCalculator<'a> = FlatDistanceCal<'a, UInt8Type>; + + fn try_from_batch(batch: RecordBatch, distance_type: DistanceType) -> Result { + let row_ids = Arc::new( + batch + .column_by_name(ROW_ID) + .ok_or(Error::Schema { + message: format!("column {} not found", ROW_ID), + location: location!(), + })? + .as_primitive::() + .clone(), + ); + let vectors = Arc::new( + batch + .column_by_name(FLAT_COLUMN) + .ok_or(Error::Schema { + message: "column flat not found".to_string(), + location: location!(), + })? + .as_fixed_size_list() + .clone(), + ); + Ok(Self { + batch, + distance_type, + row_ids, + vectors, + }) + } + + fn to_batches(&self) -> Result> { + Ok([self.batch.clone()].into_iter()) + } + + fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result { + // TODO: use chunked storage + let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?; + let mut storage = self.clone(); + storage.batch = new_batch; + Ok(storage) + } + + fn schema(&self) -> &SchemaRef { + self.batch.schema_ref() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn len(&self) -> usize { + self.vectors.len() + } + + fn distance_type(&self) -> DistanceType { + self.distance_type + } + + fn row_id(&self, id: u32) -> u64 { + self.row_ids.values()[id as usize] + } + + fn row_ids(&self) -> impl Iterator { + self.row_ids.values().iter() + } + + fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> { + Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type) + } + + fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> { + Self::DistanceCalculator::new( + self.vectors.as_ref(), + self.vectors.value(id as usize), + self.distance_type, + ) } } -pub struct FlatDistanceCal<'a> { - vectors: &'a [f32], - query: Vec, +pub struct FlatDistanceCal<'a, T: ArrowPrimitiveType> { + vectors: &'a [T::Native], + query: Vec, dimension: usize, - distance_fn: fn(&[f32], &[f32]) -> f32, + #[allow(clippy::type_complexity)] + distance_fn: fn(&[T::Native], &[T::Native]) -> f32, } -impl<'a> FlatDistanceCal<'a> { +impl<'a> FlatDistanceCal<'a, Float32Type> { fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, distance_type: DistanceType) -> Self { // Gained significant performance improvement by using strong typed primitive slice. // TODO: to support other data types other than `f32`, make FlatDistanceCal a generic struct. @@ -196,21 +300,38 @@ impl<'a> FlatDistanceCal<'a> { distance_fn: distance_type.func(), } } +} + +impl<'a> FlatDistanceCal<'a, UInt8Type> { + fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, _distance_type: DistanceType) -> Self { + // Gained significant performance improvement by using strong typed primitive slice. + // TODO: to support other data types other than `f32`, make FlatDistanceCal a generic struct. + let flat_array = vectors.values().as_primitive::(); + let dimension = vectors.value_length() as usize; + Self { + vectors: flat_array.values(), + query: query.as_primitive::().values().to_vec(), + dimension, + distance_fn: hamming, + } + } +} +impl FlatDistanceCal<'_, T> { #[inline] - fn get_vector(&self, id: u32) -> &[f32] { + fn get_vector(&self, id: u32) -> &[T::Native] { &self.vectors[self.dimension * id as usize..self.dimension * (id + 1) as usize] } } -impl DistCalculator for FlatDistanceCal<'_> { +impl DistCalculator for FlatDistanceCal<'_, T> { #[inline] fn distance(&self, id: u32) -> f32 { let vector = self.get_vector(id); (self.distance_fn)(&self.query, vector) } - fn distance_all(&self) -> Vec { + fn distance_all(&self, _k_hint: usize) -> Vec { let query = &self.query; self.vectors .chunks_exact(self.dimension) diff --git a/rust/lance-index/src/vector/flat/transform.rs b/rust/lance-index/src/vector/flat/transform.rs new file mode 100644 index 00000000000..7afe5d4d9e3 --- /dev/null +++ b/rust/lance-index/src/vector/flat/transform.rs @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use arrow_array::RecordBatch; +use arrow_schema::Field; +use lance_arrow::RecordBatchExt; +use lance_core::Error; +use snafu::location; +use tracing::instrument; + +use crate::vector::transform::Transformer; + +use super::storage::FLAT_COLUMN; + +#[derive(Debug)] +pub struct FlatTransformer { + input_column: String, +} + +impl FlatTransformer { + pub fn new(input_column: impl AsRef) -> Self { + Self { + input_column: input_column.as_ref().to_owned(), + } + } +} + +impl Transformer for FlatTransformer { + #[instrument(name = "FlatTransformer::transform", level = "debug", skip_all)] + fn transform(&self, batch: &RecordBatch) -> crate::Result { + let input_arr = batch + .column_by_name(&self.input_column) + .ok_or(Error::Index { + message: format!( + "FlatTransform: column {} not found in batch", + self.input_column + ), + location: location!(), + })?; + let field = Field::new( + FLAT_COLUMN, + input_arr.data_type().clone(), + input_arr.is_nullable(), + ); + // rename the column to FLAT_COLUMN + let batch = batch + .drop_column(&self.input_column)? + .try_with_column(field, input_arr.clone())?; + Ok(batch) + } +} diff --git a/rust/lance-index/src/vector/graph.rs b/rust/lance-index/src/vector/graph.rs index 9e79dc231ac..e31ab4d3449 100644 --- a/rust/lance-index/src/vector/graph.rs +++ b/rust/lance-index/src/vector/graph.rs @@ -152,7 +152,7 @@ pub struct Visited<'a> { recently_visited: Vec, } -impl<'a> Visited<'a> { +impl Visited<'_> { pub fn insert(&mut self, node_id: u32) { let node_id_usize = node_id as usize; if !self.visited[node_id_usize] { @@ -171,7 +171,7 @@ impl<'a> Visited<'a> { } } -impl<'a> Drop for Visited<'a> { +impl Drop for Visited<'_> { fn drop(&mut self) { for node_id in self.recently_visited.iter() { self.visited.set(*node_id as usize, false); diff --git a/rust/lance-index/src/vector/hnsw.rs b/rust/lance-index/src/vector/hnsw.rs index f301762d98e..e4a5ed662d6 100644 --- a/rust/lance-index/src/vector/hnsw.rs +++ b/rust/lance-index/src/vector/hnsw.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use self::builder::HnswBuildParams; use super::graph::{OrderedFloat, OrderedNode}; -use super::storage::{DistCalculator, VectorStore}; +use super::storage::VectorStore; pub mod builder; pub mod index; @@ -73,12 +73,11 @@ fn select_neighbors_heuristic( if results.len() >= k { break; } - let dist_cal = storage.dist_calculator_from_id(u.id); if results.is_empty() || results .iter() - .all(|v| u.dist < OrderedFloat(dist_cal.distance(v.id))) + .all(|v| u.dist < OrderedFloat(storage.dist_between(u.id, v.id))) { results.push(u.clone()); } diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index fc5e43a1b86..f8d7b378c2f 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -14,7 +14,7 @@ use itertools::Itertools; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_linalg::distance::DistanceType; use rayon::prelude::*; -use snafu::{location, Location}; +use snafu::location; use std::cmp::min; use std::collections::{BinaryHeap, HashMap}; use std::fmt::Debug; @@ -30,8 +30,9 @@ use serde::{Deserialize, Serialize}; use super::super::graph::beam_search; use super::{select_neighbors_heuristic, HnswMetadata, HNSW_TYPE, VECTOR_ID_COL, VECTOR_ID_FIELD}; +use crate::metrics::MetricsCollector; use crate::prefilter::PreFilter; -use crate::vector::flat::storage::FlatStorage; +use crate::vector::flat::storage::FlatFloatStorage; use crate::vector::graph::builder::GraphBuilderNode; use crate::vector::graph::{greedy_search, Visited}; use crate::vector::graph::{ @@ -100,7 +101,7 @@ impl HnswBuildParams { /// - `data`: A FixedSizeList to build the HNSW. /// - `distance_type`: The distance type to use. pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result { - let vec_store = Arc::new(FlatStorage::new( + let vec_store = Arc::new(FlatFloatStorage::new( data.as_fixed_size_list().clone(), distance_type, )); @@ -393,9 +394,10 @@ impl HnswBuilder { ) { let nodes = &self.nodes; let target_level = nodes[node as usize].read().unwrap().level_neighbors.len() as u16 - 1; + let dist_calc = storage.dist_calculator_from_id(node); let mut ep = OrderedNode::new( self.entry_point, - storage.distance_between(node, self.entry_point).into(), + dist_calc.distance(self.entry_point).into(), ); // @@ -406,7 +408,6 @@ impl HnswBuilder { // ep = Select-Neighbors(W, 1) // } // ``` - let dist_calc = storage.dist_calculator_from_id(node); for level in (target_level + 1..self.params.max_level).rev() { let cur_level = HnswLevelView::new(level, nodes); ep = greedy_search(&cur_level, ep, &dist_calc, self.params.prefetch_distance); @@ -507,7 +508,7 @@ impl<'a> HnswLevelView<'a> { } } -impl<'a> Graph for HnswLevelView<'a> { +impl Graph for HnswLevelView<'_> { fn len(&self) -> usize { self.nodes.len() } @@ -528,7 +529,7 @@ impl<'a> HnswBottomView<'a> { } } -impl<'a> Graph for HnswBottomView<'a> { +impl Graph for HnswBottomView<'_> { fn len(&self) -> usize { self.nodes.len() } @@ -544,7 +545,7 @@ pub struct HnswQueryParams { pub ef: usize, } -impl<'a> From<&'a Query> for HnswQueryParams { +impl From<&Query> for HnswQueryParams { fn from(query: &Query) -> Self { let k = query.k * query.refine_factor.unwrap_or(1) as usize; Self { @@ -655,7 +656,7 @@ impl IvfSubIndex for HNSW { .into() } - #[instrument(level = "debug", skip(self, query, storage, prefilter))] + #[instrument(level = "debug", skip(self, query, storage, prefilter, _metrics))] fn search( &self, query: ArrayRef, @@ -663,6 +664,7 @@ impl IvfSubIndex for HNSW { params: Self::QueryParams, storage: &impl VectorStore, prefilter: Arc, + _metrics: &dyn MetricsCollector, ) -> Result { if params.ef < k { return Err(Error::Index { @@ -706,15 +708,16 @@ impl IvfSubIndex for HNSW { // if the queue is full, we just don't push it back, so ignore the error here let _ = self.inner.visited_generator_queue.push(prefilter_generator); - let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| storage.row_id(x.id))); - let distances = Arc::new(Float32Array::from_iter_values( - results.iter().map(|x| x.dist.0), - )); + // need to unique by row ids in case of searching multivector + let (row_ids, dists): (Vec<_>, Vec<_>) = results + .into_iter() + .map(|r| (storage.row_id(r.id), r.dist.0)) + .unique_by(|r| r.0) + .unzip(); + let row_ids = Arc::new(UInt64Array::from(row_ids)); + let distances = Arc::new(Float32Array::from(dists)); - Ok(RecordBatch::try_new( - schema, - vec![distances, Arc::new(row_ids)], - )?) + Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?) } /// Given a vector storage, containing all the data for the IVF partition, build the sub index. @@ -749,6 +752,10 @@ impl IvfSubIndex for HNSW { Ok(hnsw) } + fn remap(&self, _mapping: &HashMap>) -> Result { + unimplemented!("HNSW remap is not supported yet"); + } + /// Encode the sub index into a record batch fn to_batch(&self) -> Result { let mut vector_id_builder = UInt32Builder::with_capacity(self.len()); @@ -819,7 +826,7 @@ mod tests { use crate::scalar::IndexWriter; use crate::vector::v3::subindex::IvfSubIndex; use crate::vector::{ - flat::storage::FlatStorage, + flat::storage::FlatFloatStorage, graph::{DISTS_FIELD, NEIGHBORS_FIELD}, hnsw::{builder::HnswBuildParams, HNSW, VECTOR_ID_FIELD}, }; @@ -831,7 +838,7 @@ mod tests { const NUM_EDGES: usize = 20; let data = generate_random_array(TOTAL * DIM); let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap(); - let store = Arc::new(FlatStorage::new(fsl.clone(), DistanceType::L2)); + let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2)); let builder = HNSW::index_vectors( store.as_ref(), HnswBuildParams::default() diff --git a/rust/lance-index/src/vector/hnsw/index.rs b/rust/lance-index/src/vector/hnsw/index.rs index 783372b9f16..2722ef12076 100644 --- a/rust/lance-index/src/vector/hnsw/index.rs +++ b/rust/lance-index/src/vector/hnsw/index.rs @@ -10,7 +10,11 @@ use std::{ use arrow_array::{RecordBatch, UInt32Array}; use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use deepsize::DeepSizeOf; +use lance_arrow::RecordBatchExt; +use lance_core::ROW_ID; use lance_core::{datatypes::Schema, Error, Result}; use lance_file::reader::FileReader; use lance_io::traits::Reader; @@ -18,13 +22,13 @@ use lance_linalg::distance::DistanceType; use lance_table::format::SelfDescribingFileReader; use roaring::RoaringBitmap; use serde_json::json; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; -use crate::prefilter::PreFilter; use crate::vector::ivf::storage::IvfModel; use crate::vector::quantizer::QuantizationType; use crate::vector::v3::subindex::{IvfSubIndex, SubIndexType}; +use crate::{metrics::MetricsCollector, prefilter::PreFilter}; use crate::{ vector::{ graph::NEIGHBORS_FIELD, @@ -131,6 +135,11 @@ impl Index for HNSWIndex { })) } + async fn prewarm(&self) -> Result<()> { + // TODO: HNSW can (and should) support pre-warming + Ok(()) + } + /// Get the type of the index fn index_type(&self) -> IndexType { IndexType::Vector @@ -148,7 +157,12 @@ impl Index for HNSWIndex { #[async_trait] impl VectorIndex for HNSWIndex { #[instrument(level = "debug", skip_all, name = "HNSWIndex::search")] - async fn search(&self, query: &Query, pre_filter: Arc) -> Result { + async fn search( + &self, + query: &Query, + pre_filter: Arc, + metrics: &dyn MetricsCollector, + ) -> Result { let hnsw = self.hnsw.as_ref().ok_or(Error::Index { message: "HNSW index not loaded".to_string(), location: location!(), @@ -168,6 +182,7 @@ impl VectorIndex for HNSWIndex { query.into(), storage.as_ref(), pre_filter, + metrics, ) } @@ -180,6 +195,7 @@ impl VectorIndex for HNSWIndex { _: usize, _: &Query, _: Arc, + _: &dyn MetricsCollector, ) -> Result { unimplemented!("only for IVF") } @@ -263,18 +279,50 @@ impl VectorIndex for HNSWIndex { })) } + async fn to_batch_stream(&self, with_vector: bool) -> Result { + let store = self.storage.as_ref().ok_or(Error::Index { + message: "vector storage not loaded".to_string(), + location: location!(), + })?; + + let schema = if with_vector { + store.schema().clone() + } else { + let schema = store.schema(); + let row_id_idx = schema.index_of(ROW_ID)?; + Arc::new(schema.project(&[row_id_idx])?) + }; + + let batches = store + .to_batches()? + .map(|b| { + let batch = b.project_by_schema(&schema)?; + Ok(batch) + }) + .collect::>(); + let stream = futures::stream::iter(batches); + let stream = RecordBatchStreamAdapter::new(schema, stream); + Ok(Box::pin(stream)) + } + + fn num_rows(&self) -> u64 { + self.hnsw + .as_ref() + .map_or(0, |hnsw| hnsw.num_nodes(0) as u64) + } + fn row_ids(&self) -> Box + '_> { Box::new(self.storage.as_ref().unwrap().row_ids()) } - fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { + async fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { Err(Error::Index { message: "Remapping HNSW in this way not supported".to_string(), location: location!(), }) } - fn ivf_model(&self) -> IvfModel { + fn ivf_model(&self) -> &IvfModel { unimplemented!("only for IVF") } diff --git a/rust/lance-index/src/vector/ivf.rs b/rust/lance-index/src/vector/ivf.rs index 55bfc641732..93fd7c0be72 100644 --- a/rust/lance-index/src/vector/ivf.rs +++ b/rust/lance-index/src/vector/ivf.rs @@ -17,12 +17,17 @@ use lance_linalg::{ use tracing::instrument; use crate::vector::ivf::transform::PartitionTransformer; -use crate::vector::{pq::ProductQuantizer, residual::ResidualTransform, transform::Transformer}; +use crate::vector::{pq::ProductQuantizer, transform::Transformer}; +use super::flat::transform::FlatTransformer; use super::pq::transform::PQTransformer; use super::quantizer::Quantization; +use super::residual::ResidualTransform; +use super::sq::transform::SQTransformer; +use super::sq::ScalarQuantizer; +use super::transform::KeepFiniteVectors; use super::{quantizer::Quantizer, residual::compute_residual}; -use super::{PART_ID_COLUMN, PQ_CODE_COLUMN}; +use super::{PART_ID_COLUMN, PQ_CODE_COLUMN, SQ_CODE_COLUMN}; pub mod builder; pub mod shuffler; @@ -54,7 +59,7 @@ pub fn new_ivf_transformer_with_quantizer( range: Option>, ) -> Result { match quantizer { - Quantizer::Flat(_) => Ok(IvfTransformer::new_flat( + Quantizer::Flat(_) | Quantizer::FlatBin(_) => Ok(IvfTransformer::new_flat( centroids, metric_type, vector_column, @@ -66,12 +71,12 @@ pub fn new_ivf_transformer_with_quantizer( vector_column, pq, range, - false, )), - Quantizer::Scalar(_) => Ok(IvfTransformer::with_sq( + Quantizer::Scalar(sq) => Ok(IvfTransformer::with_sq( centroids, metric_type, vector_column, + sq, range, )), } @@ -113,7 +118,8 @@ impl IvfTransformer { vector_column: &str, range: Option>, ) -> Self { - let mut transforms: Vec> = vec![]; + let mut transforms: Vec> = + vec![Arc::new(super::transform::Flatten::new(vector_column))]; let dt = if distance_type == DistanceType::Cosine { transforms.push(Arc::new(super::transform::NormalizeTransformer::new( @@ -123,6 +129,7 @@ impl IvfTransformer { } else { distance_type }; + transforms.push(Arc::new(KeepFiniteVectors::new(vector_column))); let ivf_transform = Arc::new(PartitionTransformer::new( centroids.clone(), @@ -138,11 +145,9 @@ impl IvfTransformer { ))); } - Self { - centroids, - distance_type, - transforms, - } + transforms.push(Arc::new(FlatTransformer::new(vector_column))); + + Self::new(centroids, distance_type, transforms) } /// Create a IVF_PQ struct. @@ -152,11 +157,11 @@ impl IvfTransformer { vector_column: &str, pq: ProductQuantizer, range: Option>, - with_pq_code: bool, // Pass true for v1 index format, otherwise false. ) -> Self { - let mut transforms: Vec> = vec![]; + let mut transforms: Vec> = + vec![Arc::new(super::transform::Flatten::new(vector_column))]; - let mt = if distance_type == MetricType::Cosine { + let distance_type = if distance_type == MetricType::Cosine { transforms.push(Arc::new(super::transform::NormalizeTransformer::new( vector_column, ))); @@ -164,10 +169,11 @@ impl IvfTransformer { } else { distance_type }; + transforms.push(Arc::new(KeepFiniteVectors::new(vector_column))); let partition_transform = Arc::new(PartitionTransformer::new( centroids.clone(), - mt, + distance_type, vector_column, )); transforms.push(partition_transform); @@ -186,29 +192,26 @@ impl IvfTransformer { vector_column, ))); } - if with_pq_code { - transforms.push(Arc::new(PQTransformer::new( - pq, - vector_column, - PQ_CODE_COLUMN, - ))); - } - Self { - centroids, - distance_type, - transforms, - } + transforms.push(Arc::new(PQTransformer::new( + pq, + vector_column, + PQ_CODE_COLUMN, + ))); + + Self::new(centroids, distance_type, transforms) } fn with_sq( centroids: FixedSizeListArray, metric_type: MetricType, vector_column: &str, + sq: ScalarQuantizer, range: Option>, ) -> Self { - let mut transforms: Vec> = vec![]; + let mut transforms: Vec> = + vec![Arc::new(super::transform::Flatten::new(vector_column))]; - let mt = if metric_type == MetricType::Cosine { + let distance_type = if metric_type == MetricType::Cosine { transforms.push(Arc::new(super::transform::NormalizeTransformer::new( vector_column, ))); @@ -216,10 +219,11 @@ impl IvfTransformer { } else { metric_type }; + transforms.push(Arc::new(KeepFiniteVectors::new(vector_column))); let partition_transformer = Arc::new(PartitionTransformer::new( centroids.clone(), - mt, + distance_type, vector_column, )); transforms.push(partition_transformer); @@ -231,11 +235,13 @@ impl IvfTransformer { ))); } - Self { - centroids, - distance_type: metric_type, - transforms, - } + transforms.push(Arc::new(SQTransformer::new( + sq, + vector_column.to_owned(), + SQ_CODE_COLUMN.to_owned(), + ))); + + Self::new(centroids, distance_type, transforms) } #[inline] @@ -245,7 +251,10 @@ impl IvfTransformer { #[inline] pub fn compute_partitions(&self, data: &FixedSizeListArray) -> Result { - Ok(compute_partitions_arrow_array(&self.centroids, data, self.distance_type)?.into()) + Ok( + compute_partitions_arrow_array(&self.centroids, data, self.distance_type) + .map(|(part_ids, _)| part_ids.into())?, + ) } pub fn find_partitions(&self, query: &dyn Array, nprobes: usize) -> Result { diff --git a/rust/lance-index/src/vector/ivf/builder.rs b/rust/lance-index/src/vector/ivf/builder.rs index 475435460e9..bbc46beaebb 100644 --- a/rust/lance-index/src/vector/ivf/builder.rs +++ b/rust/lance-index/src/vector/ivf/builder.rs @@ -10,7 +10,7 @@ use arrow_array::cast::AsArray; use arrow_array::{Array, FixedSizeListArray, UInt32Array, UInt64Array}; use futures::TryStreamExt; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use lance_core::error::{Error, Result}; use lance_io::stream::RecordBatchStream; @@ -28,6 +28,10 @@ pub struct IvfBuildParams { /// Use provided IVF centroids. pub centroids: Option>, + /// Retrain centroids. + /// If true, the centroids will be retrained based on provided `centroids`. + pub retrain: bool, + pub sample_rate: usize, /// Precomputed partitions file (row_id -> partition_id) @@ -45,9 +49,6 @@ pub struct IvfBuildParams { pub shuffle_partition_concurrency: usize, - /// Use residual vectors to build sub-vector. - pub use_residual: bool, - /// Storage options used to load precomputed partitions. pub storage_options: Option>, } @@ -58,12 +59,12 @@ impl Default for IvfBuildParams { num_partitions: 32, max_iters: 50, centroids: None, + retrain: false, sample_rate: 256, // See faiss precomputed_partitions_file: None, precomputed_shuffle_buffers: None, shuffle_partition_batches: 1024 * 10, shuffle_partition_concurrency: 2, - use_residual: true, storage_options: None, } } diff --git a/rust/lance-index/src/vector/ivf/shuffler.rs b/rust/lance-index/src/vector/ivf/shuffler.rs index 2f6d97ed7fc..3ac6a08203a 100644 --- a/rust/lance-index/src/vector/ivf/shuffler.rs +++ b/rust/lance-index/src/vector/ivf/shuffler.rs @@ -43,11 +43,11 @@ use lance_table::format::SelfDescribingFileReader; use lance_table::io::manifest::ManifestDescribing; use log::info; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tempfile::TempDir; use crate::vector::ivf::IvfTransformer; -use crate::vector::transform::{KeepFiniteVectors, Transformer}; +use crate::vector::transform::Transformer; use crate::vector::PART_ID_COLUMN; const UNSORTED_BUFFER: &str = "unsorted.lance"; @@ -243,7 +243,6 @@ impl PartitionListBuilder { #[allow(clippy::too_many_arguments)] pub async fn shuffle_dataset( data: impl RecordBatchStream + Unpin + 'static, - column: &str, ivf: Arc, precomputed_partitions: Option>, num_partitions: u32, @@ -268,7 +267,6 @@ pub async fn shuffle_dataset( ); let mut shuffler = IvfShuffler::try_new(num_partitions, None, true, None)?; - let column = column.to_owned(); let precomputed_partitions = precomputed_partitions.map(Arc::new); let stream = data .zip(repeat_with(move || ivf.clone())) @@ -279,7 +277,6 @@ pub async fn shuffle_dataset( .as_ref() .cloned() .unwrap_or(Arc::new(HashMap::new())); - let nan_filter = KeepFiniteVectors::new(&column); tokio::task::spawn(async move { let mut batch = b?; @@ -319,10 +316,6 @@ pub async fn shuffle_dataset( batch = batch.take(&indices)?; } } - - // Filter out NaNs/Infs - batch = nan_filter.transform(&batch)?; - ivf.transform(&batch) }) }) @@ -739,11 +732,7 @@ impl IvfShuffler { continue; } - let local_start = if start < cur_start { - 0 - } else { - start - cur_start - }; + let local_start = start.saturating_sub(cur_start); let local_end = std::cmp::min(end - cur_start, *partition_size); input.push(ShuffleInput { diff --git a/rust/lance-index/src/vector/ivf/storage.rs b/rust/lance-index/src/vector/ivf/storage.rs index 6a9d8d6d798..5f626943f8f 100644 --- a/rust/lance-index/src/vector/ivf/storage.rs +++ b/rust/lance-index/src/vector/ivf/storage.rs @@ -14,7 +14,7 @@ use lance_linalg::distance::DistanceType; use lance_table::io::manifest::ManifestDescribing; use log::debug; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; use crate::pb::Ivf as PbIvf; @@ -34,6 +34,9 @@ pub struct IvfModel { /// Number of vectors in each partition. pub lengths: Vec, + + /// Kmeans loss + pub loss: Option, } impl DeepSizeOf for IvfModel { @@ -53,14 +56,16 @@ impl IvfModel { centroids: None, offsets: vec![], lengths: vec![], + loss: None, } } - pub fn new(centroids: FixedSizeListArray) -> Self { + pub fn new(centroids: FixedSizeListArray, loss: Option) -> Self { Self { centroids: Some(centroids), offsets: vec![], lengths: vec![], + loss, } } @@ -88,6 +93,14 @@ impl IvfModel { self.lengths[part] as usize } + pub fn num_rows(&self) -> u64 { + self.lengths.iter().map(|x| *x as u64).sum() + } + + pub fn loss(&self) -> Option { + self.loss + } + /// Use the query vector to find `nprobes` closest partitions. pub fn find_partitions( &self, @@ -167,6 +180,7 @@ impl TryFrom<&IvfModel> for PbIvf { lengths, offsets: ivf.offsets.iter().map(|x| *x as u64).collect(), centroids_tensor: ivf.centroids.as_ref().map(|c| c.try_into()).transpose()?, + loss: ivf.loss, }) } } @@ -215,6 +229,7 @@ impl TryFrom for IvfModel { centroids, offsets, lengths: proto.lengths, + loss: proto.loss, }) } } @@ -296,6 +311,7 @@ mod tests { lengths: vec![2, 2], offsets: vec![0, 2], centroids_tensor: None, + loss: None, }; let ivf = IvfModel::try_from(pb_ivf).unwrap(); diff --git a/rust/lance-index/src/vector/ivf/transform.rs b/rust/lance-index/src/vector/ivf/transform.rs index d7c1caec79b..d2f877cee15 100644 --- a/rust/lance-index/src/vector/ivf/transform.rs +++ b/rust/lance-index/src/vector/ivf/transform.rs @@ -10,7 +10,8 @@ use arrow_array::{ cast::AsArray, types::UInt32Type, Array, FixedSizeListArray, RecordBatch, UInt32Array, }; use arrow_schema::Field; -use snafu::{location, Location}; +use lance_table::utils::LanceIteratorExtension; +use snafu::location; use tracing::instrument; use lance_arrow::RecordBatchExt; @@ -19,12 +20,15 @@ use lance_linalg::distance::DistanceType; use lance_linalg::kmeans::compute_partitions_arrow_array; use crate::vector::transform::Transformer; +use crate::vector::LOSS_METADATA_KEY; use super::PART_ID_COLUMN; -/// Ivf Transformer +/// PartitionTransformer /// -/// It transforms a Vector column, specified by the input data, into a column of partition IDs. +/// It computes the partition ID for each row from the input batch, +/// and adds the partition ID as a new column to the batch, +/// and adds the loss as a metadata to the batch. /// /// If the partition ID ("__ivf_part_id") column is already present in the Batch, /// this transform is a Noop. @@ -57,6 +61,7 @@ impl PartitionTransformer { pub(super) fn compute_partitions(&self, data: &FixedSizeListArray) -> UInt32Array { compute_partitions_arrow_array(&self.centroids, data, self.distance_type) .expect("failed to compute partitions") + .0 .into() } } @@ -67,30 +72,36 @@ impl Transformer for PartitionTransformer { // If the partition ID column is already present, we don't need to compute it again. return Ok(batch.clone()); } + let arr = batch .column_by_name(&self.input_column) .ok_or_else(|| lance_core::Error::Index { message: format!( - "IvfTransformer: column {} not found in the RecordBatch", + "PartitionTransformer: column {} not found in the RecordBatch", self.input_column ), location: location!(), })?; + let fsl = arr .as_fixed_size_list_opt() .ok_or_else(|| lance_core::Error::Index { message: format!( - "IvfTransformer: column {} is not a FixedSizeListArray: {}", + "PartitionTransformer: column {} is not a FixedSizeListArray: {}", self.input_column, arr.data_type(), ), location: location!(), })?; - let part_ids = self.compute_partitions(fsl); + let (part_ids, loss) = + compute_partitions_arrow_array(&self.centroids, fsl, self.distance_type)?; + let part_ids = UInt32Array::from(part_ids); let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true); - Ok(batch.try_with_column(field, Arc::new(part_ids))?) + Ok(batch + .try_with_column(field, Arc::new(part_ids))? + .add_metadata(LOSS_METADATA_KEY.to_owned(), loss.to_string())?) } } @@ -121,6 +132,8 @@ impl PartitionFilter { None } }) + // in most cases, no partition will be filtered out. + .exact_size(partition_ids.len()) .collect() } } diff --git a/rust/lance-index/src/vector/kmeans.rs b/rust/lance-index/src/vector/kmeans.rs index c971a6942b5..57e0ed46b9b 100644 --- a/rust/lance-index/src/vector/kmeans.rs +++ b/rust/lance-index/src/vector/kmeans.rs @@ -1,11 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use arrow_array::{types::ArrowPrimitiveType, ArrayRef, FixedSizeListArray, PrimitiveArray}; +use std::sync::Arc; + +use arrow_array::{types::ArrowPrimitiveType, FixedSizeListArray, PrimitiveArray}; use lance_arrow::FixedSizeListArrayExt; use log::info; use rand::{seq::IteratorRandom, Rng}; -use snafu::{location, Location}; +use snafu::location; use lance_core::{Error, Result}; use lance_linalg::{ @@ -14,8 +16,21 @@ use lance_linalg::{ }; /// Train KMeans model and returns the centroids of each cluster. +/// +/// Parameters +/// ---------- +/// - *centroids*: initial centroids, use the random initialization if None +/// - *array*: a flatten floating number array of vectors +/// - *dimension*: dimension of the vector +/// - *k*: number of clusters +/// - *max_iterations*: maximum number of iterations +/// - *redos*: number of times to redo the k-means clustering +/// - *rng*: random number generator +/// - *distance_type*: distance type to compute pair-wise vector distance +/// - *sample_rate*: sample rate to select the data for training #[allow(clippy::too_many_arguments)] pub fn train_kmeans( + centroids: Option>, array: &[T::Native], dimension: usize, k: usize, @@ -24,7 +39,7 @@ pub fn train_kmeans( mut rng: impl Rng, distance_type: DistanceType, sample_rate: usize, -) -> Result +) -> Result where T::Native: Dot + L2 + Normalize, PrimitiveArray: From>, @@ -57,13 +72,8 @@ where PrimitiveArray::::from(array.to_vec()) }; - let params = KMeansParams { - max_iters: max_iterations, - distance_type, - redos, - ..Default::default() - }; + let params = KMeansParams::new(centroids, max_iterations, redos, distance_type); let data = FixedSizeListArray::try_new_from_values(data, dimension as i32)?; let model = KMeans::new_with_params(&data, k, ¶ms)?; - Ok(model.centroids.clone()) + Ok(model) } diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 467599157b3..49bc2b54c3b 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -11,14 +11,15 @@ use arrow_array::{cast::AsArray, Array, FixedSizeListArray, UInt8Array}; use arrow_array::{ArrayRef, Float32Array, PrimitiveArray}; use arrow_schema::DataType; use deepsize::DeepSizeOf; -use distance::{build_distance_table_dot, compute_dot_distance}; +use distance::build_distance_table_dot; use lance_arrow::*; use lance_core::{Error, Result}; use lance_linalg::distance::{DistanceType, Dot, L2}; use lance_linalg::kmeans::compute_partition; +use lance_table::utils::LanceIteratorExtension; use num_traits::Float; use prost::Message; -use snafu::{location, Location}; +use snafu::location; use storage::{ProductQuantizationMetadata, ProductQuantizationStorage, PQ_METADATA_KEY}; use tracing::instrument; @@ -28,7 +29,7 @@ pub mod storage; pub mod transform; pub(crate) mod utils; -use self::distance::{build_distance_table_l2, compute_l2_distance}; +use self::distance::{build_distance_table_l2, compute_pq_distance}; pub use self::utils::num_centroids; use super::quantizer::{ Quantization, QuantizationMetadata, QuantizationType, Quantizer, QuantizerBuildParams, @@ -143,6 +144,7 @@ impl ProductQuantizer { let flatten_data = fsl.values().as_primitive::(); let sub_dim = dim / num_sub_vectors; + let total_code_length = fsl.len() * num_sub_vectors / (8 / NUM_BITS as usize); let values = flatten_data .values() .chunks_exact(dim) @@ -169,6 +171,7 @@ impl ProductQuantizer { sub_vec_code } }) + .exact_size(total_code_length) .collect::>(); let num_sub_vectors_in_byte = if NUM_BITS == 4 { @@ -267,12 +270,16 @@ impl ProductQuantizer { key.values(), ); - let distances = compute_dot_distance( + let distances = compute_pq_distance( &distance_table, self.num_bits, self.num_sub_vectors, code.values(), + 0, ); + + let diff = self.num_sub_vectors as f32 - 1.0; + let distances = distances.into_iter().map(|d| d - diff).collect::>(); Ok(distances.into()) } @@ -327,11 +334,12 @@ impl ProductQuantizer { /// The squared L2 distance. #[inline] fn compute_l2_distance(&self, distance_table: &[f32], code: &[u8]) -> Float32Array { - Float32Array::from(compute_l2_distance( + Float32Array::from(compute_pq_distance( distance_table, self.num_bits, self.num_sub_vectors, code, + 100, )) } @@ -392,6 +400,18 @@ impl Quantization for ProductQuantizer { params.build(data, distance_type) } + fn retrain(&mut self, data: &dyn Array) -> Result<()> { + assert_eq!(data.null_count(), 0); + let params = PQBuildParams::with_codebook( + self.num_sub_vectors, + self.num_bits as usize, + Arc::new(self.codebook.clone()), + ); + + *self = params.build(data, self.distance_type)?; + Ok(()) + } + fn code_dim(&self) -> usize { self.num_sub_vectors } @@ -444,7 +464,7 @@ impl Quantization for ProductQuantizer { let tensor = pb::Tensor::try_from(&self.codebook)?; Ok(serde_json::to_value(ProductQuantizationMetadata { codebook_position, - num_bits: self.num_bits, + nbits: self.num_bits, num_sub_vectors: self.num_sub_vectors, dimension: self.dimension, codebook: None, @@ -454,6 +474,10 @@ impl Quantization for ProductQuantizer { } fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result { + let distance_type = match distance_type { + DistanceType::Cosine => DistanceType::L2, + _ => distance_type, + }; let codebook = match metadata.codebook.as_ref() { Some(fsl) => fsl.clone(), None => { @@ -463,7 +487,7 @@ impl Quantization for ProductQuantizer { }; Ok(Quantizer::Product(Self::new( metadata.num_sub_vectors, - metadata.num_bits, + metadata.nbits, metadata.dimension, codebook, distance_type, @@ -503,7 +527,7 @@ impl TryFrom for ProductQuantizer { mod tests { use super::*; - use std::iter::repeat; + use std::iter::repeat_n; use approx::assert_relative_eq; use arrow::datatypes::UInt8Type; @@ -522,7 +546,7 @@ mod tests { 8, 16, FixedSizeListArray::try_new_from_values( - Float16Array::from_iter_values(repeat(f16::zero()).take(256 * 16)), + Float16Array::from_iter_values(repeat_n(f16::zero(), 256 * 16)), 16, ) .unwrap(), diff --git a/rust/lance-index/src/vector/pq/builder.rs b/rust/lance-index/src/vector/pq/builder.rs index d46c633936e..81827a27327 100644 --- a/rust/lance-index/src/vector/pq/builder.rs +++ b/rust/lance-index/src/vector/pq/builder.rs @@ -4,6 +4,8 @@ //! Product Quantizer Builder //! +use std::sync::Arc; + use crate::vector::quantizer::QuantizerBuildParams; use arrow::array::PrimitiveBuilder; use arrow_array::types::{Float16Type, Float64Type}; @@ -16,7 +18,7 @@ use lance_linalg::distance::DistanceType; use lance_linalg::distance::{Dot, Normalize, L2}; use rand::SeedableRng; use rayon::prelude::*; -use snafu::{location, Location}; +use snafu::location; use super::utils::divide_to_subvectors; use super::ProductQuantizer; @@ -108,9 +110,21 @@ impl PQBuildParams { let d = sub_vectors .into_par_iter() - .map(|sub_vec| { + .enumerate() + .map(|(sub_vec_idx, sub_vec)| { let rng = rand::rngs::SmallRng::from_entropy(); train_kmeans::( + self.codebook.as_ref().map(|cb| { + let sub_vec_centroids = FixedSizeListArray::try_new_from_values( + cb.as_fixed_size_list().values().as_primitive::().slice( + sub_vec_idx * num_centroids * sub_vector_dimension, + num_centroids * sub_vector_dimension, + ), + sub_vector_dimension as i32, + ) + .unwrap(); + Arc::new(sub_vec_centroids) + }), &sub_vec, sub_vector_dimension, num_centroids, @@ -120,6 +134,7 @@ impl PQBuildParams { distance_type, self.sample_rate, ) + .map(|kmeans| kmeans.centroids) }) .collect::>>()?; let mut codebook_builder = PrimitiveBuilder::::with_capacity(num_centroids * dimension); @@ -154,6 +169,19 @@ impl PQBuildParams { ), location: location!(), })?; + + let num_centroids = 2_usize.pow(self.num_bits as u32); + if data.len() < num_centroids { + return Err(Error::Index { + message: format!( + "Not enough rows to train PQ. Requires {:?} rows but only {:?} available", + num_centroids, + data.len() + ), + location: location!(), + }); + } + // TODO: support bf16 later. match fsl.value_type() { DataType::Float16 => self.build_from_fsl::(fsl, distance_type), diff --git a/rust/lance-index/src/vector/pq/distance.rs b/rust/lance-index/src/vector/pq/distance.rs index 8aad9cb3fa6..01f5d03f6bd 100644 --- a/rust/lance-index/src/vector/pq/distance.rs +++ b/rust/lance-index/src/vector/pq/distance.rs @@ -2,13 +2,22 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use core::panic; -use std::cmp::min; +use std::cmp::{max, min}; use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2}; +use lance_linalg::simd::u8::u8x16; +use lance_linalg::simd::{Shuffle, SIMD}; use lance_table::utils::LanceIteratorExtension; use super::{num_centroids, utils::get_sub_vector_centroids}; +// for quantizing the distance table, we need to know the max possible distance, +// so we perform a flat search on the first `FLAT_NUM_4BIT_PQ` rows. +// increasing this number will increase the accuracy of the quantization, +// but also increase the computation time. +// 200 is a good trade-off according to the original paper. +const FLAT_NUM_4BIT_PQ: usize = 200; + /// Build a Distance Table from the query to each PQ centroid /// using L2 distance. pub fn build_distance_table_l2( @@ -96,21 +105,25 @@ pub fn build_distance_table_dot_impl( /// The squared L2 distance. /// #[inline] -pub(super) fn compute_l2_distance( +pub(super) fn compute_pq_distance( distance_table: &[f32], num_bits: u32, num_sub_vectors: usize, code: &[u8], + k_hint: usize, ) -> Vec { + if code.is_empty() { + return Vec::new(); + } if num_bits == 4 { - return compute_l2_distance_4bit(distance_table, num_sub_vectors, code); + return compute_pq_distance_4bit(distance_table, num_sub_vectors, code, k_hint); } // here `code` has been transposed, // so code[i][j] is the code of i-th sub-vector of the j-th vector, // and `code` is a flatten array of [num_sub_vectors, num_vectors] u8, // so code[i * num_vectors + j] is the code of i-th sub-vector of the j-th vector. let num_vectors = code.len() / num_sub_vectors; - let mut distances = vec![0.0_f32; num_vectors]; + let mut distances = vec![0.0; num_vectors]; // it must be 8 const NUM_CENTROIDS: usize = 2_usize.pow(8); for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { @@ -129,35 +142,145 @@ pub(super) fn compute_l2_distance( } #[inline] -pub(super) fn compute_l2_distance_4bit( +pub(super) fn compute_pq_distance_4bit( distance_table: &[f32], num_sub_vectors: usize, code: &[u8], + k_hint: usize, ) -> Vec { let num_vectors = code.len() * 2 / num_sub_vectors; - let mut distances = vec![0.0_f32; num_vectors]; + let mut distances = vec![0.0f32; num_vectors]; + + // compute the distances for first k_hint rows + // then use the max distance as qmax to quantize the distance table + let k_hint = min(k_hint, num_vectors); + let flat_num = max(FLAT_NUM_4BIT_PQ, k_hint).min(num_vectors); + compute_pq_distance_4bit_flat( + distance_table, + num_vectors, + code, + 0, + flat_num, + &mut distances, + ); + let qmax = *distances + .iter() + .take(flat_num) + .max_by(|a, b| a.total_cmp(b)) + .unwrap(); + + let (qmin, quantized_dists_table) = quantize_distance_table(distance_table, qmax); const NUM_CENTROIDS: usize = 2_usize.pow(4); - for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { - let dist_table = - &distance_table[sub_vec_idx * 2 * NUM_CENTROIDS..(sub_vec_idx * 2 + 1) * NUM_CENTROIDS]; - let dist_table_next = &distance_table - [(sub_vec_idx * 2 + 1) * NUM_CENTROIDS..(sub_vec_idx * 2 + 2) * NUM_CENTROIDS]; - debug_assert_eq!(vec_indices.len(), distances.len()); - vec_indices - .iter() - .zip(distances.iter_mut()) - .for_each(|(¢roid_idx, sum)| { - // for 4bit PQ, `centroid_idx` is 2 index, each index is 4bit. - let current_idx = centroid_idx & 0xF; - let next_idx = centroid_idx >> 4; - *sum += dist_table[current_idx as usize]; - *sum += dist_table_next[next_idx as usize]; - }); + let mut quantized_dists = vec![0_u8; num_vectors]; + + let remainder = num_vectors % NUM_CENTROIDS; + for i in (0..num_vectors - remainder).step_by(NUM_CENTROIDS) { + let mut block_distances = u8x16::zeros(); + + for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { + let origin_dist_table = unsafe { + u8x16::load_unaligned( + quantized_dists_table + .as_ptr() + .add(sub_vec_idx * 2 * NUM_CENTROIDS), + ) + }; + let origin_next_dist_table = unsafe { + u8x16::load_unaligned( + quantized_dists_table + .as_ptr() + .add((sub_vec_idx * 2 + 1) * NUM_CENTROIDS), + ) + }; + + let indices = unsafe { u8x16::load_unaligned(vec_indices.as_ptr().add(i)) }; + + // compute current distances + let current_indices = indices.bit_and(0x0F); + block_distances += origin_dist_table.shuffle(current_indices); + + // compute next distances + let next_indices = indices.right_shift::<4>(); + block_distances += origin_next_dist_table.shuffle(next_indices); + } + + unsafe { + block_distances.store_unaligned(quantized_dists.as_mut_ptr().add(i)); + } + } + if remainder > 0 { + let offset = max(num_vectors - remainder, flat_num); + compute_pq_distance_4bit_flat( + distance_table, + num_vectors, + code, + offset, + num_vectors - offset, + &mut distances, + ); } + // need to dequantize the distances + // to make the distances comparable to the others from the other partitions + let range = (qmax - qmin) / 255.0; + distances + .iter_mut() + .take(num_vectors - remainder) // don't overwrite the remainder + .skip(flat_num) // don't overwrite the first k_hint + .zip( + quantized_dists + .into_iter() + .take(num_vectors - remainder) + .skip(flat_num), + ) + .for_each(|(dist, q_dist)| { + *dist = (q_dist as f32) * range + qmin; + }); distances } +// compute the distance for 4bit PQ +// it only computes for the rows from offset to offset + length +fn compute_pq_distance_4bit_flat( + distance_table: &[f32], + num_vectors: usize, + code: &[u8], + offset: usize, + length: usize, + dists: &mut [f32], +) { + const NUM_CENTROIDS: usize = 2_usize.pow(4); + + for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { + let vec_indices = &vec_indices[offset..offset + length]; + let distances = &mut dists[offset..offset + length]; + let dist_table = &distance_table[sub_vec_idx * 2 * NUM_CENTROIDS..]; + let next_dist_table = &distance_table[(sub_vec_idx * 2 + 1) * NUM_CENTROIDS..]; + for (i, ¢roid_idx) in vec_indices.iter().enumerate() { + let current_idx = centroid_idx & 0xF; + let next_idx = centroid_idx >> 4; + distances[i] += dist_table[current_idx as usize]; + distances[i] += next_dist_table[next_idx as usize]; + } + } +} + +// Quantize the distance table to u8, +// map distance `d` to `(d-qmin) * 255 / (qmax-qmin)`m +// used for only 4bit PQ so num_centroids must be 16 +// returns (qmin, quantized_distance_table) +#[inline] +fn quantize_distance_table(distance_table: &[f32], qmax: f32) -> (f32, Vec) { + let qmin = distance_table.iter().cloned().fold(f32::INFINITY, f32::min); + let factor = 255.0 / (qmax - qmin); + let quantized_dist_table = distance_table + .iter() + .map(|&d| ((d - qmin) * factor).round() as u8) + .collect(); + + (qmin, quantized_dist_table) +} + /// Compute L2 distance from the query to all code without transposing the code. /// for testing only /// @@ -201,62 +324,6 @@ fn compute_l2_distance_without_transposing( distances.chain(remainder).collect() } -#[inline] -pub fn compute_dot_distance( - distance_table: &[f32], - num_bits: u32, - num_sub_vectors: usize, - code: &[u8], -) -> Vec { - if num_bits == 4 { - return compute_dot_distance_4bit(distance_table, num_sub_vectors, code); - } - let num_vectors = code.len() / num_sub_vectors; - let mut distances = vec![0.0; num_vectors]; - let num_centroids = num_centroids(num_bits); - for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { - let dist_table = &distance_table[sub_vec_idx * num_centroids..]; - vec_indices - .iter() - .zip(distances.iter_mut()) - .for_each(|(¢roid_idx, sum)| { - *sum += dist_table[centroid_idx as usize]; - }); - } - - distances -} - -#[inline] -pub fn compute_dot_distance_4bit( - distance_table: &[f32], - num_sub_vectors: usize, - code: &[u8], -) -> Vec { - let num_vectors = code.len() * 2 / num_sub_vectors; - let mut distances = vec![0.0; num_vectors]; - const NUM_CENTROIDS: usize = 2_usize.pow(4); - for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() { - let dist_table = - &distance_table[sub_vec_idx * 2 * NUM_CENTROIDS..(sub_vec_idx * 2 + 1) * NUM_CENTROIDS]; - let dist_table_next = &distance_table - [(sub_vec_idx * 2 + 1) * NUM_CENTROIDS..(sub_vec_idx * 2 + 2) * NUM_CENTROIDS]; - debug_assert_eq!(vec_indices.len(), distances.len()); - vec_indices - .iter() - .zip(distances.iter_mut()) - .for_each(|(¢roid_idx, sum)| { - // for 4bit PQ, `centroid_idx` is 2 index, each index is 4bit. - let current_idx = centroid_idx & 0xF; - let next_idx = centroid_idx >> 4; - *sum += dist_table[current_idx as usize]; - *sum += dist_table_next[next_idx as usize]; - }); - } - - distances -} - #[cfg(test)] mod tests { use crate::vector::pq::storage::transpose; @@ -278,11 +345,12 @@ mod tests { let pq_codes = Vec::from_iter((0..num_vectors * num_sub_vectors).map(|v| v as u8)); let pq_codes = UInt8Array::from_iter_values(pq_codes); let transposed_codes = transpose(&pq_codes, num_vectors, num_sub_vectors); - let distances = compute_l2_distance( + let distances = compute_pq_distance( &distance_table, num_bits, num_sub_vectors, transposed_codes.values(), + 100, ); let expected = compute_l2_distance_without_transposing::<4, 1>( &distance_table, diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index ef3839aa3e1..67be90b5519 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -26,14 +26,14 @@ use lance_io::{ utils::read_message, }; use lance_linalg::distance::{DistanceType, Dot, L2}; +use lance_table::utils::LanceIteratorExtension; use lance_table::{format::SelfDescribingFileReader, io::manifest::ManifestDescribing}; use object_store::path::Path; use prost::Message; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; -use super::distance::{build_distance_table_dot, compute_l2_distance}; -use super::distance::{build_distance_table_l2, compute_dot_distance}; +use super::distance::{build_distance_table_dot, build_distance_table_l2, compute_pq_distance}; use super::ProductQuantizer; use crate::vector::storage::STORAGE_METADATA_KEY; use crate::{ @@ -53,7 +53,7 @@ pub const PQ_METADATA_KEY: &str = "lance:pq"; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProductQuantizationMetadata { pub codebook_position: usize, - pub num_bits: u32, + pub nbits: u32, pub num_sub_vectors: usize, pub dimension: usize, @@ -105,8 +105,6 @@ impl QuantizerMetadata for ProductQuantizationMetadata { /// It stores PQ code, as well as the row ID to the original vectors. /// /// It is possible to store additional metadata to accelerate filtering later. -/// -/// TODO: support f16/f64 later. #[derive(Clone, Debug)] pub struct ProductQuantizationStorage { codebook: FixedSizeListArray, @@ -154,6 +152,18 @@ impl ProductQuantizationStorage { distance_type: DistanceType, transposed: bool, ) -> Result { + if batch.num_columns() != 2 { + log::warn!( + "PQ storage should have 2 columns, but got {} columns: {}", + batch.num_columns(), + batch.schema(), + ); + batch = batch.project(&[ + batch.schema().index_of(ROW_ID)?, + batch.schema().index_of(PQ_CODE_COLUMN)?, + ])?; + } + let Some(row_ids) = batch.column_by_name(ROW_ID) else { return Err(Error::Index { message: "Row ID column not found from PQ storage".to_string(), @@ -248,6 +258,10 @@ impl ProductQuantizationStorage { ) } + pub fn codebook(&self) -> &FixedSizeListArray { + &self.codebook + } + /// Load full PQ storage from disk. /// /// Parameters @@ -328,7 +342,7 @@ impl ProductQuantizationStorage { let metadata = ProductQuantizationMetadata { codebook_position: pos, - num_bits: self.num_bits, + nbits: self.num_bits, num_sub_vectors: self.num_sub_vectors, dimension: self.dimension, codebook: None, @@ -412,7 +426,7 @@ impl QuantizerStorage for ProductQuantizationStorage { Self::new( codebook, batch, - metadata.num_bits, + metadata.nbits, metadata.num_sub_vectors, metadata.dimension, distance_type, @@ -428,8 +442,11 @@ impl VectorStore for ProductQuantizationStorage { where Self: Sized, { + let distance_type = match distance_type { + DistanceType::Cosine => DistanceType::L2, + _ => distance_type, + }; let metadata_json = batch - .schema_ref() .metadata() .get(STORAGE_METADATA_KEY) .ok_or(Error::Index { @@ -445,7 +462,7 @@ impl VectorStore for ProductQuantizationStorage { Self::new( codebook, batch, - metadata.num_bits, + metadata.nbits, metadata.num_sub_vectors, metadata.dimension, distance_type, @@ -457,7 +474,7 @@ impl VectorStore for ProductQuantizationStorage { let codebook = pb::Tensor::try_from(&self.codebook)?.encode_to_vec(); let metadata = ProductQuantizationMetadata { codebook_position: 0, // deprecated in new format - num_bits: self.num_bits, + nbits: self.num_bits, num_sub_vectors: self.num_sub_vectors, dimension: self.dimension, codebook: None, @@ -470,6 +487,69 @@ impl VectorStore for ProductQuantizationStorage { Ok([self.batch.with_metadata(metadata)?].into_iter()) } + // we can't use the default implementation of remap, + // because PQ Storage transposed the PQ codes + fn remap(&self, mapping: &HashMap>) -> Result { + let transposed_codes = self.pq_code.values(); + let mut new_row_ids = Vec::with_capacity(self.len()); + let mut new_codes = Vec::with_capacity(self.len() * self.num_sub_vectors); + + let row_ids = self.row_ids.values(); + for (i, row_id) in row_ids.iter().enumerate() { + match mapping.get(row_id) { + Some(Some(new_id)) => { + new_row_ids.push(*new_id); + new_codes.extend(get_pq_code( + transposed_codes, + self.num_bits, + self.num_sub_vectors, + i as u32, + )); + } + Some(None) => {} + None => { + new_row_ids.push(*row_id); + new_codes.extend(get_pq_code( + transposed_codes, + self.num_bits, + self.num_sub_vectors, + i as u32, + )); + } + } + } + + let new_row_ids = Arc::new(UInt64Array::from(new_row_ids)); + let new_codes = UInt8Array::from(new_codes); + let batch = if new_row_ids.is_empty() { + RecordBatch::new_empty(self.schema()) + } else { + let num_bytes_in_code = new_codes.len() / new_row_ids.len(); + let new_transposed_codes = transpose(&new_codes, new_row_ids.len(), num_bytes_in_code); + let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values( + new_transposed_codes, + num_bytes_in_code as i32, + )?); + RecordBatch::try_new(self.schema(), vec![new_row_ids.clone(), codes_fsl])? + }; + let transposed_codes = batch[PQ_CODE_COLUMN] + .as_fixed_size_list() + .values() + .as_primitive::() + .clone(); + + Ok(Self { + codebook: self.codebook.clone(), + batch, + pq_code: Arc::new(transposed_codes), + row_ids: new_row_ids, + num_sub_vectors: self.num_sub_vectors, + num_bits: self.num_bits, + dimension: self.dimension, + distance_type: self.distance_type, + }) + } + fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result { unimplemented!() } @@ -537,12 +617,163 @@ impl VectorStore for ProductQuantizationStorage { } } - fn dist_calculator_from_id(&self, _: u32) -> Self::DistanceCalculator<'_> { - todo!("distance_between not implemented for PQ storage") + fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> { + let codes = get_pq_code( + self.pq_code.values(), + self.num_bits, + self.num_sub_vectors, + id, + ); + match self.codebook.value_type() { + DataType::Float16 => { + let codebook = self + .codebook + .values() + .as_primitive::() + .values(); + let query = get_centroids( + codebook, + self.num_bits, + self.num_sub_vectors, + self.dimension, + codes, + ); + PQDistCalculator::new( + codebook, + self.num_bits, + self.num_sub_vectors, + self.pq_code.clone(), + &query, + self.distance_type, + ) + } + DataType::Float32 => { + let codebook = self + .codebook + .values() + .as_primitive::() + .values(); + let query = get_centroids( + codebook, + self.num_bits, + self.num_sub_vectors, + self.dimension, + codes, + ); + PQDistCalculator::new( + codebook, + self.num_bits, + self.num_sub_vectors, + self.pq_code.clone(), + &query, + self.distance_type, + ) + } + DataType::Float64 => { + let codebook = self + .codebook + .values() + .as_primitive::() + .values(); + let query = get_centroids( + codebook, + self.num_bits, + self.num_sub_vectors, + self.dimension, + codes, + ); + PQDistCalculator::new( + codebook, + self.num_bits, + self.num_sub_vectors, + self.pq_code.clone(), + &query, + self.distance_type, + ) + } + _ => unimplemented!("Unsupported data type: {:?}", self.codebook.value_type()), + } } - fn distance_between(&self, _: u32, _: u32) -> f32 { - todo!("distance_between not implemented for PQ storage") + fn dist_between(&self, u: u32, v: u32) -> f32 { + // this is a fast way to compute distance between two vectors in the same storage. + // it doesn't construct the distance table. + let pq_codes = self.pq_code.values(); + let u_codes = get_pq_code(pq_codes, self.num_bits, self.num_sub_vectors, u); + let v_codes = get_pq_code(pq_codes, self.num_bits, self.num_sub_vectors, v); + + match self.codebook.value_type() { + DataType::Float16 => { + let qu = get_centroids( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.dimension, + u_codes, + ); + let qv = get_centroids( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.dimension, + v_codes, + ); + self.distance_type.func()(&qu, &qv) + } + DataType::Float32 => { + let qu = get_centroids( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.dimension, + u_codes, + ); + let qv = get_centroids( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.dimension, + v_codes, + ); + self.distance_type.func()(&qu, &qv) + } + DataType::Float64 => { + let qu = get_centroids( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.dimension, + u_codes, + ); + let qv = get_centroids( + self.codebook + .values() + .as_primitive::() + .values(), + self.num_bits, + self.num_sub_vectors, + self.dimension, + v_codes, + ); + self.distance_type.func()(&qu, &qv) + } + _ => unimplemented!("Unsupported data type: {:?}", self.codebook.value_type()), + } } } @@ -582,20 +813,14 @@ impl PQDistCalculator { } } - fn get_pq_code(&self, id: u32) -> Vec { - let num_sub_vectors_in_byte = if self.num_bits == 4 { - self.num_sub_vectors / 2 - } else { - self.num_sub_vectors - }; - let num_vectors = self.pq_code.len() / num_sub_vectors_in_byte; - self.pq_code - .values() - .iter() - .skip(id as usize) - .step_by(num_vectors) - .map(|&c| c as usize) - .collect() + fn get_pq_code(&self, id: u32) -> impl Iterator + '_ { + get_pq_code( + self.pq_code.values(), + self.num_bits, + self.num_sub_vectors, + id, + ) + .map(|v| v as usize) } } @@ -603,34 +828,40 @@ impl DistCalculator for PQDistCalculator { fn distance(&self, id: u32) -> f32 { let num_centroids = 2_usize.pow(self.num_bits); let pq_code = self.get_pq_code(id); - - if self.num_bits == 4 { + let diff = self.num_sub_vectors as f32 - 1.0; + let dist = if self.num_bits == 4 { pq_code - .into_iter() .enumerate() .map(|(i, c)| { let current_idx = c & 0x0F; let next_idx = c >> 4; + self.distance_table[2 * i * num_centroids + current_idx] + self.distance_table[(2 * i + 1) * num_centroids + next_idx] }) .sum() } else { pq_code - .into_iter() .enumerate() .map(|(i, c)| self.distance_table[i * num_centroids + c]) .sum() + }; + + if self.distance_type == DistanceType::Dot { + dist - diff + } else { + dist } } - fn distance_all(&self) -> Vec { + fn distance_all(&self, k_hint: usize) -> Vec { match self.distance_type { - DistanceType::L2 => compute_l2_distance( + DistanceType::L2 => compute_pq_distance( &self.distance_table, self.num_bits, self.num_sub_vectors, self.pq_code.values(), + k_hint, ), DistanceType::Cosine => { // it seems we implemented cosine distance at some version, @@ -642,49 +873,130 @@ impl DistCalculator for PQDistCalculator { // L2 over normalized vectors: ||x - y|| = x^2 + y^2 - 2 * xy = 1 + 1 - 2 * xy = 2 * (1 - xy) // Cosine distance: 1 - |xy| / (||x|| * ||y||) = 1 - xy / (x^2 * y^2) = 1 - xy / (1 * 1) = 1 - xy // Therefore, Cosine = L2 / 2 - let l2_dists = compute_l2_distance( + let l2_dists = compute_pq_distance( &self.distance_table, self.num_bits, self.num_sub_vectors, self.pq_code.values(), + k_hint, ); l2_dists.into_iter().map(|v| v / 2.0).collect() } - DistanceType::Dot => compute_dot_distance( - &self.distance_table, - self.num_bits, - self.num_sub_vectors, - self.pq_code.values(), - ), + DistanceType::Dot => { + let dot_dists = compute_pq_distance( + &self.distance_table, + self.num_bits, + self.num_sub_vectors, + self.pq_code.values(), + k_hint, + ); + let diff = self.num_sub_vectors as f32 - 1.0; + dot_dists.into_iter().map(|v| v - diff).collect() + } _ => unimplemented!("distance type is not supported: {:?}", self.distance_type), } } } +fn get_pq_code( + pq_code: &[u8], + num_bits: u32, + num_sub_vectors: usize, + id: u32, +) -> impl Iterator + '_ { + let num_bytes = if num_bits == 4 { + num_sub_vectors / 2 + } else { + num_sub_vectors + }; + + let num_vectors = pq_code.len() / num_bytes; + pq_code + .iter() + .skip(id as usize) + .step_by(num_vectors) + .copied() + .exact_size(num_bytes) +} + +fn get_centroids( + codebook: &[T], + num_bits: u32, + num_sub_vectors: usize, + dimension: usize, + codes: impl Iterator, +) -> Vec { + // codebook[i][j] is the j-th centroid of the i-th sub-vector. + // the codebook is stored as a flat array, codebook[i * num_centroids + j] = codebook[i][j] + + if num_bits == 4 { + return get_centroids_4bit(codebook, num_sub_vectors, dimension, codes); + } + + let num_centroids: usize = 2_usize.pow(8); + let sub_vector_width = dimension / num_sub_vectors; + let mut centroids = Vec::with_capacity(dimension); + for (sub_vec_idx, centroid_idx) in codes.enumerate() { + let centroid_idx = centroid_idx as usize; + let centroid = &codebook[sub_vec_idx * num_centroids * sub_vector_width + + centroid_idx * sub_vector_width + ..sub_vec_idx * num_centroids * sub_vector_width + + (centroid_idx + 1) * sub_vector_width]; + centroids.extend_from_slice(centroid); + } + centroids +} + +fn get_centroids_4bit( + codebook: &[T], + num_sub_vectors: usize, + dimension: usize, + codes: impl Iterator, +) -> Vec { + let num_centroids: usize = 16; + let sub_vector_width = dimension / num_sub_vectors; + let mut centroids = Vec::with_capacity(dimension); + for (sub_vec_idx, centroid_idx) in codes.into_iter().enumerate() { + let current_idx = (centroid_idx & 0x0F) as usize; + let offset = 2 * sub_vec_idx * num_centroids * sub_vector_width; + let current_centroid = &codebook[offset + current_idx * sub_vector_width + ..offset + (current_idx + 1) * sub_vector_width]; + centroids.extend_from_slice(current_centroid); + + let next_idx = (centroid_idx >> 4) as usize; + let offset = (2 * sub_vec_idx + 1) * num_centroids * sub_vector_width; + let next_centroid = &codebook + [offset + next_idx * sub_vector_width..offset + (next_idx + 1) * sub_vector_width]; + centroids.extend_from_slice(next_centroid); + } + centroids +} + #[cfg(test)] mod tests { use crate::vector::storage::StorageBuilder; use super::*; - use arrow_array::Float32Array; + use arrow_array::{Float32Array, UInt32Array}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_arrow::FixedSizeListArrayExt; use lance_core::datatypes::Schema; use lance_core::ROW_ID_FIELD; + use rand::Rng; const DIM: usize = 32; const TOTAL: usize = 512; const NUM_SUB_VECTORS: usize = 16; async fn create_pq_storage() -> ProductQuantizationStorage { - let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|v| v as f32)); + let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random())); let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(); - let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::L2); + let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot); let schema = ArrowSchema::new(vec![ Field::new( - "vectors", + "vec", DataType::FixedSizeList( Field::new_list_field(DataType::Float32, true).into(), DIM as i32, @@ -693,14 +1005,49 @@ mod tests { ), ROW_ID_FIELD.clone(), ]); - let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|v| v as f32)); + let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random())); let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64)); let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl), Arc::new(row_ids)]).unwrap(); - StorageBuilder::new("vectors".to_owned(), pq.distance_type, pq) - .build(&batch) + StorageBuilder::new("vec".to_owned(), pq.distance_type, pq) + .unwrap() + .build(vec![batch]) + .unwrap() + } + + async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage { + let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random())); + let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(); + let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot); + + let schema = ArrowSchema::new(vec![ + Field::new( + "vec", + DataType::FixedSizeList( + Field::new_list_field(DataType::Float32, true).into(), + DIM as i32, + ), + true, + ), + ROW_ID_FIELD.clone(), + Field::new("extra", DataType::UInt32, true), + ]); + let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random())); + let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64)); + let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32)); + let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + let batch = RecordBatch::try_new( + schema.into(), + vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)], + ) + .unwrap(); + + StorageBuilder::new("vec".to_owned(), pq.distance_type, pq) + .unwrap() + .assert_num_columns(false) + .build(vec![batch]) .unwrap() } @@ -747,7 +1094,39 @@ mod tests { let expected = (0..storage.len()) .map(|id| dist_calc.distance(id as u32)) .collect::>(); - let distances = dist_calc.distance_all(); + let distances = dist_calc.distance_all(100); assert_eq!(distances, expected); } + + #[tokio::test] + async fn test_dist_between() { + let mut rng = rand::thread_rng(); + let storage = create_pq_storage().await; + let u = rng.gen_range(0..storage.len() as u32); + let v = rng.gen_range(0..storage.len() as u32); + let dist1 = storage.dist_between(u, v); + let dist2 = storage.dist_between(v, u); + assert_eq!(dist1, dist2); + } + + #[tokio::test] + async fn test_remap_with_extra_column() { + let storage = create_pq_storage_with_extra_column().await; + let mut mapping = HashMap::new(); + for i in 0..TOTAL / 2 { + mapping.insert(i as u64, Some((TOTAL + i) as u64)); + } + for i in TOTAL / 2..TOTAL { + mapping.insert(i as u64, None); + } + let new_storage = storage.remap(&mapping).unwrap(); + assert_eq!(new_storage.len(), TOTAL / 2); + assert_eq!(new_storage.row_ids.len(), TOTAL / 2); + for (i, row_id) in new_storage.row_ids().enumerate() { + assert_eq!(*row_id, (TOTAL + i) as u64); + } + assert_eq!(new_storage.batch.num_columns(), 2); + assert!(new_storage.batch.column_by_name(ROW_ID).is_some()); + assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some()); + } } diff --git a/rust/lance-index/src/vector/pq/transform.rs b/rust/lance-index/src/vector/pq/transform.rs index dfd28f9454d..ce537144245 100644 --- a/rust/lance-index/src/vector/pq/transform.rs +++ b/rust/lance-index/src/vector/pq/transform.rs @@ -8,7 +8,7 @@ use arrow_array::{cast::AsArray, Array, RecordBatch}; use arrow_schema::Field; use lance_arrow::RecordBatchExt; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use super::ProductQuantizer; diff --git a/rust/lance-index/src/vector/pq/utils.rs b/rust/lance-index/src/vector/pq/utils.rs index 8766eb80057..1e505955868 100644 --- a/rust/lance-index/src/vector/pq/utils.rs +++ b/rust/lance-index/src/vector/pq/utils.rs @@ -3,7 +3,7 @@ use arrow_array::{cast::AsArray, types::ArrowPrimitiveType, Array, FixedSizeListArray}; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; /// Divide a 2D vector in [`T::Array`] to `m` sub-vectors. /// diff --git a/rust/lance-index/src/vector/quantizer.rs b/rust/lance-index/src/vector/quantizer.rs index 1290a0f07b2..7c1f1a37200 100644 --- a/rust/lance-index/src/vector/quantizer.rs +++ b/rust/lance-index/src/vector/quantizer.rs @@ -15,16 +15,24 @@ use lance_io::traits::Reader; use lance_linalg::distance::DistanceType; use lance_table::format::SelfDescribingFileReader; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; use crate::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; -use super::flat::index::FlatQuantizer; +use super::flat::index::{FlatBinQuantizer, FlatQuantizer}; use super::pq::ProductQuantizer; use super::{ivf::storage::IvfModel, sq::ScalarQuantizer, storage::VectorStore}; -pub trait Quantization: Send + Sync + Debug + DeepSizeOf + Into { - type BuildParams: QuantizerBuildParams; +pub trait Quantization: + Send + + Sync + + Clone + + Debug + + DeepSizeOf + + Into + + TryFrom +{ + type BuildParams: QuantizerBuildParams + Send + Sync; type Metadata: QuantizerMetadata + Send + Sync; type Storage: QuantizerStorage + VectorStore + Debug; @@ -33,6 +41,7 @@ pub trait Quantization: Send + Sync + Debug + DeepSizeOf + Into { distance_type: DistanceType, params: &Self::BuildParams, ) -> Result; + fn retrain(&mut self, data: &dyn Array) -> Result<()>; fn code_dim(&self) -> usize; fn column(&self) -> &'static str; fn use_residual(_: DistanceType) -> bool { @@ -98,6 +107,7 @@ impl QuantizerBuildParams for () { #[derive(Debug, Clone, DeepSizeOf)] pub enum Quantizer { Flat(FlatQuantizer), + FlatBin(FlatBinQuantizer), Product(ProductQuantizer), Scalar(ScalarQuantizer), } @@ -106,6 +116,7 @@ impl Quantizer { pub fn code_dim(&self) -> usize { match self { Self::Flat(fq) => fq.code_dim(), + Self::FlatBin(fq) => fq.code_dim(), Self::Product(pq) => pq.code_dim(), Self::Scalar(sq) => sq.code_dim(), } @@ -114,6 +125,7 @@ impl Quantizer { pub fn column(&self) -> &'static str { match self { Self::Flat(fq) => fq.column(), + Self::FlatBin(fq) => fq.column(), Self::Product(pq) => pq.column(), Self::Scalar(sq) => sq.column(), } @@ -122,6 +134,7 @@ impl Quantizer { pub fn metadata_key(&self) -> &'static str { match self { Self::Flat(_) => FlatQuantizer::metadata_key(), + Self::FlatBin(_) => FlatBinQuantizer::metadata_key(), Self::Product(_) => ProductQuantizer::metadata_key(), Self::Scalar(_) => ScalarQuantizer::metadata_key(), } @@ -130,6 +143,7 @@ impl Quantizer { pub fn quantization_type(&self) -> QuantizationType { match self { Self::Flat(_) => QuantizationType::Flat, + Self::FlatBin(_) => QuantizationType::Flat, Self::Product(_) => QuantizationType::Product, Self::Scalar(_) => QuantizationType::Scalar, } @@ -138,6 +152,7 @@ impl Quantizer { pub fn metadata(&self, args: Option) -> Result { match self { Self::Flat(fq) => fq.metadata(args), + Self::FlatBin(fq) => fq.metadata(args), Self::Product(pq) => pq.metadata(args), Self::Scalar(sq) => sq.metadata(args), } diff --git a/rust/lance-index/src/vector/residual.rs b/rust/lance-index/src/vector/residual.rs index b094e43d114..39678ced349 100644 --- a/rust/lance-index/src/vector/residual.rs +++ b/rust/lance-index/src/vector/residual.rs @@ -1,20 +1,24 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::iter; +use std::ops::{AddAssign, DivAssign}; use std::sync::Arc; +use arrow_array::ArrowNumericType; use arrow_array::{ cast::AsArray, - types::{ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, UInt32Type}, + types::{Float16Type, Float32Type, Float64Type, UInt32Type}, Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array, }; use arrow_schema::DataType; use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt}; use lance_core::{Error, Result}; use lance_linalg::distance::{DistanceType, Dot, L2}; -use lance_linalg::kmeans::compute_partitions; -use num_traits::Float; -use snafu::{location, Location}; +use lance_linalg::kmeans::{compute_partitions, KMeansAlgoFloat}; +use lance_table::utils::LanceIteratorExtension; +use num_traits::{Float, FromPrimitive, Num}; +use snafu::location; use tracing::instrument; use super::transform::Transformer; @@ -53,39 +57,45 @@ impl ResidualTransform { } } -fn do_compute_residual( +fn do_compute_residual( centroids: &FixedSizeListArray, vectors: &FixedSizeListArray, distance_type: Option, partitions: Option<&UInt32Array>, ) -> Result where - T::Native: Float + L2 + Dot, + T::Native: Num + Float + L2 + Dot + DivAssign + AddAssign + FromPrimitive, { let dimension = centroids.value_length() as usize; - let centroids_slice = centroids.values().as_primitive::().values(); - let vectors_slice = vectors.values().as_primitive::().values(); + let centroids = centroids.values().as_primitive::(); + let vectors = vectors.values().as_primitive::(); let part_ids = partitions.cloned().unwrap_or_else(|| { - compute_partitions( - centroids_slice, - vectors_slice, + compute_partitions::>( + centroids, + vectors, dimension, distance_type.expect("provide either partitions or distance type"), ) + .0 .into() }); + let part_ids = part_ids.values(); + let vectors_slice = vectors.values(); + let centroids_slice = centroids.values(); let residuals = vectors_slice .chunks_exact(dimension) .enumerate() .flat_map(|(idx, vector)| { - let part_id = part_ids.value(idx) as usize; + let part_id = part_ids[idx] as usize; let c = ¢roids_slice[part_id * dimension..(part_id + 1) * dimension]; - vector.iter().zip(c.iter()).map(|(v, cent)| *v - *cent) + iter::zip(vector, c).map(|(v, cent)| *v - *cent) }) + .exact_size(vectors.len()) .collect::>(); let residual_arr = PrimitiveArray::::from_iter_values(residuals); + debug_assert_eq!(residual_arr.len(), vectors.len()); Ok(FixedSizeListArray::try_new_from_values( residual_arr, dimension as i32, @@ -126,6 +136,13 @@ pub(crate) fn compute_residual( (DataType::Float64, DataType::Float64) => { do_compute_residual::(centroids, vectors, distance_type, partitions) } + (DataType::Float32, DataType::Int8) => { + do_compute_residual::( + centroids, + &vectors.convert_to_floating_point()?, + distance_type, + partitions) + } _ => Err(Error::Index { message: format!( "Compute residual vector: centroids and vector type mismatch: centroid: {}, vector: {}", @@ -171,7 +188,16 @@ impl Transformer for ResidualTransform { compute_residual(&self.centroids, original_vectors, None, Some(part_ids_ref))?; // Replace original column with residual column. - let batch = batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))?; + let batch = if residual_arr.data_type() != original.data_type() { + batch.replace_column_schema_by_name( + &self.vec_col, + residual_arr.data_type().clone(), + Arc::new(residual_arr), + )? + } else { + batch.replace_column_by_name(&self.vec_col, Arc::new(residual_arr))? + }; + Ok(batch) } } diff --git a/rust/lance-index/src/vector/sq.rs b/rust/lance-index/src/vector/sq.rs index 269f637695a..5829d27f9b4 100644 --- a/rust/lance-index/src/vector/sq.rs +++ b/rust/lance-index/src/vector/sq.rs @@ -15,7 +15,7 @@ use lance_arrow::*; use lance_core::{Error, Result}; use lance_linalg::distance::DistanceType; use num_traits::*; -use snafu::{location, Location}; +use snafu::location; use storage::{ScalarQuantizationMetadata, ScalarQuantizationStorage, SQ_METADATA_KEY}; use super::quantizer::{Quantization, QuantizationMetadata, QuantizationType, Quantizer}; @@ -88,7 +88,7 @@ impl ScalarQuantizer { .as_slice(); self.bounds = data.iter().fold(self.bounds.clone(), |f, v| { - f.start.min(v.to_f64().unwrap())..f.end.max(v.to_f64().unwrap()) + f.start.min(v.as_())..f.end.max(v.as_()) }); Ok(self.bounds.clone()) @@ -119,7 +119,7 @@ impl ScalarQuantizer { .as_slice(); // TODO: support SQ4 - let builder: Vec = scale_to_u8::(data, self.bounds.clone()); + let builder: Vec = scale_to_u8::(data, &self.bounds); Ok(Arc::new(FixedSizeListArray::try_new_from_values( UInt8Array::from(builder), @@ -187,6 +187,35 @@ impl Quantization for ScalarQuantizer { Ok(quantizer) } + fn retrain(&mut self, data: &dyn Array) -> Result<()> { + let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index { + message: format!( + "SQ retrain: input is not a FixedSizeList: {}", + data.data_type() + ), + location: location!(), + })?; + + match fsl.value_type() { + DataType::Float16 => { + self.update_bounds::(fsl)?; + } + DataType::Float32 => { + self.update_bounds::(fsl)?; + } + DataType::Float64 => { + self.update_bounds::(fsl)?; + } + value_type => { + return Err(Error::invalid_input( + format!("unsupported data type {} for scalar quantizer", value_type), + location!(), + )) + } + } + Ok(()) + } + fn code_dim(&self) -> usize { self.dim } @@ -232,23 +261,31 @@ impl Quantization for ScalarQuantizer { } } -pub(crate) fn scale_to_u8(values: &[T::Native], bounds: Range) -> Vec { +pub(crate) fn scale_to_u8(values: &[T::Native], bounds: &Range) -> Vec { + if bounds.start == bounds.end { + return vec![0; values.len()]; + } + let range = bounds.end - bounds.start; values .iter() .map(|&v| { let v = v.to_f64().unwrap(); - match v { - v if v < bounds.start => 0, - v if v > bounds.end => 255, - _ => ((v - bounds.start) * f64::from_u32(255).unwrap() / range) - .round() - .to_u8() - .unwrap(), - } + let v = ((v - bounds.start) * 255.0 / range).round(); + v as u8 // rust `as` performs saturating cast when casting float to int, so it's safe and expected here }) .collect_vec() } + +pub(crate) fn inverse_scalar_dist( + values: impl Iterator, + bounds: &Range, +) -> Vec { + let range = (bounds.end - bounds.start) as f32; + values + .map(|v| v * range.powi(2) / 255.0.powi(2)) + .collect_vec() +} #[cfg(test)] mod tests { use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; @@ -340,4 +377,15 @@ mod tests { assert_eq!(*v, (i * 17) as u8,); }); } + + #[tokio::test] + async fn test_scale_to_u8_with_nan() { + let values = vec![0.0, 1.0, 2.0, 3.0, f64::NAN]; + let bounds = Range:: { + start: 0.0, + end: 3.0, + }; + let u8_values = scale_to_u8::(&values, &bounds); + assert_eq!(u8_values, vec![0, 85, 170, 255, 0]); + } } diff --git a/rust/lance-index/src/vector/sq/storage.rs b/rust/lance-index/src/vector/sq/storage.rs index 4428cfcd766..5b180eaf047 100644 --- a/rust/lance-index/src/vector/sq/storage.rs +++ b/rust/lance-index/src/vector/sq/storage.rs @@ -3,13 +3,13 @@ use std::ops::Range; -use arrow::compute::concat_batches; +use arrow::{compute::concat_batches, datatypes::Float16Type}; use arrow_array::{ cast::AsArray, types::{Float32Type, UInt64Type, UInt8Type}, ArrayRef, RecordBatch, UInt64Array, UInt8Array, }; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; use deepsize::DeepSizeOf; use lance_core::{Error, Result, ROW_ID}; @@ -19,7 +19,7 @@ use lance_linalg::distance::{dot_distance, l2_distance_uint_scalar, DistanceType use lance_table::format::SelfDescribingFileReader; use object_store::path::Path; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; use crate::vector::storage::STORAGE_METADATA_KEY; use crate::{ @@ -32,7 +32,7 @@ use crate::{ IndexMetadata, INDEX_METADATA_SCHEMA_KEY, }; -use super::{scale_to_u8, ScalarQuantizer}; +use super::{inverse_scalar_dist, scale_to_u8, ScalarQuantizer}; pub const SQ_METADATA_KEY: &str = "lance:sq"; @@ -357,7 +357,7 @@ impl VectorStore for ScalarQuantizationStorage { /// Using dist calculator can be more efficient as it can pre-compute some /// values. fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> { - SQDistCalculator::new(query, self, self.quantizer.bounds.clone()) + SQDistCalculator::new(query, self, self.quantizer.bounds()) } fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> { @@ -365,76 +365,77 @@ impl VectorStore for ScalarQuantizationStorage { let query_sq_code = chunk.sq_code_slice(id - offset).to_vec(); SQDistCalculator { query_sq_code, + bounds: self.quantizer.bounds(), storage: self, } } - - fn distance_between(&self, a: u32, b: u32) -> f32 { - let (offset_a, chunk_a) = self.chunk(a); - let (offset_b, chunk_b) = self.chunk(b); - let a_slice = chunk_a.sq_code_slice(a - offset_a); - let b_slice = chunk_b.sq_code_slice(b - offset_b); - match self.distance_type { - DistanceType::L2 | DistanceType::Cosine => l2_distance_uint_scalar(a_slice, b_slice), - DistanceType::Dot => dot_distance(a_slice, b_slice), - _ => panic!("We should not reach here: sq distance can only be L2 or Dot"), - } - } } pub struct SQDistCalculator<'a> { query_sq_code: Vec, + bounds: Range, storage: &'a ScalarQuantizationStorage, } impl<'a> SQDistCalculator<'a> { fn new(query: ArrayRef, storage: &'a ScalarQuantizationStorage, bounds: Range) -> Self { - let query_sq_code = - scale_to_u8::(query.as_primitive::().values(), bounds); + // This is okay-ish to use hand-rolled dynamic dispatch here + // since we search 10s-100s of partitions, we can afford the overhead + // this could be annoying at indexing time for HNSW, which requires constructing the + // dist calculator frequently. However, HNSW isn't first-class citizen in Lance yet. so be it. + let query_sq_code = match query.data_type() { + DataType::Float16 => { + scale_to_u8::(query.as_primitive::().values(), &bounds) + } + DataType::Float32 => { + scale_to_u8::(query.as_primitive::().values(), &bounds) + } + _ => { + panic!("Unsupported data type for ScalarQuantizationStorage"); + } + }; Self { query_sq_code, + bounds, storage, } } } -impl<'a> DistCalculator for SQDistCalculator<'a> { +impl DistCalculator for SQDistCalculator<'_> { fn distance(&self, id: u32) -> f32 { let (offset, chunk) = self.storage.chunk(id); let sq_code = chunk.sq_code_slice(id - offset); - match self.storage.distance_type { + let dist = match self.storage.distance_type { DistanceType::L2 | DistanceType::Cosine => { l2_distance_uint_scalar(sq_code, &self.query_sq_code) } DistanceType::Dot => dot_distance(sq_code, &self.query_sq_code), _ => panic!("We should not reach here: sq distance can only be L2 or Dot"), - } + }; + inverse_scalar_dist(std::iter::once(dist), &self.bounds)[0] } - fn distance_all(&self) -> Vec { + fn distance_all(&self, _k_hint: usize) -> Vec { match self.storage.distance_type { - DistanceType::L2 | DistanceType::Cosine => self - .storage - .chunks - .iter() - .flat_map(|c| { + DistanceType::L2 | DistanceType::Cosine => inverse_scalar_dist( + self.storage.chunks.iter().flat_map(|c| { c.sq_codes .values() .chunks_exact(c.dim()) .map(|sq_codes| l2_distance_uint_scalar(sq_codes, &self.query_sq_code)) - }) - .collect(), - DistanceType::Dot => self - .storage - .chunks - .iter() - .flat_map(|c| { + }), + &self.bounds, + ), + DistanceType::Dot => inverse_scalar_dist( + self.storage.chunks.iter().flat_map(|c| { c.sq_codes .values() .chunks_exact(c.dim()) .map(|sq_codes| dot_distance(sq_codes, &self.query_sq_code)) - }) - .collect(), + }), + &self.bounds, + ), _ => panic!("We should not reach here: sq distance can only be L2 or Dot"), } } diff --git a/rust/lance-index/src/vector/sq/transform.rs b/rust/lance-index/src/vector/sq/transform.rs index 1f7ff391be0..0e45fb661d3 100644 --- a/rust/lance-index/src/vector/sq/transform.rs +++ b/rust/lance-index/src/vector/sq/transform.rs @@ -12,7 +12,7 @@ use arrow_array::{ RecordBatch, }; use arrow_schema::{DataType, Field}; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use crate::vector::transform::Transformer; @@ -60,8 +60,6 @@ impl Transformer for SQTransformer { ), location: location!(), })?; - let batch = batch.drop_column(&self.input_column)?; - let fsl = input.as_fixed_size_list_opt().ok_or(Error::Index { message: "input column is not vector type".to_string(), location: location!(), @@ -79,7 +77,9 @@ impl Transformer for SQTransformer { }; let sq_field = Field::new(&self.output_column, sq_code.data_type().clone(), false); - let batch = batch.try_with_column(sq_field, Arc::new(sq_code))?; + let batch = batch + .try_with_column(sq_field, Arc::new(sq_code))? + .drop_column(&self.input_column)?; Ok(batch) } } diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index f35eda0ef83..285994ac756 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -3,21 +3,24 @@ //! Vector Storage, holding (quantized) vectors and providing distance calculation. +use std::collections::HashMap; use std::{any::Any, sync::Arc}; +use arrow::array::AsArray; use arrow::compute::concat_batches; -use arrow_array::{ArrayRef, RecordBatch}; -use arrow_schema::{Field, SchemaRef}; +use arrow::datatypes::UInt64Type; +use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array}; +use arrow_schema::SchemaRef; use deepsize::DeepSizeOf; use futures::prelude::stream::TryStreamExt; use lance_arrow::RecordBatchExt; -use lance_core::{Error, Result}; +use lance_core::{Error, Result, ROW_ID}; use lance_encoding::decoder::FilterExpression; use lance_file::v2::reader::FileReader; use lance_io::ReadBatchParams; use lance_linalg::distance::DistanceType; use prost::Message; -use snafu::{location, Location}; +use snafu::location; use crate::{ pb, @@ -25,7 +28,6 @@ use crate::{ ivf::storage::{IvfModel, IVF_METADATA_KEY}, quantizer::Quantization, }, - INDEX_METADATA_SCHEMA_KEY, }; use super::quantizer::Quantizer; @@ -38,7 +40,11 @@ use super::DISTANCE_TYPE_KEY; /// pub trait DistCalculator { fn distance(&self, id: u32) -> f32; - fn distance_all(&self) -> Vec; + + // return the distances of all rows + // k_hint is a hint that can be used for optimization + fn distance_all(&self, k_hint: usize) -> Vec; + fn prefetch(&self, _id: u32) {} } @@ -70,7 +76,44 @@ pub trait VectorStore: Send + Sync + Sized + Clone { fn schema(&self) -> &SchemaRef; - fn to_batches(&self) -> Result>; + fn to_batches(&self) -> Result + Send>; + + fn remap(&self, mapping: &HashMap>) -> Result { + let batches = self + .to_batches()? + .map(|b| { + let mut indices = Vec::with_capacity(b.num_rows()); + let mut new_row_ids = Vec::with_capacity(b.num_rows()); + + let row_ids = b.column(0).as_primitive::().values(); + for (i, row_id) in row_ids.iter().enumerate() { + match mapping.get(row_id) { + Some(Some(new_id)) => { + indices.push(i as u32); + new_row_ids.push(*new_id); + } + Some(None) => {} + None => { + indices.push(i as u32); + new_row_ids.push(*row_id); + } + } + } + + let indices = UInt32Array::from(indices); + let new_row_ids = Arc::new(UInt64Array::from(new_row_ids)); + let new_vectors = arrow::compute::take(b.column(1), &indices, None)?; + + Ok(RecordBatch::try_new( + self.schema().clone(), + vec![new_row_ids, new_vectors], + )?) + }) + .collect::>>()?; + + let batch = concat_batches(self.schema(), batches.iter())?; + Self::try_from_batch(batch, self.distance_type()) + } fn len(&self) -> usize; @@ -99,44 +142,60 @@ pub trait VectorStore: Send + Sync + Sized + Clone { fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_>; - fn distance_between(&self, a: u32, b: u32) -> f32; - - fn dist_calculator_from_native(&self, _query: ArrayRef) -> Self::DistanceCalculator<'_> { - todo!("Implement this") + fn dist_between(&self, u: u32, v: u32) -> f32 { + let dist_cal_u = self.dist_calculator_from_id(u); + dist_cal_u.distance(v) } } pub struct StorageBuilder { - column: String, + vector_column: String, distance_type: DistanceType, quantizer: Q, + + // this is for testing purpose + assert_num_columns: bool, } impl StorageBuilder { - pub fn new(column: String, distance_type: DistanceType, quantizer: Q) -> Self { - Self { - column, + pub fn new(vector_column: String, distance_type: DistanceType, quantizer: Q) -> Result { + Ok(Self { + vector_column, distance_type, quantizer, - } + assert_num_columns: true, + }) + } + + // this is for testing purpose + pub fn assert_num_columns(mut self, assert_num_columns: bool) -> Self { + self.assert_num_columns = assert_num_columns; + self } - pub fn build(&self, batch: &RecordBatch) -> Result { - let vectors = batch.column_by_name(&self.column).ok_or(Error::Schema { - message: format!("column {} not found", self.column), - location: location!(), - })?; - let code_array = self.quantizer.quantize(vectors.as_ref())?; - let batch = batch - .try_with_column( - Field::new( - self.quantizer.column(), - code_array.data_type().clone(), - true, - ), - code_array, - )? - .drop_column(&self.column)?; + pub fn build(&self, batches: Vec) -> Result { + let mut batch = concat_batches(batches[0].schema_ref(), batches.iter())?; + + if batch.column_by_name(self.quantizer.column()).is_none() { + let vectors = batch + .column_by_name(&self.vector_column) + .ok_or(Error::Index { + message: format!("Vector column {} not found in batch", self.vector_column), + location: location!(), + })?; + let codes = self.quantizer.quantize(vectors)?; + batch = batch.drop_column(&self.vector_column)?.try_with_column( + arrow_schema::Field::new(self.quantizer.column(), codes.data_type().clone(), true), + codes, + )?; + } + + if self.assert_num_columns { + debug_assert_eq!(batch.num_columns(), 2, "{}", batch.schema()); + } + debug_assert!(batch.column_by_name(ROW_ID).is_some()); + debug_assert!(batch.column_by_name(self.quantizer.column()).is_some()); + let batch = batch.add_metadata( STORAGE_METADATA_KEY.to_owned(), self.quantizer.metadata(None)?.to_string(), @@ -175,7 +234,7 @@ impl IvfQuantizationStorage { .metadata .get(DISTANCE_TYPE_KEY) .ok_or(Error::Index { - message: format!("{} not found", INDEX_METADATA_SCHEMA_KEY), + message: format!("{} not found", DISTANCE_TYPE_KEY), location: location!(), })? .as_str(), @@ -214,11 +273,27 @@ impl IvfQuantizationStorage { }) } + pub fn num_rows(&self) -> u64 { + self.reader.num_rows() + } + pub fn quantizer(&self) -> Result { - let metadata = serde_json::from_str(&self.metadata[0])?; + let metadata = self.metadata::()?; Q::from_metadata(&metadata, self.distance_type) } + pub fn metadata(&self) -> Result { + Ok(serde_json::from_str(&self.metadata[0])?) + } + + pub fn distance_type(&self) -> DistanceType { + self.distance_type + } + + pub fn schema(&self) -> SchemaRef { + Arc::new(self.reader.schema().as_ref().into()) + } + /// Get the number of partitions in the storage. pub fn num_partitions(&self) -> usize { self.ivf.num_partitions() diff --git a/rust/lance-index/src/vector/transform.rs b/rust/lance-index/src/vector/transform.rs index 21ab74cd9f1..0c53833f7ef 100644 --- a/rust/lance-index/src/vector/transform.rs +++ b/rust/lance-index/src/vector/transform.rs @@ -7,14 +7,16 @@ use std::fmt::Debug; use std::sync::Arc; +use arrow::datatypes::UInt64Type; use arrow_array::types::{Float16Type, Float32Type, Float64Type}; +use arrow_array::UInt64Array; use arrow_array::{cast::AsArray, Array, ArrowPrimitiveType, RecordBatch, UInt32Array}; -use arrow_schema::{DataType, Field}; +use arrow_schema::{DataType, Field, Schema}; use lance_arrow::RecordBatchExt; use num_traits::Float; -use snafu::{location, Location}; +use snafu::location; -use lance_core::{Error, Result}; +use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD}; use lance_linalg::kernels::normalize_fsl; use tracing::instrument; @@ -66,20 +68,16 @@ impl Transformer for NormalizeTransformer { ), location: location!(), })?; - let data = arr.as_fixed_size_list_opt().ok_or(Error::Index { - message: format!( - "Normalize Transform: column {} is not a fixed size list: {}", - self.input_column, - arr.data_type() - ), - location: location!(), - })?; + + let data = arr.as_fixed_size_list(); let norm = normalize_fsl(data)?; + let transformed = Arc::new(norm); + if let Some(output_column) = &self.output_column { - let field = Field::new(output_column, norm.data_type().clone(), true); - Ok(batch.try_with_column(field, Arc::new(norm))?) + let field = Field::new(output_column, transformed.data_type().clone(), true); + Ok(batch.try_with_column(field, transformed)?) } else { - Ok(batch.replace_column_by_name(&self.input_column, Arc::new(norm))?) + Ok(batch.replace_column_by_name(&self.input_column, transformed)?) } } } @@ -102,10 +100,12 @@ fn is_all_finite(arr: &dyn Array) -> bool where T::Native: Float, { - !arr.as_primitive::() - .values() - .iter() - .any(|&v| !v.is_finite()) + arr.null_count() == 0 + && !arr + .as_primitive::() + .values() + .iter() + .any(|&v| !v.is_finite()) } impl Transformer for KeepFiniteVectors { @@ -118,34 +118,38 @@ impl Transformer for KeepFiniteVectors { ), location: location!(), })?; - let data = arr.as_fixed_size_list_opt().ok_or(Error::Index { - message: format!( - "KeepFiniteVectors: column {} is not a fixed size list: {}", - self.column, - arr.data_type() - ), - location: location!(), - })?; - let valid = data - .iter() - .enumerate() - .filter_map(|(idx, arr)| { - arr.and_then(|data| { - let is_valid = match data.data_type() { - DataType::Float16 => is_all_finite::(&data), - DataType::Float32 => is_all_finite::(&data), - DataType::Float64 => is_all_finite::(&data), - _ => false, - }; - if is_valid { - Some(idx as u32) - } else { - None - } + let data = match arr.data_type() { + DataType::FixedSizeList(_, _) => arr.as_fixed_size_list(), + DataType::List(_) => arr.as_list::().values().as_fixed_size_list(), + _ => { + return Err(Error::Index { + message: format!( + "KeepFiniteVectors: column {} is not a fixed size list: {}", + self.column, + arr.data_type() + ), + location: location!(), }) - }) - .collect::>(); + } + }; + + let mut valid = Vec::with_capacity(batch.num_rows()); + data.iter().enumerate().for_each(|(idx, arr)| { + if let Some(data) = arr { + let is_valid = match data.data_type() { + DataType::Float16 => is_all_finite::(&data), + DataType::Float32 => is_all_finite::(&data), + DataType::Float64 => is_all_finite::(&data), + DataType::UInt8 => data.null_count() == 0, + DataType::Int8 => data.null_count() == 0, + _ => false, + }; + if is_valid { + valid.push(idx as u32); + } + }; + }); if valid.len() < batch.num_rows() { let indices = UInt32Array::from(valid); Ok(batch.take(&indices)?) @@ -174,6 +178,64 @@ impl Transformer for DropColumn { } } +#[derive(Debug)] +pub struct Flatten { + column: String, +} + +impl Flatten { + pub fn new(column: &str) -> Self { + Self { + column: column.to_owned(), + } + } +} + +impl Transformer for Flatten { + fn transform(&self, batch: &RecordBatch) -> Result { + let arr = batch.column_by_name(&self.column).ok_or(Error::Index { + message: format!("Flatten: column {} not found in RecordBatch", self.column), + location: location!(), + })?; + match arr.data_type() { + DataType::FixedSizeList(_, _) => { + // do nothing + Ok(batch.clone()) + } + DataType::List(_) => { + let row_ids = batch[ROW_ID].as_primitive::(); + let vectors = arr.as_list::(); + + let row_ids = row_ids.values().iter().zip(vectors.iter()).flat_map( + |(row_id, multivector)| { + std::iter::repeat_n( + *row_id, + multivector.map(|multivec| multivec.len()).unwrap_or(0), + ) + }, + ); + let row_ids = UInt64Array::from_iter_values(row_ids); + let vectors = vectors.values().as_fixed_size_list().clone(); + let schema = Arc::new(Schema::new(vec![ + ROW_ID_FIELD.clone(), + Field::new(self.column.as_str(), vectors.data_type().clone(), true), + ])); + let batch = + RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(vectors)])?; + Ok(batch) + } + _ => Err(Error::Index { + message: format!( + "Flatten: column {} is not a vector: {}", + self.column, + arr.data_type() + ), + location: location!(), + }), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance-index/src/vector/utils.rs b/rust/lance-index/src/vector/utils.rs index 1f5b3adca47..2a2003c691a 100644 --- a/rust/lance-index/src/vector/utils.rs +++ b/rust/lance-index/src/vector/utils.rs @@ -5,12 +5,12 @@ use arrow::{ array::AsArray, datatypes::{Float16Type, Float32Type, Float64Type}, }; -use arrow_array::{Array, FixedSizeListArray}; +use arrow_array::{Array, BooleanArray, FixedSizeListArray}; use arrow_schema::{DataType, Field}; use lance_core::{Error, Result}; use lance_io::encodings::plain::bytes_to_array; use prost::bytes; -use snafu::{location, Location}; +use snafu::location; use std::{ops::Range, sync::Arc}; use super::pb; @@ -164,6 +164,36 @@ impl TryFrom<&pb::Tensor> for FixedSizeListArray { } } +/// Check if all vectors in the FixedSizeListArray are finite +/// null values are considered as not finite +/// returns a BooleanArray +/// with the same length as the FixedSizeListArray +/// with true for finite values and false for non-finite values +pub fn is_finite(fsl: &FixedSizeListArray) -> BooleanArray { + let is_finite = fsl + .iter() + .map(|v| match v { + Some(v) => match v.data_type() { + DataType::Float16 => { + let v = v.as_primitive::(); + v.null_count() == 0 && v.values().iter().all(|v| v.is_finite()) + } + DataType::Float32 => { + let v = v.as_primitive::(); + v.null_count() == 0 && v.values().iter().all(|v| v.is_finite()) + } + DataType::Float64 => { + let v = v.as_primitive::(); + v.null_count() == 0 && v.values().iter().all(|v| v.is_finite()) + } + _ => v.null_count() == 0, + }, + None => false, + }) + .collect::>(); + BooleanArray::from(is_finite) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index 421fc014fa5..e0888d2649f 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -8,7 +8,8 @@ use std::sync::Arc; use arrow::{array::AsArray, compute::sort_to_indices}; use arrow_array::{RecordBatch, UInt32Array}; -use future::join_all; +use arrow_schema::Schema; +use future::try_join_all; use futures::prelude::*; use lance_arrow::RecordBatchExt; use lance_core::{ @@ -27,8 +28,10 @@ use lance_io::{ stream::{RecordBatchStream, RecordBatchStreamAdapter}, }; use object_store::path::Path; +use snafu::location; +use tokio::sync::Mutex; -use crate::vector::PART_ID_COLUMN; +use crate::vector::{LOSS_METADATA_KEY, PART_ID_COLUMN}; #[async_trait::async_trait] /// A reader that can read the shuffled partitions. @@ -43,6 +46,12 @@ pub trait ShuffleReader: Send + Sync { /// Get the size of the partition by partition_id fn partition_size(&self, partition_id: usize) -> Result; + + /// Get the total loss, + /// if the loss is not available, return None, + /// in such case, the caller should sum up the losses from each batch's metadata. + /// Must be called after all partitions are read. + fn total_loss(&self) -> Option; } #[async_trait::async_trait] @@ -88,6 +97,10 @@ impl Shuffler for IvfShuffler { &self, data: Box, ) -> Result> { + if self.num_partitions == 1 { + return Ok(Box::new(SinglePartitionReader::new(data))); + } + let mut writers: Vec = vec![]; let mut partition_sizes = vec![0; self.num_partitions]; let mut first_pass = true; @@ -98,6 +111,12 @@ impl Shuffler for IvfShuffler { spawn_cpu(move || { let batch = batch?; + let loss = batch + .metadata() + .get(LOSS_METADATA_KEY) + .map(|s| s.parse::().unwrap_or_default()) + .unwrap_or_default(); + let part_ids: &UInt32Array = batch .column_by_name(PART_ID_COLUMN) .expect("Partition ID column not found") @@ -110,6 +129,7 @@ impl Shuffler for IvfShuffler { .column_by_name(PART_ID_COLUMN) .expect("Partition ID column not found") .as_primitive(); + let batch = batch.drop_column(PART_ID_COLUMN)?; let mut partition_buffers = (0..num_partitions).map(|_| Vec::new()).collect::>(); @@ -127,7 +147,7 @@ impl Shuffler for IvfShuffler { start = end; } - Ok::>, Error>(partition_buffers) + Ok::<(Vec>, f64), Error>((partition_buffers, loss)) }) }) .buffered(get_num_compute_intensive_cpus()); @@ -139,8 +159,10 @@ impl Shuffler for IvfShuffler { .collect::>(); let mut counter = 0; + let mut total_loss = 0.0; while let Some(shuffled) = parallel_sort_stream.next().await { - let shuffled = shuffled?; + let (shuffled, loss) = shuffled?; + total_loss += loss; for (part_id, batches) in shuffled.into_iter().enumerate() { let part_batches = &mut partition_buffers[part_id]; @@ -171,7 +193,7 @@ impl Shuffler for IvfShuffler { ) } }) - .buffered(10) + .buffered(self.object_store.io_parallelism()) .try_collect::>() .await?; @@ -187,10 +209,7 @@ impl Shuffler for IvfShuffler { partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::(); futs.push(writer.write_batches(batches.iter())); } - join_all(futs) - .await - .into_iter() - .collect::>>()?; + try_join_all(futs).await?; partition_buffers.iter_mut().for_each(|b| b.clear()); } @@ -214,6 +233,7 @@ impl Shuffler for IvfShuffler { self.object_store.clone(), self.output_dir.clone(), partition_sizes, + total_loss, ))) } } @@ -222,6 +242,7 @@ pub struct IvfShufflerReader { scheduler: Arc, output_dir: Path, partition_sizes: Vec, + loss: f64, } impl IvfShufflerReader { @@ -229,6 +250,7 @@ impl IvfShufflerReader { object_store: Arc, output_dir: Path, partition_sizes: Vec, + loss: f64, ) -> Self { let scheduler_config = SchedulerConfig::max_bandwidth(&object_store); let scheduler = ScanScheduler::new(object_store, scheduler_config); @@ -236,6 +258,7 @@ impl IvfShufflerReader { scheduler, output_dir, partition_sizes, + loss, } } } @@ -256,13 +279,12 @@ impl ShuffleReader for IvfShufflerReader { FileReaderOptions::default(), ) .await?; - let schema = reader.schema().as_ref().into(); - + let schema: Schema = reader.schema().as_ref().into(); Ok(Some(Box::new(RecordBatchStreamAdapter::new( Arc::new(schema), reader.read_stream( lance_io::ReadBatchParams::RangeFull, - 4096, + u32::MAX, 16, FilterExpression::no_filter(), )?, @@ -272,4 +294,48 @@ impl ShuffleReader for IvfShufflerReader { fn partition_size(&self, partition_id: usize) -> Result { Ok(self.partition_sizes[partition_id]) } + + fn total_loss(&self) -> Option { + Some(self.loss) + } +} + +pub struct SinglePartitionReader { + data: Mutex>>, +} + +impl SinglePartitionReader { + pub fn new(data: Box) -> Self { + Self { + data: Mutex::new(Some(data)), + } + } +} + +#[async_trait::async_trait] +impl ShuffleReader for SinglePartitionReader { + async fn read_partition( + &self, + _partition_id: usize, + ) -> Result>> { + let mut data = self.data.lock().await; + match data.as_mut() { + Some(_) => Ok(data.take()), + None => Err(Error::Internal { + message: "the partition has been read and consumed".to_string(), + location: location!(), + }), + } + } + + fn partition_size(&self, _partition_id: usize) -> Result { + // we don't really care about the partition size + // it's used for determining the order of building the index and skipping empty partitions + // so we just return 1 here + Ok(1) + } + + fn total_loss(&self) -> Option { + None + } } diff --git a/rust/lance-index/src/vector/v3/subindex.rs b/rust/lance-index/src/vector/v3/subindex.rs index 8e8e96dbd33..a04ddd9db9f 100644 --- a/rust/lance-index/src/vector/v3/subindex.rs +++ b/rust/lance-index/src/vector/v3/subindex.rs @@ -1,21 +1,23 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; use arrow_array::{ArrayRef, RecordBatch}; use deepsize::DeepSizeOf; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; +use crate::metrics::MetricsCollector; use crate::vector::storage::VectorStore; use crate::vector::{flat, hnsw}; use crate::{prefilter::PreFilter, vector::Query}; /// A sub index for IVF index pub trait IvfSubIndex: Send + Sync + Debug + DeepSizeOf { type QueryParams: Send + Sync + for<'a> From<&'a Query>; - type BuildParams: Clone; + type BuildParams: Clone + Send + Sync; /// Load the sub index from a record batch with a single row fn load(data: RecordBatch) -> Result @@ -42,6 +44,7 @@ pub trait IvfSubIndex: Send + Sync + Debug + DeepSizeOf { params: Self::QueryParams, storage: &impl VectorStore, prefilter: Arc, + metrics: &dyn MetricsCollector, ) -> Result; /// Given a vector storage, containing all the data for the IVF partition, build the sub index. @@ -49,6 +52,10 @@ pub trait IvfSubIndex: Send + Sync + Debug + DeepSizeOf { where Self: Sized; + fn remap(&self, mapping: &HashMap>) -> Result + where + Self: Sized; + /// Encode the sub index into a record batch fn to_batch(&self) -> Result; } diff --git a/rust/lance-io/Cargo.toml b/rust/lance-io/Cargo.toml index c416f7556d0..bac880c2edb 100644 --- a/rust/lance-io/Cargo.toml +++ b/rust/lance-io/Cargo.toml @@ -13,7 +13,7 @@ rust-version.workspace = true [dependencies] -object_store = { workspace = true, features = ["aws", "gcp", "azure"] } +object_store = { workspace = true } lance-arrow.workspace = true lance-core.workspace = true arrow = { workspace = true, features = ["ffi"] } @@ -26,8 +26,8 @@ arrow-schema.workspace = true arrow-select.workspace = true async-recursion.workspace = true async-trait.workspace = true -aws-config.workspace = true -aws-credential-types.workspace = true +aws-config = { workspace = true, optional = true } +aws-credential-types = { workspace = true, optional = true } byteorder.workspace = true bytes.workspace = true chrono.workspace = true @@ -52,9 +52,7 @@ parquet.workspace = true tempfile.workspace = true test-log.workspace = true mockall.workspace = true - -[build-dependencies] -prost-build.workspace = true +rstest.workspace = true [target.'cfg(target_os = "linux")'.dev-dependencies] pprof.workspace = true @@ -64,7 +62,11 @@ name = "scheduler" harness = false [features] +default = ["aws", "azure", "gcp"] gcs-test = [] +gcp = ["object_store/gcp"] +aws = ["object_store/aws", "aws-config", "aws-credential-types"] +azure = ["object_store/azure"] [lints] workspace = true diff --git a/rust/lance-io/benches/scheduler.rs b/rust/lance-io/benches/scheduler.rs index bcb73a2695a..b536781ba4b 100644 --- a/rust/lance-io/benches/scheduler.rs +++ b/rust/lance-io/benches/scheduler.rs @@ -46,7 +46,7 @@ async fn create_data(num_bytes: u64) -> (Arc, Path) { rand::thread_rng().fill_bytes(&mut some_data); obj_store.put(&tmp_file, &some_data).await.unwrap(); - (Arc::new(obj_store), tmp_file) + (obj_store, tmp_file) } const DATA_SIZE: u64 = 128 * 1024 * 1024; diff --git a/rust/lance-io/src/encodings/binary.rs b/rust/lance-io/src/encodings/binary.rs index f8187a37174..cdfea77b416 100644 --- a/rust/lance-io/src/encodings/binary.rs +++ b/rust/lance-io/src/encodings/binary.rs @@ -26,7 +26,8 @@ use arrow_schema::DataType; use async_trait::async_trait; use bytes::Bytes; use futures::{StreamExt, TryStreamExt}; -use snafu::{location, Location}; +use lance_arrow::BufferExt; +use snafu::location; use tokio::io::AsyncWriteExt; use super::ReadBatchParams; @@ -88,7 +89,7 @@ impl<'a> BinaryEncoder<'a> { } #[async_trait] -impl<'a> Encoder for BinaryEncoder<'a> { +impl Encoder for BinaryEncoder<'_> { async fn encode(&mut self, arrs: &[&dyn Array]) -> Result { assert!(!arrs.is_empty()); let data_type = arrs[0].data_type(); @@ -224,7 +225,7 @@ impl<'a, T: ByteArrayType> BinaryDecoder<'a, T> { .null_bit_buffer(null_buf); } - let buf = bytes.into(); + let buf = Buffer::from_bytes_bytes(bytes, /*bytes_per_value=*/ 1); let array_data = data_builder .add_buffer(offset_data.buffers()[0].clone()) .add_buffer(buf) @@ -286,7 +287,7 @@ fn plan_take_chunks( } #[async_trait] -impl<'a, T: ByteArrayType> Decoder for BinaryDecoder<'a, T> { +impl Decoder for BinaryDecoder<'_, T> { async fn decode(&self) -> Result { self.get(..).await } @@ -394,7 +395,7 @@ impl<'a, T: ByteArrayType> Decoder for BinaryDecoder<'a, T> { } #[async_trait] -impl<'a, T: ByteArrayType> AsyncIndex for BinaryDecoder<'a, T> { +impl AsyncIndex for BinaryDecoder<'_, T> { type Output = Result; async fn get(&self, index: usize) -> Self::Output { @@ -403,7 +404,7 @@ impl<'a, T: ByteArrayType> AsyncIndex for BinaryDecoder<'a, T> { } #[async_trait] -impl<'a, T: ByteArrayType> AsyncIndex> for BinaryDecoder<'a, T> { +impl AsyncIndex> for BinaryDecoder<'_, T> { type Output = Result; async fn get(&self, index: RangeFrom) -> Self::Output { @@ -412,7 +413,7 @@ impl<'a, T: ByteArrayType> AsyncIndex> for BinaryDecoder<'a, T> } #[async_trait] -impl<'a, T: ByteArrayType> AsyncIndex> for BinaryDecoder<'a, T> { +impl AsyncIndex> for BinaryDecoder<'_, T> { type Output = Result; async fn get(&self, index: RangeTo) -> Self::Output { @@ -421,7 +422,7 @@ impl<'a, T: ByteArrayType> AsyncIndex> for BinaryDecoder<'a, T> { } #[async_trait] -impl<'a, T: ByteArrayType> AsyncIndex for BinaryDecoder<'a, T> { +impl AsyncIndex for BinaryDecoder<'_, T> { type Output = Result; async fn get(&self, _: RangeFull) -> Self::Output { @@ -430,7 +431,7 @@ impl<'a, T: ByteArrayType> AsyncIndex for BinaryDecoder<'a, T> { } #[async_trait] -impl<'a, T: ByteArrayType> AsyncIndex for BinaryDecoder<'a, T> { +impl AsyncIndex for BinaryDecoder<'_, T> { type Output = Result; async fn get(&self, params: ReadBatchParams) -> Self::Output { @@ -445,7 +446,7 @@ impl<'a, T: ByteArrayType> AsyncIndex for BinaryDecoder<'a, T> } #[async_trait] -impl<'a, T: ByteArrayType> AsyncIndex> for BinaryDecoder<'a, T> { +impl AsyncIndex> for BinaryDecoder<'_, T> { type Output = Result; async fn get(&self, index: Range) -> Self::Output { diff --git a/rust/lance-io/src/encodings/dictionary.rs b/rust/lance-io/src/encodings/dictionary.rs index 72b150e023e..a0652a6f0e1 100644 --- a/rust/lance-io/src/encodings/dictionary.rs +++ b/rust/lance-io/src/encodings/dictionary.rs @@ -15,7 +15,7 @@ use arrow_array::types::{ use arrow_array::{Array, ArrayRef, DictionaryArray, PrimitiveArray, UInt32Array}; use arrow_schema::DataType; use async_trait::async_trait; -use snafu::{location, Location}; +use snafu::location; use crate::{ traits::{Reader, Writer}, @@ -62,7 +62,7 @@ impl<'a> DictionaryEncoder<'a> { } #[async_trait] -impl<'a> Encoder for DictionaryEncoder<'a> { +impl Encoder for DictionaryEncoder<'_> { async fn encode(&mut self, array: &[&dyn Array]) -> Result { use DataType::*; @@ -171,7 +171,7 @@ impl<'a> DictionaryDecoder<'a> { } #[async_trait] -impl<'a> Decoder for DictionaryDecoder<'a> { +impl Decoder for DictionaryDecoder<'_> { async fn decode(&self) -> Result { self.decode_impl(..).await } @@ -182,7 +182,7 @@ impl<'a> Decoder for DictionaryDecoder<'a> { } #[async_trait] -impl<'a> AsyncIndex for DictionaryDecoder<'a> { +impl AsyncIndex for DictionaryDecoder<'_> { type Output = Result; async fn get(&self, _index: usize) -> Self::Output { @@ -196,7 +196,7 @@ impl<'a> AsyncIndex for DictionaryDecoder<'a> { } #[async_trait] -impl<'a> AsyncIndex for DictionaryDecoder<'a> { +impl AsyncIndex for DictionaryDecoder<'_> { type Output = Result; async fn get(&self, params: ReadBatchParams) -> Self::Output { diff --git a/rust/lance-io/src/encodings/plain.rs b/rust/lance-io/src/encodings/plain.rs index 4f77fde5c7c..35985005439 100644 --- a/rust/lance-io/src/encodings/plain.rs +++ b/rust/lance-io/src/encodings/plain.rs @@ -7,7 +7,6 @@ //! it stores the array directly in the file. It offers O(1) read access. use std::ops::{Range, RangeFrom, RangeFull, RangeTo}; -use std::ptr::NonNull; use std::slice::from_raw_parts; use std::sync::Arc; @@ -30,7 +29,7 @@ use bytes::Bytes; use futures::stream::{self, StreamExt, TryStreamExt}; use lance_arrow::*; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; use tokio::io::AsyncWriteExt; use crate::encodings::{AsyncIndex, Decoder}; @@ -200,25 +199,18 @@ pub fn bytes_to_array( { // this code is taken from // https://github.com/apache/arrow-rs/blob/master/arrow-data/src/data.rs#L748-L768 - let len_plus_offset = bytes.len() + offset; + let len_plus_offset = len + offset; let min_buffer_size = len_plus_offset.saturating_mul(*byte_width); // alignment or size isn't right -- just make a copy - if (bytes.len() < min_buffer_size) || (bytes.as_ptr().align_offset(*alignment) != 0) { - bytes.into() + if bytes.len() < min_buffer_size { + Buffer::copy_bytes_bytes(bytes, min_buffer_size) } else { - // SAFETY: the alignment is correct we can make this conversion - unsafe { - Buffer::from_custom_allocation( - NonNull::new(bytes.as_ptr() as _).expect("should be a valid pointer"), - bytes.len(), - Arc::new(bytes), - ) - } + Buffer::from_bytes_bytes(bytes, *alignment as u64) } } else { // cases we don't handle, just copy - bytes.into() + Buffer::from_slice_ref(bytes) }; let array_data = ArrayDataBuilder::new(data_type.clone()) @@ -401,7 +393,7 @@ fn make_chunked_requests( } #[async_trait] -impl<'a> Decoder for PlainDecoder<'a> { +impl Decoder for PlainDecoder<'_> { async fn decode(&self) -> Result { self.get(0..self.length).await } @@ -642,6 +634,25 @@ mod tests { test_round_trip(arrs.as_slice(), t).await; } + #[tokio::test] + async fn test_bytes_to_array_padding() { + let bytes = Bytes::from_static(&[0x01, 0x00, 0x02, 0x00, 0x03]); + let arr = bytes_to_array(&DataType::UInt16, bytes, 3, 0).unwrap(); + + let expected = UInt16Array::from(vec![1, 2, 3]); + assert_eq!(arr.as_ref(), &expected); + + // Underlying data is padded to the nearest multiple of two bytes (for u16). + let data = arr.to_data(); + let buf = &data.buffers()[0]; + let repr = format!("{:?}", buf); + assert!( + repr.contains("[1, 0, 2, 0, 3, 0]"), + "Underlying buffer contains unexpected data: {}", + repr + ); + } + #[tokio::test] async fn test_encode_decode_nested_fixed_size_list() { // FixedSizeList of FixedSizeList diff --git a/rust/lance-io/src/lib.rs b/rust/lance-io/src/lib.rs index 1fdea717821..b7fa7482bed 100644 --- a/rust/lance-io/src/lib.rs +++ b/rust/lance-io/src/lib.rs @@ -4,7 +4,7 @@ use std::ops::{Range, RangeFrom, RangeFull, RangeTo}; use arrow::datatypes::UInt32Type; use arrow_array::{PrimitiveArray, UInt32Array}; -use snafu::{location, Location}; +use snafu::location; use lance_core::{Error, Result}; @@ -21,6 +21,8 @@ pub mod testing; pub mod traits; pub mod utils; +pub use scheduler::{bytes_read_counter, iops_counter}; + /// Defines a selection of rows to read from a file/batch #[derive(Debug, Clone)] pub enum ReadBatchParams { @@ -198,6 +200,16 @@ impl ReadBatchParams { )), } } + + pub fn to_offsets_total(&self, total: u32) -> PrimitiveArray { + match self { + Self::Indices(indices) => indices.clone(), + Self::Range(r) => UInt32Array::from_iter_values(r.start as u32..r.end as u32), + Self::RangeFull => UInt32Array::from_iter_values(0_u32..total), + Self::RangeTo(r) => UInt32Array::from_iter_values(0..r.end as u32), + Self::RangeFrom(r) => UInt32Array::from_iter_values(r.start as u32..total), + } + } } #[cfg(test)] diff --git a/rust/lance-io/src/local.rs b/rust/lance-io/src/local.rs index 53dbd928298..f6110be5a40 100644 --- a/rust/lance-io/src/local.rs +++ b/rust/lance-io/src/local.rs @@ -19,7 +19,7 @@ use bytes::{Bytes, BytesMut}; use deepsize::DeepSizeOf; use lance_core::{Error, Result}; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tokio::io::AsyncSeekExt; use tokio::sync::OnceCell; use tracing::instrument; diff --git a/rust/lance-io/src/object_store.rs b/rust/lance-io/src/object_store.rs index f668cdfaae4..e281b3d953d 100644 --- a/rust/lance-io/src/object_store.rs +++ b/rust/lance-io/src/object_store.rs @@ -5,38 +5,37 @@ use std::collections::HashMap; use std::ops::Range; -use std::path::PathBuf; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use async_trait::async_trait; -use aws_config::default_provider::credentials::DefaultCredentialsChain; -use aws_credential_types::provider::ProvideCredentials; use bytes::Bytes; use chrono::{DateTime, Utc}; use deepsize::DeepSizeOf; use futures::{future, stream::BoxStream, StreamExt, TryStreamExt}; -use lance_core::utils::tokio::get_num_compute_intensive_cpus; -use object_store::aws::{ - AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential, AwsCredentialProvider, -}; -use object_store::gcp::GoogleCloudStorageBuilder; -use object_store::{ - aws::AmazonS3Builder, azure::AzureConfigKey, gcp::GoogleConfigKey, local::LocalFileSystem, - memory::InMemory, CredentialProvider, Error as ObjectStoreError, Result as ObjectStoreResult, -}; -use object_store::{parse_url_opts, ClientOptions, DynObjectStore, StaticCredentialProvider}; +use futures::{FutureExt, Stream}; +use lance_core::error::LanceOptionExt; +use lance_core::utils::parse::str_is_truthy; +use list_retry::ListRetryStream; +#[cfg(feature = "aws")] +use object_store::aws::AwsCredentialProvider; +use object_store::DynObjectStore; +use object_store::Error as ObjectStoreError; use object_store::{path::Path, ObjectMeta, ObjectStore as OSObjectStore}; +use providers::local::FileStoreProvider; +use providers::memory::MemoryStoreProvider; use shellexpand::tilde; -use snafu::{location, Location}; +use snafu::location; use tokio::io::AsyncWriteExt; -use tokio::sync::RwLock; use url::Url; use super::local::LocalObjectReader; +mod list_retry; +pub mod providers; mod tracing; -use self::tracing::ObjectStoreTracingExt; +use crate::object_writer::WriteResult; use crate::{object_reader::CloudObjectReader, object_writer::ObjectWriter, traits::Reader}; use lance_core::{Error, Result}; @@ -48,8 +47,14 @@ pub const DEFAULT_LOCAL_IO_PARALLELISM: usize = 8; // Cloud disks often need many many threads to saturate the network pub const DEFAULT_CLOUD_IO_PARALLELISM: usize = 64; +const DEFAULT_LOCAL_BLOCK_SIZE: usize = 4 * 1024; // 4KB block size +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +const DEFAULT_CLOUD_BLOCK_SIZE: usize = 64 * 1024; // 64KB block size + pub const DEFAULT_DOWNLOAD_RETRY_COUNT: usize = 3; +pub use providers::{ObjectStoreProvider, ObjectStoreRegistry}; + #[async_trait] pub trait ObjectStoreExt { /// Returns true if the file exists. @@ -58,20 +63,20 @@ pub trait ObjectStoreExt { /// Read all files (start from base directory) recursively /// /// unmodified_since can be specified to only return files that have not been modified since the given time. - async fn read_dir_all( - &self, + async fn read_dir_all<'a>( + &'a self, dir_path: impl Into<&Path> + Send, unmodified_since: Option>, - ) -> Result>>; + ) -> Result>>; } #[async_trait] impl ObjectStoreExt for O { - async fn read_dir_all( - &self, + async fn read_dir_all<'a>( + &'a self, dir_path: impl Into<&Path> + Send, unmodified_since: Option>, - ) -> Result>> { + ) -> Result>> { let mut output = self.list(Some(dir_path.into())); if let Some(unmodified_since_val) = unmodified_since { output = output @@ -124,212 +129,6 @@ impl std::fmt::Display for ObjectStore { } } -pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send { - fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result; -} - -#[derive(Default, Debug)] -pub struct ObjectStoreRegistry { - providers: HashMap>, -} - -impl ObjectStoreRegistry { - pub fn insert(&mut self, scheme: &str, provider: Arc) { - self.providers.insert(scheme.into(), provider); - } -} - -const AWS_CREDS_CACHE_KEY: &str = "aws_credentials"; - -/// Adapt an AWS SDK cred into object_store credentials -#[derive(Debug)] -pub struct AwsCredentialAdapter { - pub inner: Arc, - - // RefCell can't be shared across threads, so we use HashMap - cache: Arc>>>, - - // The amount of time before expiry to refresh credentials - credentials_refresh_offset: Duration, -} - -impl AwsCredentialAdapter { - pub fn new( - provider: Arc, - credentials_refresh_offset: Duration, - ) -> Self { - Self { - inner: provider, - cache: Arc::new(RwLock::new(HashMap::new())), - credentials_refresh_offset, - } - } -} - -#[async_trait] -impl CredentialProvider for AwsCredentialAdapter { - type Credential = ObjectStoreAwsCredential; - - async fn get_credential(&self) -> ObjectStoreResult> { - let cached_creds = { - let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned(); - let expired = cache_value - .clone() - .map(|cred| { - cred.expiry() - .map(|exp| { - exp.checked_sub(self.credentials_refresh_offset) - .expect("this time should always be valid") - < SystemTime::now() - }) - // no expiry is never expire - .unwrap_or(false) - }) - .unwrap_or(true); // no cred is the same as expired; - if expired { - None - } else { - cache_value.clone() - } - }; - - if let Some(creds) = cached_creds { - Ok(Arc::new(Self::Credential { - key_id: creds.access_key_id().to_string(), - secret_key: creds.secret_access_key().to_string(), - token: creds.session_token().map(|s| s.to_string()), - })) - } else { - let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err( - |e| Error::Internal { - message: format!("Failed to get AWS credentials: {}", e), - location: location!(), - }, - )?); - - self.cache - .write() - .await - .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone()); - - Ok(Arc::new(Self::Credential { - key_id: refreshed_creds.access_key_id().to_string(), - secret_key: refreshed_creds.secret_access_key().to_string(), - token: refreshed_creds.session_token().map(|s| s.to_string()), - })) - } - } -} - -/// Figure out the S3 region of the bucket. -/// -/// This resolves in order of precedence: -/// 1. The region provided in the storage options -/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket -/// -/// It can return None if no region is provided and the endpoint is set. -async fn resolve_s3_region( - url: &Url, - storage_options: &HashMap, -) -> Result> { - if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) { - Ok(Some(region.clone())) - } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() { - // If no endpoint is set, we can assume this is AWS S3 and the region - // can be resolved from the bucket. - let bucket = url.host_str().ok_or_else(|| { - Error::invalid_input( - format!("Could not parse bucket from url: {}", url), - location!(), - ) - })?; - - let mut client_options = ClientOptions::default(); - for (key, value) in storage_options { - if let AmazonS3ConfigKey::Client(client_key) = key { - client_options = client_options.with_config(*client_key, value.clone()); - } - } - - let bucket_region = - object_store::aws::resolve_bucket_region(bucket, &client_options).await?; - Ok(Some(bucket_region)) - } else { - Ok(None) - } -} - -/// Build AWS credentials -/// -/// This resolves credentials from the following sources in order: -/// 1. An explicit `credentials` provider -/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`, -/// `aws_secret_access_key`, `aws_session_token`) -/// 3. The default credential provider chain from AWS SDK. -/// -/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials. -pub async fn build_aws_credential( - credentials_refresh_offset: Duration, - credentials: Option, - storage_options: Option<&HashMap>, - region: Option, -) -> Result<(AwsCredentialProvider, String)> { - // TODO: make this return no credential provider not using AWS - use aws_config::meta::region::RegionProviderChain; - const DEFAULT_REGION: &str = "us-west-2"; - - let region = if let Some(region) = region { - region - } else { - RegionProviderChain::default_provider() - .or_else(DEFAULT_REGION) - .region() - .await - .map(|r| r.as_ref().to_string()) - .unwrap_or(DEFAULT_REGION.to_string()) - }; - - if let Some(creds) = credentials { - Ok((creds, region)) - } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) { - Ok((Arc::new(creds), region)) - } else { - let credentials_provider = DefaultCredentialsChain::builder().build().await; - - Ok(( - Arc::new(AwsCredentialAdapter::new( - Arc::new(credentials_provider), - credentials_refresh_offset, - )), - region, - )) - } -} - -fn extract_static_s3_credentials( - options: &HashMap, -) -> Option> { - let key_id = options - .get(&AmazonS3ConfigKey::AccessKeyId) - .map(|s| s.to_string()); - let secret_key = options - .get(&AmazonS3ConfigKey::SecretAccessKey) - .map(|s| s.to_string()); - let token = options - .get(&AmazonS3ConfigKey::Token) - .map(|s| s.to_string()); - match (key_id, secret_key, token) { - (Some(key_id), Some(secret_key), token) => { - Some(StaticCredentialProvider::new(ObjectStoreAwsCredential { - key_id, - secret_key, - token, - })) - } - _ => None, - } -} - pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync { fn wrap(&self, original: Arc) -> Arc; } @@ -339,8 +138,10 @@ pub trait WrappingObjectStore: std::fmt::Debug + Send + Sync { #[derive(Debug, Clone)] pub struct ObjectStoreParams { pub block_size: Option, + #[deprecated(note = "Implement an ObjectStoreProvider instead")] pub object_store: Option<(Arc, Url)>, pub s3_credentials_refresh_offset: Duration, + #[cfg(feature = "aws")] pub aws_credentials: Option, pub object_store_wrapper: Option>, pub storage_options: Option>, @@ -354,10 +155,12 @@ pub struct ObjectStoreParams { impl Default for ObjectStoreParams { fn default() -> Self { + #[allow(deprecated)] Self { object_store: None, block_size: None, s3_credentials_refresh_offset: Duration::from_secs(60), + #[cfg(feature = "aws")] aws_credentials: None, object_store_wrapper: None, storage_options: None, @@ -367,26 +170,107 @@ impl Default for ObjectStoreParams { } } -impl ObjectStoreParams { - /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials. - pub fn with_aws_credentials( - aws_credentials: Option, - region: Option, - ) -> Self { - Self { - aws_credentials, - storage_options: region - .map(|region| [("region".into(), region)].iter().cloned().collect()), - ..Default::default() +// We implement hash for caching +impl std::hash::Hash for ObjectStoreParams { + #[allow(deprecated)] + fn hash(&self, state: &mut H) { + // For hashing, we use pointer values for ObjectStore, S3 credentials, and wrapper + self.block_size.hash(state); + if let Some((store, url)) = &self.object_store { + Arc::as_ptr(store).hash(state); + url.hash(state); + } + self.s3_credentials_refresh_offset.hash(state); + #[cfg(feature = "aws")] + if let Some(aws_credentials) = &self.aws_credentials { + Arc::as_ptr(aws_credentials).hash(state); + } + if let Some(wrapper) = &self.object_store_wrapper { + Arc::as_ptr(wrapper).hash(state); + } + if let Some(storage_options) = &self.storage_options { + for (key, value) in storage_options { + key.hash(state); + value.hash(state); + } + } + self.use_constant_size_upload_parts.hash(state); + self.list_is_lexically_ordered.hash(state); + } +} + +// We implement eq for caching +impl Eq for ObjectStoreParams {} +impl PartialEq for ObjectStoreParams { + #[allow(deprecated)] + fn eq(&self, other: &Self) -> bool { + // For equality, we use pointer comparison for ObjectStore, S3 credentials, and wrapper + self.block_size == other.block_size + && self + .object_store + .as_ref() + .map(|(store, url)| (Arc::as_ptr(store), url)) + == other + .object_store + .as_ref() + .map(|(store, url)| (Arc::as_ptr(store), url)) + && self.s3_credentials_refresh_offset == other.s3_credentials_refresh_offset + && self.aws_credentials.as_ref().map(Arc::as_ptr) + == other.aws_credentials.as_ref().map(Arc::as_ptr) + && self.object_store_wrapper.as_ref().map(Arc::as_ptr) + == other.object_store_wrapper.as_ref().map(Arc::as_ptr) + && self.storage_options == other.storage_options + && self.use_constant_size_upload_parts == other.use_constant_size_upload_parts + && self.list_is_lexically_ordered == other.list_is_lexically_ordered + } +} + +fn uri_to_url(uri: &str) -> Result { + match Url::parse(uri) { + Ok(url) if url.scheme().len() == 1 && cfg!(windows) => { + // On Windows, the drive is parsed as a scheme + local_path_to_url(uri) } + Ok(url) => Ok(url), + Err(_) => local_path_to_url(uri), } } +fn expand_path(str_path: impl AsRef) -> Result { + let expanded = tilde(str_path.as_ref()).to_string(); + + let mut expanded_path = path_abs::PathAbs::new(expanded) + .unwrap() + .as_path() + .to_path_buf(); + // path_abs::PathAbs::new(".") returns an empty string. + if let Some(s) = expanded_path.as_path().to_str() { + if s.is_empty() { + expanded_path = std::env::current_dir()?; + } + } + + Ok(expanded_path) +} + +fn local_path_to_url(str_path: &str) -> Result { + let expanded_path = expand_path(str_path)?; + + Url::from_directory_path(expanded_path).map_err(|_| Error::InvalidInput { + source: format!("Invalid table location: '{}'", str_path).into(), + location: location!(), + }) +} + impl ObjectStore { /// Parse from a string URI. /// /// Returns the ObjectStore instance and the absolute path to the object. - pub async fn from_uri(uri: &str) -> Result<(Self, Path)> { + /// + /// This uses the default [ObjectStoreRegistry] to find the object store. To + /// allow for potential re-use of object store instances, it's recommended to + /// create a shared [ObjectStoreRegistry] and pass that to [Self::from_uri_and_params]. + pub async fn from_uri(uri: &str) -> Result<(Arc, Path)> { let registry = Arc::new(ObjectStoreRegistry::default()); Self::from_uri_and_params(registry, uri, &ObjectStoreParams::default()).await @@ -399,7 +283,8 @@ impl ObjectStore { registry: Arc, uri: &str, params: &ObjectStoreParams, - ) -> Result<(Self, Path)> { + ) -> Result<(Arc, Path)> { + #[allow(deprecated)] if let Some((store, path)) = params.object_store.as_ref() { let mut inner = store.clone(); if let Some(wrapper) = params.object_store_wrapper.as_ref() { @@ -415,96 +300,46 @@ impl ObjectStore { download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT, }; let path = Path::from(path.path()); - return Ok((store, path)); + return Ok((Arc::new(store), path)); } - let (object_store, path) = match Url::parse(uri) { - Ok(url) if url.scheme().len() == 1 && cfg!(windows) => { - // On Windows, the drive is parsed as a scheme - Self::from_path(uri) - } - Ok(url) => { - let store = Self::new_from_url(registry, url.clone(), params.clone()).await?; - Ok((store, Path::from(url.path()))) - } - Err(_) => Self::from_path(uri), - }?; + let url = uri_to_url(uri)?; + let store = registry.get_store(url.clone(), params).await?; + // We know the scheme is valid if we got a store back. + let provider = registry.get_provider(url.scheme()).expect_ok()?; + let path = provider.extract_path(&url); - Ok(( - Self { - inner: params - .object_store_wrapper - .as_ref() - .map(|w| w.wrap(object_store.inner.clone())) - .unwrap_or(object_store.inner), - ..object_store - }, - path, - )) + Ok((store, path)) } - pub fn from_path_with_scheme(str_path: &str, scheme: &str) -> Result<(Self, Path)> { - let expanded = tilde(str_path).to_string(); - - let mut expanded_path = path_abs::PathAbs::new(expanded) - .unwrap() - .as_path() - .to_path_buf(); - // path_abs::PathAbs::new(".") returns an empty string. - if let Some(s) = expanded_path.as_path().to_str() { - if s.is_empty() { - expanded_path = std::env::current_dir()?; - } - } - Ok(( - Self { - inner: Arc::new(LocalFileSystem::new()).traced(), - scheme: String::from(scheme), - block_size: 4 * 1024, // 4KB block size - use_constant_size_upload_parts: false, - list_is_lexically_ordered: false, - io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM, - download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT, - }, - Path::from_absolute_path(expanded_path.as_path())?, - )) - } - - pub fn from_path(str_path: &str) -> Result<(Self, Path)> { - Self::from_path_with_scheme(str_path, "file") - } - - async fn new_from_url( - registry: Arc, - url: Url, - params: ObjectStoreParams, - ) -> Result { - configure_store(registry, url.as_str(), params).await + #[deprecated(note = "Use `from_uri` instead")] + pub fn from_path(str_path: &str) -> Result<(Arc, Path)> { + Self::from_uri_and_params( + Arc::new(ObjectStoreRegistry::default()), + str_path, + &Default::default(), + ) + .now_or_never() + .unwrap() } /// Local object store. pub fn local() -> Self { - Self { - inner: Arc::new(LocalFileSystem::new()).traced(), - scheme: String::from("file"), - block_size: 4 * 1024, // 4KB block size - use_constant_size_upload_parts: false, - list_is_lexically_ordered: false, - io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM, - download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT, - } + let provider = FileStoreProvider; + provider + .new_store(Url::parse("file:///").unwrap(), &Default::default()) + .now_or_never() + .unwrap() + .unwrap() } /// Create a in-memory object store directly for testing. pub fn memory() -> Self { - Self { - inner: Arc::new(InMemory::new()).traced(), - scheme: String::from("memory"), - block_size: 64 * 1024, - use_constant_size_upload_parts: false, - list_is_lexically_ordered: true, - io_parallelism: get_num_compute_intensive_cpus(), - download_retry_count: DEFAULT_DOWNLOAD_RETRY_COUNT, - } + let provider = MemoryStoreProvider; + provider + .new_store(Url::parse("memory:///").unwrap(), &Default::default()) + .now_or_never() + .unwrap() + .unwrap() } /// Returns true if the object store pointed to a local file system. @@ -520,14 +355,6 @@ impl ObjectStore { self.block_size } - pub fn set_block_size(&mut self, new_size: usize) { - self.block_size = new_size; - } - - pub fn set_io_parallelism(&mut self, io_parallelism: usize) { - self.io_parallelism = io_parallelism; - } - pub fn io_parallelism(&self) -> usize { std::env::var("LANCE_IO_THREADS") .map(|val| val.parse::().unwrap()) @@ -572,14 +399,16 @@ impl ObjectStore { /// Create an [ObjectWriter] from local [std::path::Path] pub async fn create_local_writer(path: &std::path::Path) -> Result { let object_store = Self::local(); - let os_path = Path::from(path.to_str().unwrap()); + let absolute_path = expand_path(path.to_string_lossy())?; + let os_path = Path::from_absolute_path(absolute_path)?; object_store.create(&os_path).await } /// Open an [Reader] from local [std::path::Path] pub async fn open_local(path: &std::path::Path) -> Result> { let object_store = Self::local(); - let os_path = Path::from(path.to_str().unwrap()); + let absolute_path = expand_path(path.to_string_lossy())?; + let os_path = Path::from_absolute_path(absolute_path)?; object_store.open(&os_path).await } @@ -589,7 +418,7 @@ impl ObjectStore { } /// A helper function to create a file and write content to it. - pub async fn put(&self, path: &Path, content: &[u8]) -> Result<()> { + pub async fn put(&self, path: &Path, content: &[u8]) -> Result { let mut writer = self.create(path).await?; writer.write_all(content).await?; writer.shutdown().await @@ -617,6 +446,13 @@ impl ObjectStore { .collect()) } + pub fn list( + &self, + path: Option, + ) -> Pin> + Send>> { + Box::pin(ListRetryStream::new(self.inner.clone(), path, 5).map(|m| m.map_err(|e| e.into()))) + } + /// Read all files (start from base directory) recursively /// /// unmodified_since can be specified to only return files that have not been modified since the given time. @@ -652,7 +488,7 @@ impl ObjectStore { pub fn remove_stream<'a>( &'a self, locations: BoxStream<'a, Result>, - ) -> BoxStream> { + ) -> BoxStream<'a, Result> { self.inner .delete_stream(locations.err_into::().boxed()) .err_into::() @@ -726,49 +562,13 @@ impl StorageOptions { if let Ok(value) = std::env::var("AWS_ALLOW_HTTP") { options.insert("allow_http".into(), value); } - Self(options) - } - - /// Add values from the environment to storage options - pub fn with_env_azure(&mut self) { - for (os_key, os_value) in std::env::vars_os() { - if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { - if let Ok(config_key) = AzureConfigKey::from_str(&key.to_ascii_lowercase()) { - if !self.0.contains_key(config_key.as_ref()) { - self.0 - .insert(config_key.as_ref().to_string(), value.to_string()); - } - } - } - } - } - - /// Add values from the environment to storage options - pub fn with_env_gcs(&mut self) { - for (os_key, os_value) in std::env::vars_os() { - if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { - if let Ok(config_key) = GoogleConfigKey::from_str(&key.to_ascii_lowercase()) { - if !self.0.contains_key(config_key.as_ref()) { - self.0 - .insert(config_key.as_ref().to_string(), value.to_string()); - } - } - } + if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_MAX_RETRIES") { + options.insert("client_max_retries".into(), value); } - } - - /// Add values from the environment to storage options - pub fn with_env_s3(&mut self) { - for (os_key, os_value) in std::env::vars_os() { - if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { - if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) { - if !self.0.contains_key(config_key.as_ref()) { - self.0 - .insert(config_key.as_ref().to_string(), value.to_string()); - } - } - } + if let Ok(value) = std::env::var("OBJECT_STORE_CLIENT_RETRY_TIMEOUT") { + options.insert("client_retry_timeout".into(), value); } + Self(options) } /// Denotes if unsecure connections via http are allowed @@ -782,42 +582,31 @@ impl StorageOptions { pub fn download_retry_count(&self) -> usize { self.0 .iter() - .find(|(key, _)| key.to_ascii_lowercase() == "download_retry_count") + .find(|(key, _)| key.eq_ignore_ascii_case("download_retry_count")) .map(|(_, value)| value.parse::().unwrap_or(3)) .unwrap_or(3) } - /// Subset of options relevant for azure storage - pub fn as_azure_options(&self) -> HashMap { + /// Max retry times to set in RetryConfig for object store client + pub fn client_max_retries(&self) -> usize { self.0 .iter() - .filter_map(|(key, value)| { - let az_key = AzureConfigKey::from_str(&key.to_ascii_lowercase()).ok()?; - Some((az_key, value.clone())) - }) - .collect() + .find(|(key, _)| key.eq_ignore_ascii_case("client_max_retries")) + .and_then(|(_, value)| value.parse::().ok()) + .unwrap_or(10) } - /// Subset of options relevant for s3 storage - pub fn as_s3_options(&self) -> HashMap { + /// Seconds of timeout to set in RetryConfig for object store client + pub fn client_retry_timeout(&self) -> u64 { self.0 .iter() - .filter_map(|(key, value)| { - let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?; - Some((s3_key, value.clone())) - }) - .collect() + .find(|(key, _)| key.eq_ignore_ascii_case("client_retry_timeout")) + .and_then(|(_, value)| value.parse::().ok()) + .unwrap_or(180) } - /// Subset of options relevant for gcs storage - pub fn as_gcs_options(&self) -> HashMap { - self.0 - .iter() - .filter_map(|(key, value)| { - let gcs_key = GoogleConfigKey::from_str(&key.to_ascii_lowercase()).ok()?; - Some((gcs_key, value.clone())) - }) - .collect() + pub fn get(&self, key: &str) -> Option<&String> { + self.0.get(key) } } @@ -827,139 +616,6 @@ impl From> for StorageOptions { } } -async fn configure_store( - registry: Arc, - url: &str, - options: ObjectStoreParams, -) -> Result { - let mut storage_options = StorageOptions(options.storage_options.clone().unwrap_or_default()); - let download_retry_count = storage_options.download_retry_count(); - let mut url = ensure_table_uri(url)?; - // Block size: On local file systems, we use 4KB block size. On cloud - // object stores, we use 64KB block size. This is generally the largest - // block size where we don't see a latency penalty. - match url.scheme() { - "s3" | "s3+ddb" => { - storage_options.with_env_s3(); - - // if url.scheme() == "s3+ddb" && options.commit_handler.is_some() { - // return Err(Error::InvalidInput { - // source: "`s3+ddb://` scheme and custom commit handler are mutually exclusive" - // .into(), - // location: location!(), - // }); - // } - - let storage_options = storage_options.as_s3_options(); - let region = resolve_s3_region(&url, &storage_options).await?; - let (aws_creds, region) = build_aws_credential( - options.s3_credentials_refresh_offset, - options.aws_credentials.clone(), - Some(&storage_options), - region, - ) - .await?; - - // Cloudflare does not support varying part sizes. - let use_constant_size_upload_parts = storage_options - .get(&AmazonS3ConfigKey::Endpoint) - .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com")) - .unwrap_or(false); - - // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts - url.set_scheme("s3").map_err(|()| Error::Internal { - message: "could not set scheme".into(), - location: location!(), - })?; - - url.set_query(None); - - // we can't use parse_url_opts here because we need to manually set the credentials provider - let mut builder = AmazonS3Builder::new(); - for (key, value) in storage_options { - builder = builder.with_config(key, value); - } - builder = builder - .with_url(url.as_ref()) - .with_credentials(aws_creds) - .with_region(region); - let store = builder.build()?; - - Ok(ObjectStore { - inner: Arc::new(store), - scheme: String::from(url.scheme()), - block_size: 64 * 1024, - use_constant_size_upload_parts, - list_is_lexically_ordered: true, - io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, - download_retry_count, - }) - } - "gs" => { - storage_options.with_env_gcs(); - let mut builder = GoogleCloudStorageBuilder::new().with_url(url.as_ref()); - for (key, value) in storage_options.as_gcs_options() { - builder = builder.with_config(key, value); - } - let store = builder.build()?; - let store = Arc::new(store); - - Ok(ObjectStore { - inner: store, - scheme: String::from("gs"), - block_size: 64 * 1024, - use_constant_size_upload_parts: false, - list_is_lexically_ordered: true, - io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, - download_retry_count, - }) - } - "az" => { - storage_options.with_env_azure(); - let (store, _) = parse_url_opts(&url, storage_options.as_azure_options())?; - let store = Arc::new(store); - - Ok(ObjectStore { - inner: store, - scheme: String::from("az"), - block_size: 64 * 1024, - use_constant_size_upload_parts: false, - list_is_lexically_ordered: true, - io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, - download_retry_count, - }) - } - // we have a bypass logic to use `tokio::fs` directly to lower overhead - // however this makes testing harder as we can't use the same code path - // "file-object-store" forces local file system dataset to use the same - // code path as cloud object stores - "file" => Ok(ObjectStore::from_path(url.path())?.0), - "file-object-store" => { - Ok(ObjectStore::from_path_with_scheme(url.path(), "file-object-store")?.0) - } - "memory" => Ok(ObjectStore { - inner: Arc::new(InMemory::new()).traced(), - scheme: String::from("memory"), - block_size: 64 * 1024, - use_constant_size_upload_parts: false, - list_is_lexically_ordered: true, - io_parallelism: get_num_compute_intensive_cpus(), - download_retry_count, - }), - unknown_scheme => { - if let Some(provider) = registry.providers.get(unknown_scheme) { - provider.new_store(url, &options) - } else { - let err = lance_core::Error::from(object_store::Error::NotSupported { - source: format!("Unsupported URI scheme: {} in url {}", unknown_scheme, url) - .into(), - }); - Err(err) - } - } - } -} - impl ObjectStore { #[allow(clippy::too_many_arguments)] pub fn new( @@ -1002,88 +658,12 @@ fn infer_block_size(scheme: &str) -> usize { } } -fn str_is_truthy(val: &str) -> bool { - val.eq_ignore_ascii_case("1") - | val.eq_ignore_ascii_case("true") - | val.eq_ignore_ascii_case("on") - | val.eq_ignore_ascii_case("yes") - | val.eq_ignore_ascii_case("y") -} - -/// Attempt to create a Url from given table location. -/// -/// The location could be: -/// * A valid URL, which will be parsed and returned -/// * A path to a directory, which will be created and then converted to a URL. -/// -/// If it is a local path, it will be created if it doesn't exist. -/// -/// Extra slashes will be removed from the end path as well. -/// -/// Will return an error if the location is not valid. For example, -pub fn ensure_table_uri(table_uri: impl AsRef) -> Result { - let table_uri = table_uri.as_ref(); - - enum UriType { - LocalPath(PathBuf), - Url(Url), - } - let uri_type: UriType = if let Ok(url) = Url::parse(table_uri) { - if url.scheme() == "file" { - UriType::LocalPath(url.to_file_path().map_err(|err| { - let msg = format!("Invalid table location: {}\nError: {:?}", table_uri, err); - Error::InvalidTableLocation { message: msg } - })?) - // NOTE this check is required to support absolute windows paths which may properly parse as url - } else { - UriType::Url(url) - } - } else { - UriType::LocalPath(PathBuf::from(table_uri)) - }; - - // If it is a local path, we need to create it if it does not exist. - let mut url = match uri_type { - UriType::LocalPath(path) => { - let path = std::fs::canonicalize(path).map_err(|err| Error::DatasetNotFound { - path: table_uri.to_string(), - source: Box::new(err), - location: location!(), - })?; - Url::from_directory_path(path).map_err(|_| { - let msg = format!( - "Could not construct a URL from canonicalized path: {}.\n\ - Something must be very wrong with the table path.", - table_uri - ); - Error::InvalidTableLocation { message: msg } - })? - } - UriType::Url(url) => url, - }; - - let trimmed_path = url.path().trim_end_matches('/').to_owned(); - url.set_path(&trimmed_path); - Ok(url) -} - -lazy_static::lazy_static! { - static ref KNOWN_SCHEMES: Vec<&'static str> = - Vec::from([ - "s3", - "s3+ddb", - "gs", - "az", - "file", - "file-object-store", - "memory" - ]); -} - #[cfg(test)] mod tests { use super::*; + use object_store::memory::InMemory; use parquet::data_type::AsBytes; + use rstest::rstest; use std::env::set_current_dir; use std::fs::{create_dir_all, write}; use std::path::Path as StdPath; @@ -1097,7 +677,7 @@ mod tests { write(path, contents) } - async fn read_from_store(store: ObjectStore, path: &Path) -> Result { + async fn read_from_store(store: &ObjectStore, path: &Path) -> Result { let test_file_store = store.open(path).await.unwrap(); let size = test_file_store.size().await.unwrap(); let bytes = test_file_store.get_range(0..size).await.unwrap(); @@ -1122,7 +702,7 @@ mod tests { format!("{tmp_path}/bar/foo.lance/../foo.lance"), ] { let (store, path) = ObjectStore::from_uri(uri).await.unwrap(); - let contents = read_from_store(store, &path.child("test_file")) + let contents = read_from_store(store.as_ref(), &path.child("test_file")) .await .unwrap(); assert_eq!(contents, "TEST_CONTENT"); @@ -1149,6 +729,65 @@ mod tests { assert_eq!(path.to_string(), "foo.lance"); } + async fn test_block_size_used_test_helper( + uri: &str, + storage_options: Option>, + default_expected_block_size: usize, + ) { + // Test the default + let registry = Arc::new(ObjectStoreRegistry::default()); + let params = ObjectStoreParams { + storage_options: storage_options.clone(), + ..ObjectStoreParams::default() + }; + let (store, _) = ObjectStore::from_uri_and_params(registry, uri, ¶ms) + .await + .unwrap(); + assert_eq!(store.block_size, default_expected_block_size); + + // Ensure param is used + let registry = Arc::new(ObjectStoreRegistry::default()); + let params = ObjectStoreParams { + block_size: Some(1024), + storage_options: storage_options.clone(), + ..ObjectStoreParams::default() + }; + let (store, _) = ObjectStore::from_uri_and_params(registry, uri, ¶ms) + .await + .unwrap(); + assert_eq!(store.block_size, 1024); + } + + #[rstest] + #[case("s3://bucket/foo.lance", None)] + #[case("gs://bucket/foo.lance", None)] + #[case("az://account/bucket/foo.lance", + Some(HashMap::from([ + (String::from("account_name"), String::from("account")), + (String::from("container_name"), String::from("container")) + ])))] + #[tokio::test] + async fn test_block_size_used_cloud( + #[case] uri: &str, + #[case] storage_options: Option>, + ) { + test_block_size_used_test_helper(uri, storage_options, 64 * 1024).await; + } + + #[rstest] + #[case("file")] + #[case("file-object-store")] + #[case("memory:///bucket/foo.lance")] + #[tokio::test] + async fn test_block_size_used_file(#[case] prefix: &str) { + let tmp_dir = tempfile::tempdir().unwrap(); + let tmp_path = tmp_dir.path().to_str().unwrap().to_owned(); + let path = format!("{tmp_path}/bar/foo.lance/test_file"); + write_to_file(&path, "URL").unwrap(); + let uri = format!("{prefix}:///{path}"); + test_block_size_used_test_helper(&uri, None, 4 * 1024).await; + } + #[tokio::test] async fn test_relative_paths() { let tmp_dir = tempfile::tempdir().unwrap(); @@ -1162,7 +801,7 @@ mod tests { set_current_dir(StdPath::new(&tmp_path)).expect("Error changing current dir"); let (store, path) = ObjectStore::from_uri("./bar/foo.lance").await.unwrap(); - let contents = read_from_store(store, &path.child("test_file")) + let contents = read_from_store(store.as_ref(), &path.child("test_file")) .await .unwrap(); assert_eq!(contents, "RELATIVE_URL"); @@ -1173,7 +812,7 @@ mod tests { let uri = "~/foo.lance"; write_to_file(&format!("{uri}/test_file"), "TILDE").unwrap(); let (store, path) = ObjectStore::from_uri(uri).await.unwrap(); - let contents = read_from_store(store, &path.child("test_file")) + let contents = read_from_store(store.as_ref(), &path.child("test_file")) .await .unwrap(); assert_eq!(contents, "TILDE"); @@ -1275,54 +914,6 @@ mod tests { assert_eq!(Arc::strong_count(&mock_inner_store), 2); } - #[derive(Debug, Default)] - struct MockAwsCredentialsProvider { - called: AtomicBool, - } - - #[async_trait] - impl CredentialProvider for MockAwsCredentialsProvider { - type Credential = ObjectStoreAwsCredential; - - async fn get_credential(&self) -> ObjectStoreResult> { - self.called.store(true, Ordering::Relaxed); - Ok(Arc::new(Self::Credential { - key_id: "".to_string(), - secret_key: "".to_string(), - token: None, - })) - } - } - - #[tokio::test] - async fn test_injected_aws_creds_option_is_used() { - let mock_provider = Arc::new(MockAwsCredentialsProvider::default()); - let registry = Arc::new(ObjectStoreRegistry::default()); - - let params = ObjectStoreParams { - aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider), - ..ObjectStoreParams::default() - }; - - // Not called yet - assert!(!mock_provider.called.load(Ordering::Relaxed)); - - let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms) - .await - .unwrap(); - - // fails, but we don't care - let _ = store - .open(&Path::parse("/").unwrap()) - .await - .unwrap() - .get_range(0..1) - .await; - - // Not called yet - assert!(mock_provider.called.load(Ordering::Relaxed)); - } - #[tokio::test] async fn test_local_paths() { let temp_dir = tempfile::tempdir().unwrap(); @@ -1396,7 +987,7 @@ mod tests { format!("{drive_letter}:\\test_folder\\test.lance"), ] { let (store, base) = ObjectStore::from_uri(uri).await.unwrap(); - let contents = read_from_store(store, &base.child("test_file")) + let contents = read_from_store(store.as_ref(), &base.child("test_file")) .await .unwrap(); assert_eq!(contents, "WINDOWS"); diff --git a/rust/lance-io/src/object_store/list_retry.rs b/rust/lance-io/src/object_store/list_retry.rs new file mode 100644 index 00000000000..64e6eb1e5d2 --- /dev/null +++ b/rust/lance-io/src/object_store/list_retry.rs @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{pin::Pin, sync::Arc, task::Poll}; + +use futures::{Stream, StreamExt}; +use object_store::{path::Path, ObjectMeta, ObjectStore}; +use tokio::task::JoinHandle; + +/// ObjectStore::list() and ObjectStore::list_with_offset() return a stream +/// where the lifetime is tied to the object store. This makes it hard to wrap. +/// So here we put it inside a tokio task and return a channel receiver. +struct StaticListStream { + rx: tokio::sync::mpsc::Receiver>, + handle: JoinHandle<()>, +} + +impl StaticListStream { + fn new(object_store: Arc, prefix: Option, offset: Option) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(100); + let handle = tokio::spawn(async move { + let mut stream = if let Some(offset) = offset { + object_store.list_with_offset(prefix.as_ref(), &offset) + } else { + object_store.list(prefix.as_ref()) + }; + while let Some(item) = stream.next().await { + if tx.send(item).await.is_err() { + break; + } + } + }); + Self { rx, handle } + } + + fn abort(&self) { + self.handle.abort(); + } +} + +impl Stream for StaticListStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match this.rx.poll_recv(cx) { + Poll::Ready(Some(item)) => Poll::Ready(Some(item)), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending if this.handle.is_finished() => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +/// A stream that does outer retries on list operations. +/// +/// This is to handle request responses that ObjectStore doesn't handle, such as +/// the error `error decoding response body` from queries to GCS. +pub struct ListRetryStream { + object_store: Arc, + current_stream: StaticListStream, + prefix: Option, + last_successful_key: Option, + max_retries: usize, + current_retries: usize, +} + +impl ListRetryStream { + pub fn new( + object_store: Arc, + prefix: Option, + max_retries: usize, + ) -> Self { + let current_stream = StaticListStream::new(object_store.clone(), prefix.clone(), None); + Self { + object_store, + current_stream, + prefix, + last_successful_key: None, + max_retries, + current_retries: 0, + } + } + + fn is_retryable(error: &object_store::Error) -> bool { + !matches!( + error, + object_store::Error::NotFound { .. } + | object_store::Error::InvalidPath { .. } + | object_store::Error::NotSupported { .. } + | object_store::Error::NotImplemented + ) + } +} + +impl Stream for ListRetryStream { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + loop { + match this.current_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(meta))) => { + this.last_successful_key = Some(meta.location.clone()); + return Poll::Ready(Some(Ok(meta))); + } + Poll::Ready(None) => { + // If the stream is done, return None + return Poll::Ready(None); + } + Poll::Ready(Some(Err(error))) if Self::is_retryable(&error) => { + if this.current_retries < this.max_retries { + this.current_retries += 1; + + this.current_stream.abort(); + this.current_stream = StaticListStream::new( + this.object_store.clone(), + this.prefix.clone(), + this.last_successful_key.clone(), + ); + + continue; + } else { + return Poll::Ready(Some(Err(error))); + } + } + Poll::Ready(Some(Err(error))) => { + return Poll::Ready(Some(Err(error))); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_send() {} + + #[test] + fn test_list_retry_stream_send() { + // Ensure that ListRetryStream is Send + assert_send::(); + } +} diff --git a/rust/lance-io/src/object_store/providers.rs b/rust/lance-io/src/object_store/providers.rs new file mode 100644 index 00000000000..d07e4719ba2 --- /dev/null +++ b/rust/lance-io/src/object_store/providers.rs @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{ + collections::HashMap, + sync::{Arc, RwLock, Weak}, +}; + +use object_store::path::Path; +use snafu::location; +use url::Url; + +use super::{tracing::ObjectStoreTracingExt, ObjectStore, ObjectStoreParams}; +use lance_core::error::{Error, LanceOptionExt, Result}; + +#[cfg(feature = "aws")] +pub mod aws; +#[cfg(feature = "azure")] +pub mod azure; +#[cfg(feature = "gcp")] +pub mod gcp; +pub mod local; +pub mod memory; + +#[async_trait::async_trait] +pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send { + async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result; + + /// Extract the path relative to the base of the store. + /// + /// For example, in S3 the path is relative to the bucket. So a URL of + /// `s3://bucket/path/to/file` would return `path/to/file`. + /// + /// Meanwhile, for a file store, the path is relative to the filesystem root. + /// So a URL of `file:///path/to/file` would return `/path/to/file`. + fn extract_path(&self, url: &Url) -> Path { + Path::from(url.path()) + } +} + +/// A registry of object store providers. +/// +/// Use [`Self::default()`] to create one with the available default providers. +/// This includes (depending on features enabled): +/// - `memory`: An in-memory object store. +/// - `file`: A local file object store, with optimized code paths. +/// - `file-object-store`: A local file object store that uses the ObjectStore API, +/// for all operations. Used for testing with ObjectStore wrappers. +/// - `s3`: An S3 object store. +/// - `s3+ddb`: An S3 object store with DynamoDB for metadata. +/// - `az`: An Azure Blob Storage object store. +/// - `gs`: A Google Cloud Storage object store. +/// +/// Use [`Self::empty()`] to create an empty registry, with no providers registered. +/// +/// The registry also caches object stores that are currently in use. It holds +/// weak references to the object stores, so they are not held onto. If an object +/// store is no longer in use, it will be removed from the cache on the next +/// call to either [`Self::active_stores()`] or [`Self::get_store()`]. +#[derive(Debug)] +pub struct ObjectStoreRegistry { + providers: RwLock>>, + // Cache of object stores currently in use. We use a weak reference so the + // cache itself doesn't keep them alive if no object store is actually using + // it. + active_stores: RwLock>>, +} + +/// Convert a URL to a cache key. +/// +/// We truncate to the first path segment. This should capture +/// buckets and prefixes. We keep URL params since those might be +/// important. +/// +/// * s3://bucket/path?param=value -> s3://bucket/path?param=value +/// * file:///path/to/file -> file:/// +fn cache_url(url: &Url) -> String { + if ["file", "file-object-store", "memory"].contains(&url.scheme()) { + // For file URLs, we want to cache the URL without the path. + // This is because the path can be different for different + // object stores, but we want to cache the object store itself. + format!("{}://", url.scheme()) + } else { + // Bucket is parsed as domain, so we just drop the path. + let mut url = url.clone(); + url.set_path(""); + url.to_string() + } +} + +impl ObjectStoreRegistry { + /// Create a new registry with no providers registered. + /// + /// Typically, you want to use [`Self::default()`] instead, so you get the + /// default providers. + pub fn empty() -> Self { + Self { + providers: RwLock::new(HashMap::new()), + active_stores: RwLock::new(HashMap::new()), + } + } + + /// Get the object store provider for a given scheme. + pub fn get_provider(&self, scheme: &str) -> Option> { + self.providers + .read() + .expect("ObjectStoreRegistry lock poisoned") + .get(scheme) + .cloned() + } + + /// Get a list of all active object stores. + /// + /// Calling this will also clean up any weak references to object stores that + /// are no longer valid. + pub fn active_stores(&self) -> Vec> { + let mut found_inactive = false; + let output = self + .active_stores + .read() + .expect("ObjectStoreRegistry lock poisoned") + .values() + .filter_map(|weak| match weak.upgrade() { + Some(store) => Some(store), + None => { + found_inactive = true; + None + } + }) + .collect(); + + if found_inactive { + // Clean up the cache by removing any weak references that are no longer valid + let mut cache_lock = self + .active_stores + .write() + .expect("ObjectStoreRegistry lock poisoned"); + cache_lock.retain(|_, weak| weak.upgrade().is_some()); + } + output + } + + /// Get an object store for a given base path and parameters. + /// + /// If the object store is already in use, it will return a strong reference + /// to the object store. If the object store is not in use, it will create a + /// new object store and return a strong reference to it. + pub async fn get_store( + &self, + base_path: Url, + params: &ObjectStoreParams, + ) -> Result> { + let cache_path = cache_url(&base_path); + let cache_key = (cache_path, params.clone()); + + // Check if we have a cached store for this base path and params + { + let maybe_store = self + .active_stores + .read() + .ok() + .expect_ok()? + .get(&cache_key) + .cloned(); + if let Some(store) = maybe_store { + if let Some(store) = store.upgrade() { + return Ok(store); + } else { + // Remove the weak reference if it is no longer valid + let mut cache_lock = self + .active_stores + .write() + .expect("ObjectStoreRegistry lock poisoned"); + if let Some(store) = cache_lock.get(&cache_key) { + if store.upgrade().is_none() { + // Remove the weak reference if it is no longer valid + cache_lock.remove(&cache_key); + } + } + } + } + } + + let scheme = base_path.scheme(); + let Some(provider) = self.get_provider(scheme) else { + let mut message = format!("No object store provider found for scheme: '{}'", scheme); + if let Ok(providers) = self.providers.read() { + let valid_schemes = providers.keys().cloned().collect::>().join(", "); + message.push_str(&format!("\nValid schemes: {}", valid_schemes)); + } + + return Err(Error::invalid_input(message, location!())); + }; + let mut store = provider.new_store(base_path, params).await?; + + store.inner = store.inner.traced(); + + if let Some(wrapper) = ¶ms.object_store_wrapper { + store.inner = wrapper.wrap(store.inner); + } + + let store = Arc::new(store); + + { + // Insert the store into the cache + let mut cache_lock = self.active_stores.write().ok().expect_ok()?; + cache_lock.insert(cache_key, Arc::downgrade(&store)); + } + + Ok(store) + } +} + +impl Default for ObjectStoreRegistry { + fn default() -> Self { + let mut providers: HashMap> = HashMap::new(); + + providers.insert("memory".into(), Arc::new(memory::MemoryStoreProvider)); + providers.insert("file".into(), Arc::new(local::FileStoreProvider)); + // The "file" scheme has special optimized code paths that bypass + // the ObjectStore API for better performance. However, this can make it + // hard to test when using ObjectStore wrappers, such as IOTrackingStore. + // So we provide a "file-object-store" scheme that uses the ObjectStore API. + // The specialized code paths are differentiated by the scheme name. + providers.insert( + "file-object-store".into(), + Arc::new(local::FileStoreProvider), + ); + + #[cfg(feature = "aws")] + { + let aws = Arc::new(aws::AwsStoreProvider); + providers.insert("s3".into(), aws.clone()); + providers.insert("s3+ddb".into(), aws); + } + #[cfg(feature = "azure")] + providers.insert("az".into(), Arc::new(azure::AzureBlobStoreProvider)); + #[cfg(feature = "gcp")] + providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider)); + Self { + providers: RwLock::new(providers), + active_stores: RwLock::new(HashMap::new()), + } + } +} + +impl ObjectStoreRegistry { + /// Add a new object store provider to the registry. The provider will be used + /// in [`Self::get_store()`] when a URL is passed with a matching scheme. + pub fn insert(&self, scheme: &str, provider: Arc) { + self.providers + .write() + .expect("ObjectStoreRegistry lock poisoned") + .insert(scheme.into(), provider); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cache_url() { + let cases = [ + ("s3://bucket/path?param=value", "s3://bucket?param=value"), + ("file:///path/to/file", "file://"), + ("file-object-store:///path/to/file", "file-object-store://"), + ("memory:///", "memory://"), + ( + "http://example.com/path?param=value", + "http://example.com/?param=value", + ), + ]; + + for (url, expected_cache_url) in cases { + let url = Url::parse(url).unwrap(); + let cache_url = cache_url(&url); + assert_eq!(cache_url, expected_cache_url); + } + } +} diff --git a/rust/lance-io/src/object_store/providers/aws.rs b/rust/lance-io/src/object_store/providers/aws.rs new file mode 100644 index 00000000000..d1dbbc7ac58 --- /dev/null +++ b/rust/lance-io/src/object_store/providers/aws.rs @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{ + collections::HashMap, + str::FromStr, + sync::Arc, + time::{Duration, SystemTime}, +}; + +use aws_config::default_provider::credentials::DefaultCredentialsChain; +use aws_credential_types::provider::ProvideCredentials; +use object_store::{ + aws::{ + AmazonS3Builder, AmazonS3ConfigKey, AwsCredential as ObjectStoreAwsCredential, + AwsCredentialProvider, + }, + ClientOptions, CredentialProvider, Result as ObjectStoreResult, RetryConfig, + StaticCredentialProvider, +}; +use snafu::location; +use tokio::sync::RwLock; +use url::Url; + +use crate::object_store::{ + ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE, + DEFAULT_CLOUD_IO_PARALLELISM, +}; +use lance_core::error::{Error, Result}; + +#[derive(Default, Debug)] +pub struct AwsStoreProvider; + +#[async_trait::async_trait] +impl ObjectStoreProvider for AwsStoreProvider { + async fn new_store( + &self, + mut base_path: Url, + params: &ObjectStoreParams, + ) -> Result { + let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE); + let mut storage_options = + StorageOptions(params.storage_options.clone().unwrap_or_default()); + let download_retry_count = storage_options.download_retry_count(); + + let max_retries = storage_options.client_max_retries(); + let retry_timeout = storage_options.client_retry_timeout(); + let retry_config = RetryConfig { + backoff: Default::default(), + max_retries, + retry_timeout: Duration::from_secs(retry_timeout), + }; + + storage_options.with_env_s3(); + + let mut storage_options = storage_options.as_s3_options(); + let region = resolve_s3_region(&base_path, &storage_options).await?; + let (aws_creds, region) = build_aws_credential( + params.s3_credentials_refresh_offset, + params.aws_credentials.clone(), + Some(&storage_options), + region, + ) + .await?; + + // This will be default in next version of object store. + // https://github.com/apache/arrow-rs/pull/7181 + // We can do this when we upgrade to 0.12. + storage_options + .entry(AmazonS3ConfigKey::ConditionalPut) + .or_insert_with(|| "etag".to_string()); + + // Cloudflare does not support varying part sizes. + let use_constant_size_upload_parts = storage_options + .get(&AmazonS3ConfigKey::Endpoint) + .map(|endpoint| endpoint.contains("r2.cloudflarestorage.com")) + .unwrap_or(false); + + // before creating the OSObjectStore we need to rewrite the url to drop ddb related parts + base_path.set_scheme("s3").unwrap(); + base_path.set_query(None); + + // we can't use parse_url_opts here because we need to manually set the credentials provider + let mut builder = AmazonS3Builder::new(); + for (key, value) in storage_options { + builder = builder.with_config(key, value); + } + builder = builder + .with_url(base_path.as_ref()) + .with_credentials(aws_creds) + .with_retry(retry_config) + .with_region(region); + let inner = Arc::new(builder.build()?); + + Ok(ObjectStore { + inner, + scheme: String::from(base_path.scheme()), + block_size, + use_constant_size_upload_parts, + list_is_lexically_ordered: true, + io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, + download_retry_count, + }) + } +} + +/// Figure out the S3 region of the bucket. +/// +/// This resolves in order of precedence: +/// 1. The region provided in the storage options +/// 2. (If endpoint is not set), the region returned by the S3 API for the bucket +/// +/// It can return None if no region is provided and the endpoint is set. +async fn resolve_s3_region( + url: &Url, + storage_options: &HashMap, +) -> Result> { + if let Some(region) = storage_options.get(&AmazonS3ConfigKey::Region) { + Ok(Some(region.clone())) + } else if storage_options.get(&AmazonS3ConfigKey::Endpoint).is_none() { + // If no endpoint is set, we can assume this is AWS S3 and the region + // can be resolved from the bucket. + let bucket = url.host_str().ok_or_else(|| { + Error::invalid_input( + format!("Could not parse bucket from url: {}", url), + location!(), + ) + })?; + + let mut client_options = ClientOptions::default(); + for (key, value) in storage_options { + if let AmazonS3ConfigKey::Client(client_key) = key { + client_options = client_options.with_config(*client_key, value.clone()); + } + } + + let bucket_region = + object_store::aws::resolve_bucket_region(bucket, &client_options).await?; + Ok(Some(bucket_region)) + } else { + Ok(None) + } +} + +/// Build AWS credentials +/// +/// This resolves credentials from the following sources in order: +/// 1. An explicit `credentials` provider +/// 2. Explicit credentials in storage_options (as in `aws_access_key_id`, +/// `aws_secret_access_key`, `aws_session_token`) +/// 3. The default credential provider chain from AWS SDK. +/// +/// `credentials_refresh_offset` is the amount of time before expiry to refresh credentials. +pub async fn build_aws_credential( + credentials_refresh_offset: Duration, + credentials: Option, + storage_options: Option<&HashMap>, + region: Option, +) -> Result<(AwsCredentialProvider, String)> { + // TODO: make this return no credential provider not using AWS + use aws_config::meta::region::RegionProviderChain; + const DEFAULT_REGION: &str = "us-west-2"; + + let region = if let Some(region) = region { + region + } else { + RegionProviderChain::default_provider() + .or_else(DEFAULT_REGION) + .region() + .await + .map(|r| r.as_ref().to_string()) + .unwrap_or(DEFAULT_REGION.to_string()) + }; + + if let Some(creds) = credentials { + Ok((creds, region)) + } else if let Some(creds) = storage_options.and_then(extract_static_s3_credentials) { + Ok((Arc::new(creds), region)) + } else { + let credentials_provider = DefaultCredentialsChain::builder().build().await; + + Ok(( + Arc::new(AwsCredentialAdapter::new( + Arc::new(credentials_provider), + credentials_refresh_offset, + )), + region, + )) + } +} + +fn extract_static_s3_credentials( + options: &HashMap, +) -> Option> { + let key_id = options + .get(&AmazonS3ConfigKey::AccessKeyId) + .map(|s| s.to_string()); + let secret_key = options + .get(&AmazonS3ConfigKey::SecretAccessKey) + .map(|s| s.to_string()); + let token = options + .get(&AmazonS3ConfigKey::Token) + .map(|s| s.to_string()); + match (key_id, secret_key, token) { + (Some(key_id), Some(secret_key), token) => { + Some(StaticCredentialProvider::new(ObjectStoreAwsCredential { + key_id, + secret_key, + token, + })) + } + _ => None, + } +} + +/// Adapt an AWS SDK cred into object_store credentials +#[derive(Debug)] +pub struct AwsCredentialAdapter { + pub inner: Arc, + + // RefCell can't be shared across threads, so we use HashMap + cache: Arc>>>, + + // The amount of time before expiry to refresh credentials + credentials_refresh_offset: Duration, +} + +impl AwsCredentialAdapter { + pub fn new( + provider: Arc, + credentials_refresh_offset: Duration, + ) -> Self { + Self { + inner: provider, + cache: Arc::new(RwLock::new(HashMap::new())), + credentials_refresh_offset, + } + } +} + +const AWS_CREDS_CACHE_KEY: &str = "aws_credentials"; + +#[async_trait::async_trait] +impl CredentialProvider for AwsCredentialAdapter { + type Credential = ObjectStoreAwsCredential; + + async fn get_credential(&self) -> ObjectStoreResult> { + let cached_creds = { + let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned(); + let expired = cache_value + .clone() + .map(|cred| { + cred.expiry() + .map(|exp| { + exp.checked_sub(self.credentials_refresh_offset) + .expect("this time should always be valid") + < SystemTime::now() + }) + // no expiry is never expire + .unwrap_or(false) + }) + .unwrap_or(true); // no cred is the same as expired; + if expired { + None + } else { + cache_value.clone() + } + }; + + if let Some(creds) = cached_creds { + Ok(Arc::new(Self::Credential { + key_id: creds.access_key_id().to_string(), + secret_key: creds.secret_access_key().to_string(), + token: creds.session_token().map(|s| s.to_string()), + })) + } else { + let refreshed_creds = Arc::new(self.inner.provide_credentials().await.map_err( + |e| Error::Internal { + message: format!("Failed to get AWS credentials: {}", e), + location: location!(), + }, + )?); + + self.cache + .write() + .await + .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone()); + + Ok(Arc::new(Self::Credential { + key_id: refreshed_creds.access_key_id().to_string(), + secret_key: refreshed_creds.secret_access_key().to_string(), + token: refreshed_creds.session_token().map(|s| s.to_string()), + })) + } + } +} + +impl StorageOptions { + /// Add values from the environment to storage options + pub fn with_env_s3(&mut self) { + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) { + if !self.0.contains_key(config_key.as_ref()) { + self.0 + .insert(config_key.as_ref().to_string(), value.to_string()); + } + } + } + } + } + + /// Subset of options relevant for s3 storage + pub fn as_s3_options(&self) -> HashMap { + self.0 + .iter() + .filter_map(|(key, value)| { + let s3_key = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()).ok()?; + Some((s3_key, value.clone())) + }) + .collect() + } +} + +impl ObjectStoreParams { + /// Create a new instance of [`ObjectStoreParams`] based on the AWS credentials. + pub fn with_aws_credentials( + aws_credentials: Option, + region: Option, + ) -> Self { + Self { + aws_credentials, + storage_options: region + .map(|region| [("region".into(), region)].iter().cloned().collect()), + ..Default::default() + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicBool, Ordering}; + + use object_store::path::Path; + + use crate::object_store::ObjectStoreRegistry; + + use super::*; + + #[derive(Debug, Default)] + struct MockAwsCredentialsProvider { + called: AtomicBool, + } + + #[async_trait::async_trait] + impl CredentialProvider for MockAwsCredentialsProvider { + type Credential = ObjectStoreAwsCredential; + + async fn get_credential(&self) -> ObjectStoreResult> { + self.called.store(true, Ordering::Relaxed); + Ok(Arc::new(Self::Credential { + key_id: "".to_string(), + secret_key: "".to_string(), + token: None, + })) + } + } + + #[tokio::test] + async fn test_injected_aws_creds_option_is_used() { + let mock_provider = Arc::new(MockAwsCredentialsProvider::default()); + let registry = Arc::new(ObjectStoreRegistry::default()); + + let params = ObjectStoreParams { + aws_credentials: Some(mock_provider.clone() as AwsCredentialProvider), + ..ObjectStoreParams::default() + }; + + // Not called yet + assert!(!mock_provider.called.load(Ordering::Relaxed)); + + let (store, _) = ObjectStore::from_uri_and_params(registry, "s3://not-a-bucket", ¶ms) + .await + .unwrap(); + + // fails, but we don't care + let _ = store + .open(&Path::parse("/").unwrap()) + .await + .unwrap() + .get_range(0..1) + .await; + + // Not called yet + assert!(mock_provider.called.load(Ordering::Relaxed)); + } + + #[test] + fn test_s3_path_parsing() { + let provider = AwsStoreProvider; + + let cases = [ + ("s3://bucket/path/to/file", "path/to/file"), + ( + "s3+ddb://bucket/path/to/file?ddbTableName=test", + "path/to/file", + ), + ]; + + for (uri, expected_path) in cases { + let url = Url::parse(uri).unwrap(); + let path = provider.extract_path(&url); + let expected_path = Path::from(expected_path); + assert_eq!(path, expected_path); + } + } +} diff --git a/rust/lance-io/src/object_store/providers/azure.rs b/rust/lance-io/src/object_store/providers/azure.rs new file mode 100644 index 00000000000..16412919431 --- /dev/null +++ b/rust/lance-io/src/object_store/providers/azure.rs @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; + +use object_store::{ + azure::{AzureConfigKey, MicrosoftAzureBuilder}, + RetryConfig, +}; +use url::Url; + +use crate::object_store::{ + ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE, + DEFAULT_CLOUD_IO_PARALLELISM, +}; +use lance_core::error::Result; + +#[derive(Default, Debug)] +pub struct AzureBlobStoreProvider; + +#[async_trait::async_trait] +impl ObjectStoreProvider for AzureBlobStoreProvider { + async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result { + let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE); + let mut storage_options = + StorageOptions(params.storage_options.clone().unwrap_or_default()); + let download_retry_count = storage_options.download_retry_count(); + + let max_retries = storage_options.client_max_retries(); + let retry_timeout = storage_options.client_retry_timeout(); + let retry_config = RetryConfig { + backoff: Default::default(), + max_retries, + retry_timeout: Duration::from_secs(retry_timeout), + }; + + storage_options.with_env_azure(); + let mut builder = MicrosoftAzureBuilder::new() + .with_url(base_path.as_ref()) + .with_retry(retry_config); + for (key, value) in storage_options.as_azure_options() { + builder = builder.with_config(key, value); + } + let inner = Arc::new(builder.build()?); + + Ok(ObjectStore { + inner, + scheme: String::from("az"), + block_size, + use_constant_size_upload_parts: false, + list_is_lexically_ordered: true, + io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, + download_retry_count, + }) + } +} + +impl StorageOptions { + /// Add values from the environment to storage options + pub fn with_env_azure(&mut self) { + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if let Ok(config_key) = AzureConfigKey::from_str(&key.to_ascii_lowercase()) { + if !self.0.contains_key(config_key.as_ref()) { + self.0 + .insert(config_key.as_ref().to_string(), value.to_string()); + } + } + } + } + } + + /// Subset of options relevant for azure storage + pub fn as_azure_options(&self) -> HashMap { + self.0 + .iter() + .filter_map(|(key, value)| { + let az_key = AzureConfigKey::from_str(&key.to_ascii_lowercase()).ok()?; + Some((az_key, value.clone())) + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_azure_store_path() { + let provider = AzureBlobStoreProvider; + + let url = Url::parse("az://bucket/path/to/file").unwrap(); + let path = provider.extract_path(&url); + let expected_path = object_store::path::Path::from("path/to/file"); + assert_eq!(path, expected_path); + } +} diff --git a/rust/lance-io/src/object_store/providers/gcp.rs b/rust/lance-io/src/object_store/providers/gcp.rs new file mode 100644 index 00000000000..21f4ffc955f --- /dev/null +++ b/rust/lance-io/src/object_store/providers/gcp.rs @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; + +use object_store::{ + gcp::{GcpCredential, GoogleCloudStorageBuilder, GoogleConfigKey}, + RetryConfig, StaticCredentialProvider, +}; +use url::Url; + +use crate::object_store::{ + ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_CLOUD_BLOCK_SIZE, + DEFAULT_CLOUD_IO_PARALLELISM, +}; +use lance_core::error::Result; + +#[derive(Default, Debug)] +pub struct GcsStoreProvider; + +#[async_trait::async_trait] +impl ObjectStoreProvider for GcsStoreProvider { + async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result { + let block_size = params.block_size.unwrap_or(DEFAULT_CLOUD_BLOCK_SIZE); + let mut storage_options = + StorageOptions(params.storage_options.clone().unwrap_or_default()); + let download_retry_count = storage_options.download_retry_count(); + + let max_retries = storage_options.client_max_retries(); + let retry_timeout = storage_options.client_retry_timeout(); + let retry_config = RetryConfig { + backoff: Default::default(), + max_retries, + retry_timeout: Duration::from_secs(retry_timeout), + }; + + storage_options.with_env_gcs(); + let mut builder = GoogleCloudStorageBuilder::new() + .with_url(base_path.as_ref()) + .with_retry(retry_config); + for (key, value) in storage_options.as_gcs_options() { + builder = builder.with_config(key, value); + } + let token_key = "google_storage_token"; + if let Some(storage_token) = storage_options.get(token_key) { + let credential = GcpCredential { + bearer: storage_token.to_string(), + }; + let credential_provider = Arc::new(StaticCredentialProvider::new(credential)) as _; + builder = builder.with_credentials(credential_provider); + } + let inner = Arc::new(builder.build()?); + + Ok(ObjectStore { + inner, + scheme: String::from("gs"), + block_size, + use_constant_size_upload_parts: false, + list_is_lexically_ordered: true, + io_parallelism: DEFAULT_CLOUD_IO_PARALLELISM, + download_retry_count, + }) + } +} + +impl StorageOptions { + /// Add values from the environment to storage options + pub fn with_env_gcs(&mut self) { + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + let lowercase_key = key.to_ascii_lowercase(); + let token_key = "google_storage_token"; + + if let Ok(config_key) = GoogleConfigKey::from_str(&lowercase_key) { + if !self.0.contains_key(config_key.as_ref()) { + self.0 + .insert(config_key.as_ref().to_string(), value.to_string()); + } + } + // Check for GOOGLE_STORAGE_TOKEN until GoogleConfigKey supports storage token + else if lowercase_key == token_key && !self.0.contains_key(token_key) { + self.0.insert(token_key.to_string(), value.to_string()); + } + } + } + } + + /// Subset of options relevant for gcs storage + pub fn as_gcs_options(&self) -> HashMap { + self.0 + .iter() + .filter_map(|(key, value)| { + let gcs_key = GoogleConfigKey::from_str(&key.to_ascii_lowercase()).ok()?; + Some((gcs_key, value.clone())) + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gcs_store_path() { + let provider = GcsStoreProvider; + + let url = Url::parse("gs://bucket/path/to/file").unwrap(); + let path = provider.extract_path(&url); + let expected_path = object_store::path::Path::from("path/to/file"); + assert_eq!(path, expected_path); + } +} diff --git a/rust/lance-io/src/object_store/providers/local.rs b/rust/lance-io/src/object_store/providers/local.rs new file mode 100644 index 00000000000..fa2b4474ffb --- /dev/null +++ b/rust/lance-io/src/object_store/providers/local.rs @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use object_store::{local::LocalFileSystem, path::Path}; +use url::Url; + +use crate::object_store::{ + ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_LOCAL_BLOCK_SIZE, + DEFAULT_LOCAL_IO_PARALLELISM, +}; +use lance_core::error::Result; + +#[derive(Default, Debug)] +pub struct FileStoreProvider; + +#[async_trait::async_trait] +impl ObjectStoreProvider for FileStoreProvider { + async fn new_store(&self, base_path: Url, params: &ObjectStoreParams) -> Result { + let block_size = params.block_size.unwrap_or(DEFAULT_LOCAL_BLOCK_SIZE); + let storage_options = StorageOptions(params.storage_options.clone().unwrap_or_default()); + let download_retry_count = storage_options.download_retry_count(); + Ok(ObjectStore { + inner: Arc::new(LocalFileSystem::new()), + scheme: base_path.scheme().to_owned(), + block_size, + use_constant_size_upload_parts: false, + list_is_lexically_ordered: false, + io_parallelism: DEFAULT_LOCAL_IO_PARALLELISM, + download_retry_count, + }) + } + + fn extract_path(&self, url: &Url) -> object_store::path::Path { + url.to_file_path() + .ok() + .and_then(|p| Path::from_absolute_path(p).ok()) + .unwrap_or_else(|| Path::from(url.path())) + } +} + +#[cfg(test)] +mod tests { + use crate::object_store::uri_to_url; + + use super::*; + + #[test] + fn test_file_store_path() { + let provider = FileStoreProvider; + + let cases = [ + ("file:///", ""), + ("file:///usr/local/bin", "usr/local/bin"), + ("file-object-store:///path/to/file", "path/to/file"), + ("file:///path/to/foo/../bar", "path/to/bar"), + ]; + + for (uri, expected_path) in cases { + let url = uri_to_url(uri).unwrap(); + let path = provider.extract_path(&url); + assert_eq!(path.as_ref(), expected_path, "uri: '{}'", uri); + } + } + + #[test] + #[cfg(windows)] + fn test_file_store_path_windows() { + let provider = FileStoreProvider; + + let cases = [ + ( + "C:\\Users\\ADMINI~1\\AppData\\Local\\", + "C:/Users/ADMINI~1/AppData/Local", + ), + ( + "C:\\Users\\ADMINI~1\\AppData\\Local\\..\\", + "C:/Users/ADMINI~1/AppData", + ), + ]; + + for (uri, expected_path) in cases { + let url = uri_to_url(uri).unwrap(); + let path = provider.extract_path(&url); + assert_eq!(path.as_ref(), expected_path); + } + } +} diff --git a/rust/lance-io/src/object_store/providers/memory.rs b/rust/lance-io/src/object_store/providers/memory.rs new file mode 100644 index 00000000000..9c300e878a8 --- /dev/null +++ b/rust/lance-io/src/object_store/providers/memory.rs @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use object_store::{memory::InMemory, path::Path}; +use url::Url; + +use crate::object_store::{ + ObjectStore, ObjectStoreParams, ObjectStoreProvider, StorageOptions, DEFAULT_LOCAL_BLOCK_SIZE, +}; +use lance_core::{error::Result, utils::tokio::get_num_compute_intensive_cpus}; + +/// Provides a fresh in-memory object store for each call to `new_store`. +#[derive(Default, Debug)] +pub struct MemoryStoreProvider; + +#[async_trait::async_trait] +impl ObjectStoreProvider for MemoryStoreProvider { + async fn new_store(&self, _base_path: Url, params: &ObjectStoreParams) -> Result { + let block_size = params.block_size.unwrap_or(DEFAULT_LOCAL_BLOCK_SIZE); + let storage_options = StorageOptions(params.storage_options.clone().unwrap_or_default()); + let download_retry_count = storage_options.download_retry_count(); + Ok(ObjectStore { + inner: Arc::new(InMemory::new()), + scheme: String::from("memory"), + block_size, + use_constant_size_upload_parts: false, + list_is_lexically_ordered: true, + io_parallelism: get_num_compute_intensive_cpus(), + download_retry_count, + }) + } + + fn extract_path(&self, url: &Url) -> Path { + let mut output = String::new(); + if let Some(domain) = url.domain() { + output.push_str(domain); + } + output.push_str(url.path()); + Path::from(output) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memory_store_path() { + let provider = MemoryStoreProvider; + + let url = Url::parse("memory://path/to/file").unwrap(); + let path = provider.extract_path(&url); + let expected_path = Path::from("path/to/file"); + assert_eq!(path, expected_path); + } +} diff --git a/rust/lance-io/src/object_store/tracing.rs b/rust/lance-io/src/object_store/tracing.rs index 2de8241c2bf..f890254000f 100644 --- a/rust/lance-io/src/object_store/tracing.rs +++ b/rust/lance-io/src/object_store/tracing.rs @@ -55,6 +55,7 @@ impl std::fmt::Display for TracedObjectStore { } #[async_trait::async_trait] +#[deny(clippy::missing_trait_methods)] impl object_store::ObjectStore for TracedObjectStore { #[instrument(level = "debug", skip(self, bytes))] async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { @@ -71,6 +72,17 @@ impl object_store::ObjectStore for TracedObjectStore { self.target.put_opts(location, bytes, opts).await } + async fn put_multipart( + &self, + location: &Path, + ) -> OSResult> { + let upload = self.target.put_multipart(location).await?; + Ok(Box::new(TracedMultipartUpload { + target: upload, + write_span: debug_span!("put_multipart"), + })) + } + async fn put_multipart_opts( &self, location: &Path, @@ -83,6 +95,11 @@ impl object_store::ObjectStore for TracedObjectStore { })) } + #[instrument(level = "debug", skip(self, location))] + async fn get(&self, location: &Path) -> OSResult { + self.target.get(location).await + } + #[instrument(level = "debug", skip(self, options))] async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { self.target.get_opts(location, options).await @@ -121,6 +138,15 @@ impl object_store::ObjectStore for TracedObjectStore { self.target.list(prefix) } + #[instrument(level = "debug", skip(self))] + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'_, OSResult> { + self.target.list_with_offset(prefix, offset) + } + #[instrument(level = "debug", skip(self))] async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { self.target.list_with_delimiter(prefix).await @@ -136,6 +162,11 @@ impl object_store::ObjectStore for TracedObjectStore { self.target.rename(from, to).await } + #[instrument(level = "debug", skip(self))] + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.target.rename_if_not_exists(from, to).await + } + #[instrument(level = "debug", skip(self))] async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { self.target.copy_if_not_exists(from, to).await diff --git a/rust/lance-io/src/object_writer.rs b/rust/lance-io/src/object_writer.rs index fa0edb972c2..1106c69d298 100644 --- a/rust/lance-io/src/object_writer.rs +++ b/rust/lance-io/src/object_writer.rs @@ -18,9 +18,10 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::task::JoinSet; use lance_core::{Error, Result}; +use tracing::Instrument; use crate::traits::Writer; -use snafu::{location, Location}; +use snafu::location; /// Start at 5MB. const INITIAL_UPLOAD_STEP: usize = 1024 * 1024 * 5; @@ -81,6 +82,12 @@ pub struct ObjectWriter { use_constant_size_upload_parts: bool, } +#[derive(Debug, Clone, Default)] +pub struct WriteResult { + pub size: usize, + pub e_tag: Option, +} + enum UploadState { /// The writer has been opened but no data has been written yet. Will be in /// this state until the buffer is full or the writer is shut down. @@ -95,23 +102,27 @@ enum UploadState { }, /// The writer is in the process of uploading data in a single PUT request. /// This happens when shutdown is called before the buffer is full. - PuttingSingle(BoxFuture<'static, OSResult<()>>), + PuttingSingle(BoxFuture<'static, OSResult>), /// The writer is in the process of completing the multipart upload. - Completing(BoxFuture<'static, OSResult<()>>), + Completing(BoxFuture<'static, OSResult>), /// The writer has been shut down and all data has been written. - Done, + Done(WriteResult), } /// Methods for state transitions. impl UploadState { fn started_to_completing(&mut self, path: Arc, buffer: Vec) { // To get owned self, we temporarily swap with Done. - let this = std::mem::replace(self, Self::Done); + let this = std::mem::replace(self, Self::Done(WriteResult::default())); *self = match this { Self::Started(store) => { let fut = async move { - store.put(&path, buffer.into()).await?; - Ok(()) + let size = buffer.len(); + let res = store.put(&path, buffer.into()).await?; + Ok(WriteResult { + size, + e_tag: res.e_tag, + }) }; Self::PuttingSingle(Box::pin(fut)) } @@ -121,7 +132,7 @@ impl UploadState { fn in_progress_to_completing(&mut self) { // To get owned self, we temporarily swap with Done. - let this = std::mem::replace(self, Self::Done); + let this = std::mem::replace(self, Self::Done(WriteResult::default())); *self = match this { Self::InProgress { mut upload, @@ -130,8 +141,11 @@ impl UploadState { } => { debug_assert!(futures.is_empty()); let fut = async move { - upload.complete().await?; - Ok(()) + let res = upload.complete().await?; + Ok(WriteResult { + size: 0, // This will be set properly later. + e_tag: res.e_tag, + }) }; Self::Completing(Box::pin(fut)) } @@ -198,7 +212,7 @@ impl ObjectWriter { let mut_self = &mut *self; loop { match &mut mut_self.state { - UploadState::Started(_) | UploadState::Done => break, + UploadState::Started(_) | UploadState::Done(_) => break, UploadState::CreatingUpload(ref mut fut) => match fut.poll_unpin(cx) { Poll::Ready(Ok(mut upload)) => { let mut futures = JoinSet::new(); @@ -274,7 +288,10 @@ impl ObjectWriter { } UploadState::PuttingSingle(ref mut fut) | UploadState::Completing(ref mut fut) => { match fut.poll_unpin(cx) { - Poll::Ready(Ok(())) => mut_self.state = UploadState::Done, + Poll::Ready(Ok(mut res)) => { + res.size = mut_self.cursor; + mut_self.state = UploadState::Done(res) + } Poll::Ready(Err(e)) => { return Err(std::io::Error::new(std::io::ErrorKind::Other, e)) } @@ -286,14 +303,19 @@ impl ObjectWriter { Ok(()) } - pub async fn shutdown(&mut self) -> Result<()> { + pub async fn shutdown(&mut self) -> Result { AsyncWriteExt::shutdown(self).await.map_err(|e| { Error::io( format!("failed to shutdown object writer for {}: {}", self.path, e), // and wrap it in here. location!(), ) - }) + })?; + if let UploadState::Done(result) = &self.state { + Ok(result.clone()) + } else { + unreachable!() + } } } @@ -302,7 +324,8 @@ impl Drop for ObjectWriter { // If there is a multipart upload started but not finished, we should abort it. if matches!(self.state, UploadState::InProgress { .. }) { // Take ownership of the state. - let state = std::mem::replace(&mut self.state, UploadState::Done); + let state = + std::mem::replace(&mut self.state, UploadState::Done(WriteResult::default())); if let UploadState::InProgress { mut upload, .. } = state { tokio::task::spawn(async move { let _ = upload.abort().await; @@ -375,7 +398,10 @@ impl AsyncWrite for ObjectWriter { *part_idx, mut_self.use_constant_size_upload_parts, ); - futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None)); + futures.spawn( + Self::put_part(upload.as_mut(), data, *part_idx, None) + .instrument(tracing::Span::current()), + ); *part_idx += 1; } } @@ -398,7 +424,7 @@ impl AsyncWrite for ObjectWriter { self.as_mut().poll_tasks(cx)?; match &self.state { - UploadState::Started(_) | UploadState::Done => Poll::Ready(Ok(())), + UploadState::Started(_) | UploadState::Done(_) => Poll::Ready(Ok(())), UploadState::CreatingUpload(_) | UploadState::Completing(_) | UploadState::PuttingSingle(_) => Poll::Pending, @@ -423,7 +449,7 @@ impl AsyncWrite for ObjectWriter { // through a Pin. let mut_self = &mut *self; match &mut mut_self.state { - UploadState::Done => return Poll::Ready(Ok(())), + UploadState::Done(_) => return Poll::Ready(Ok(())), UploadState::CreatingUpload(_) | UploadState::PuttingSingle(_) | UploadState::Completing(_) => return Poll::Pending, @@ -442,7 +468,10 @@ impl AsyncWrite for ObjectWriter { if !mut_self.buffer.is_empty() && futures.len() < max_upload_parallelism() { // We can just use `take` since we don't need the buffer anymore. let data = Bytes::from(std::mem::take(&mut mut_self.buffer)); - futures.spawn(Self::put_part(upload.as_mut(), data, *part_idx, None)); + futures.spawn( + Self::put_part(upload.as_mut(), data, *part_idx, None) + .instrument(tracing::Span::current()), + ); // We need to go back to beginning of loop to poll the // new feature and get the waker registered on the ctx. continue; @@ -492,6 +521,22 @@ mod tests { assert_eq!(object_writer.write(buf.as_slice()).await.unwrap(), 256); assert_eq!(object_writer.tell().await.unwrap(), 256 * 3); - object_writer.shutdown().await.unwrap(); + let res = object_writer.shutdown().await.unwrap(); + assert_eq!(res.size, 256 * 3); + + // Trigger multi part upload + let mut object_writer = ObjectWriter::new(&store, &Path::from("/bar")) + .await + .unwrap(); + let buf = vec![0; INITIAL_UPLOAD_STEP / 3 * 2]; + for i in 0..5 { + // Write more data to trigger the multipart upload + // This should be enough to trigger a multipart upload + object_writer.write_all(buf.as_slice()).await.unwrap(); + // Check the cursor + assert_eq!(object_writer.tell().await.unwrap(), (i + 1) * buf.len()); + } + let res = object_writer.shutdown().await.unwrap(); + assert_eq!(res.size, buf.len() * 5); } } diff --git a/rust/lance-io/src/scheduler.rs b/rust/lance-io/src/scheduler.rs index b6cfff300a8..c6262b55a4b 100644 --- a/rust/lance-io/src/scheduler.rs +++ b/rust/lance-io/src/scheduler.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use futures::channel::oneshot; use futures::{FutureExt, TryFutureExt}; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use std::collections::BinaryHeap; use std::fmt::Debug; use std::future::Future; @@ -25,6 +25,19 @@ const BACKPRESSURE_MIN: u64 = 5; // Don't log backpressure warnings more than once / minute const BACKPRESSURE_DEBOUNCE: u64 = 60; +// Global counter of how many IOPS we have issued +static IOPS_COUNTER: AtomicU64 = AtomicU64::new(0); +// Global counter of how many bytes were read by the scheduler +static BYTES_READ_COUNTER: AtomicU64 = AtomicU64::new(0); + +pub fn iops_counter() -> u64 { + IOPS_COUNTER.load(Ordering::Acquire) +} + +pub fn bytes_read_counter() -> u64 { + BYTES_READ_COUNTER.load(Ordering::Acquire) +} + // There are two structures that control the I/O scheduler concurrency. First, // we have a hard limit on the number of IOPS that can be issued concurrently. // This limit is process-wide. @@ -64,7 +77,7 @@ struct IopsReservation<'a> { value: Option>, } -impl<'a> IopsReservation<'a> { +impl IopsReservation<'_> { // Forget the reservation, so it won't be released on drop fn forget(&mut self) { if let Some(value) = self.value.take() { @@ -208,8 +221,8 @@ impl IoQueueState { && seconds_elapsed < BACKPRESSURE_DEBOUNCE) || since_last_warn > BACKPRESSURE_DEBOUNCE { - tracing::event!(tracing::Level::WARN, "Backpressure throttle exceeded"); - log::warn!("Backpressure throttle is full, I/O will pause until buffer is drained. Max I/O bandwidth will not be achieved because CPU is falling behind"); + tracing::event!(tracing::Level::DEBUG, "Backpressure throttle exceeded"); + log::debug!("Backpressure throttle is full, I/O will pause until buffer is drained. Max I/O bandwidth will not be achieved because CPU is falling behind"); self.last_warn .store(seconds_elapsed.max(1), Ordering::Release); } @@ -456,6 +469,8 @@ impl IoTask { let bytes_fut = self .reader .get_range(self.to_read.start as usize..self.to_read.end as usize); + IOPS_COUNTER.fetch_add(1, Ordering::Release); + BYTES_READ_COUNTER.fetch_add(self.num_bytes(), Ordering::Release); bytes_fut.await.map_err(Error::from) }; IOPS_QUOTA.release(); @@ -482,6 +497,60 @@ async fn run_io_loop(tasks: Arc) { } } +#[derive(Debug)] +struct StatsCollector { + iops: AtomicU64, + requests: AtomicU64, + bytes_read: AtomicU64, +} + +impl StatsCollector { + fn new() -> Self { + Self { + iops: AtomicU64::new(0), + requests: AtomicU64::new(0), + bytes_read: AtomicU64::new(0), + } + } + + fn iops(&self) -> u64 { + self.iops.load(Ordering::Relaxed) + } + + fn bytes_read(&self) -> u64 { + self.bytes_read.load(Ordering::Relaxed) + } + + fn requests(&self) -> u64 { + self.requests.load(Ordering::Relaxed) + } + + fn record_request(&self, request: &[Range]) { + self.requests.fetch_add(1, Ordering::Relaxed); + self.iops.fetch_add(request.len() as u64, Ordering::Relaxed); + self.bytes_read.fetch_add( + request.iter().map(|r| r.end - r.start).sum::(), + Ordering::Relaxed, + ); + } +} + +pub struct ScanStats { + pub iops: u64, + pub requests: u64, + pub bytes_read: u64, +} + +impl ScanStats { + fn new(stats: &StatsCollector) -> Self { + Self { + iops: stats.iops(), + requests: stats.requests(), + bytes_read: stats.bytes_read(), + } + } +} + /// An I/O scheduler which wraps an ObjectStore and throttles the amount of /// parallel I/O that can be run. /// @@ -489,6 +558,7 @@ async fn run_io_loop(tasks: Arc) { pub struct ScanScheduler { object_store: Arc, io_queue: Arc, + stats: Arc, } impl Debug for ScanScheduler { @@ -547,6 +617,7 @@ impl ScanScheduler { let scheduler = Self { object_store, io_queue: io_queue.clone(), + stats: Arc::new(StatsCollector::new()), }; tokio::task::spawn(async move { run_io_loop(io_queue).await }); Arc::new(scheduler) @@ -558,8 +629,8 @@ impl ScanScheduler { /// /// * path - the path to the file to open /// * base_priority - the base priority for I/O requests submitted to this file scheduler - /// this will determine the upper 64 bits of priority (the lower 64 bits - /// come from `submit_request` and `submit_single`) + /// this will determine the upper 64 bits of priority (the lower 64 bits + /// come from `submit_request` and `submit_single`) pub async fn open_file_with_priority( self: &Arc, path: &Path, @@ -646,6 +717,10 @@ impl ScanScheduler { rsp.data }) } + + pub fn stats(&self) -> ScanStats { + ScanStats::new(self.stats.as_ref()) + } } impl Drop for ScanScheduler { @@ -690,6 +765,8 @@ impl FileScheduler { request: Vec>, priority: u64, ) -> impl Future>> + Send { + self.root.stats.record_request(&request); + // The final priority is a combination of the row offset and the file number let priority = ((self.base_priority as u128) << 64) + priority as u128; @@ -744,6 +821,15 @@ impl FileScheduler { } } + pub fn with_priority(&self, priority: u64) -> Self { + Self { + reader: self.reader.clone(), + root: self.root.clone(), + block_size: self.block_size, + base_priority: priority, + } + } + /// Submit a single IOP to the reader /// /// If you have multiple IOPS to perform then [`Self::submit_request`] is going diff --git a/rust/lance-io/src/utils.rs b/rust/lance-io/src/utils.rs index 37253339a5c..63fc300f724 100644 --- a/rust/lance-io/src/utils.rs +++ b/rust/lance-io/src/utils.rs @@ -12,7 +12,7 @@ use byteorder::{ByteOrder, LittleEndian}; use bytes::Bytes; use lance_arrow::*; use prost::Message; -use snafu::{location, Location}; +use snafu::location; use crate::{ encodings::{binary::BinaryDecoder, plain::PlainDecoder, AsyncIndex, Decoder}, @@ -104,7 +104,6 @@ pub async fn read_message(reader: &dyn Reader, pos: usize) /// Read a Protobuf-backed struct at file position: `pos`. // TODO: pub(crate) pub async fn read_struct< - 'm, M: Message + Default + 'static, T: ProtoStruct + TryFrom, >( @@ -118,11 +117,7 @@ pub async fn read_struct< pub async fn read_last_block(reader: &dyn Reader) -> object_store::Result { let file_size = reader.size().await?; let block_size = reader.block_size(); - let begin = if file_size < block_size { - 0 - } else { - file_size - block_size - }; + let begin = file_size.saturating_sub(block_size); reader.get_range(begin..file_size).await } diff --git a/rust/lance-io/tests/gcs_integration.rs b/rust/lance-io/tests/gcs_integration.rs index 92137322441..959a3a61086 100644 --- a/rust/lance-io/tests/gcs_integration.rs +++ b/rust/lance-io/tests/gcs_integration.rs @@ -4,13 +4,15 @@ //! They do not work against any local emulator right now. #![cfg(feature = "gcs-test")] +use std::sync::Arc; + // TODO: Once we re-use this logic for S3, we can instead use tests against // Minio to validate the multipart upload logic. use lance_io::object_store::ObjectStore; use object_store::path::Path; use tokio::io::AsyncWriteExt; -async fn get_store() -> ObjectStore { +async fn get_store() -> Arc { let bucket_name = std::env::var("OBJECT_STORE_BUCKET").unwrap_or_else(|_| "test-bucket".into()); ObjectStore::from_uri(&format!("gs://{}/object", bucket_name)) .await diff --git a/rust/lance-linalg/benches/compute_partition.rs b/rust/lance-linalg/benches/compute_partition.rs index 7b155a9aa5b..5cdda57158a 100644 --- a/rust/lance-linalg/benches/compute_partition.rs +++ b/rust/lance-linalg/benches/compute_partition.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use arrow_array::types::Float32Type; use criterion::{criterion_group, criterion_main, Criterion}; +use lance_linalg::kmeans::KMeansAlgoFloat; use lance_linalg::{distance::MetricType, kmeans::compute_partitions}; use lance_testing::datagen::generate_random_array_with_seed; #[cfg(target_os = "linux")] @@ -24,9 +25,9 @@ fn bench_compute_partitions(c: &mut Criterion) { c.bench_function("compute_centroids(L2)", |b| { b.iter(|| { - compute_partitions( - centroids.values(), - input.values(), + compute_partitions::>( + centroids.as_ref(), + &input, DIMENSION, MetricType::L2, ) @@ -35,9 +36,9 @@ fn bench_compute_partitions(c: &mut Criterion) { c.bench_function("compute_centroids(Cosine)", |b| { b.iter(|| { - compute_partitions( - centroids.values(), - input.values(), + compute_partitions::>( + centroids.as_ref(), + &input, DIMENSION, MetricType::Cosine, ) diff --git a/rust/lance-linalg/build.rs b/rust/lance-linalg/build.rs index 40c28c54997..2ee88c65495 100644 --- a/rust/lance-linalg/build.rs +++ b/rust/lance-linalg/build.rs @@ -27,20 +27,26 @@ fn main() -> Result<(), String> { return Ok(()); } - if cfg!(target_os = "windows") { + // Important: we don't use `cfg!(target_arch)` here because that is the target_arch + // for the build script, not the target_arch for the library. Similar story for + // target_os. + let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap(); + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap(); + + if target_os == "windows" { println!( "cargo:warning=fp16 kernels are not supported on Windows. Skipping compilation of kernels." ); return Ok(()); } - if cfg!(all(target_arch = "aarch64", target_os = "macos")) { + if target_arch == "aarch64" && target_os == "macos" { // Build a version with NEON build_f16_with_flags("neon", &["-mtune=apple-m1"]).unwrap(); - } else if cfg!(all(target_arch = "aarch64", target_os = "linux")) { + } else if target_arch == "aarch64" && target_os == "linux" { // Build a version with NEON build_f16_with_flags("neon", &["-march=armv8.2-a+fp16"]).unwrap(); - } else if cfg!(target_arch = "x86_64") { + } else if target_arch == "x86_64" { // Build a version with AVX512 if let Err(err) = build_f16_with_flags("avx512", &["-march=sapphirerapids", "-mavx512fp16"]) { @@ -63,7 +69,7 @@ fn main() -> Result<(), String> { return Err(format!("Unable to build AVX2 f16 kernels. Please use Clang >= 6 or GCC >= 12 or remove the fp16kernels feature. Received error: {}", err)); }; // There is no SSE instruction set for f16 -> f32 float conversion - } else if cfg!(target_arch = "loongarch64") { + } else if target_arch == "loongarch64" { // Build a version with LSX and LASX build_f16_with_flags("lsx", &["-mlsx"]).unwrap(); build_f16_with_flags("lasx", &["-mlasx"]).unwrap(); diff --git a/rust/lance-linalg/src/clustering.rs b/rust/lance-linalg/src/clustering.rs index 0bbfea3953d..99a166f8d5a 100644 --- a/rust/lance-linalg/src/clustering.rs +++ b/rust/lance-linalg/src/clustering.rs @@ -33,7 +33,7 @@ pub trait Clustering { /// ## Parameters: /// * `data`: an `N * D` of D-dimensional vectors. /// * `nprobes`: If provided, the number of partitions per vector to return. - /// If not provided, return 1 partition per vector. + /// If not provided, return 1 partition per vector. /// /// ## Returns: /// * An `N * nprobes` matrix of partition IDs. diff --git a/rust/lance-linalg/src/distance.rs b/rust/lance-linalg/src/distance.rs index fdb9226a5aa..6e79c7d8b03 100644 --- a/rust/lance-linalg/src/distance.rs +++ b/rust/lance-linalg/src/distance.rs @@ -11,8 +11,10 @@ use std::sync::Arc; -use arrow_array::{Array, FixedSizeListArray, Float32Array}; -use arrow_schema::ArrowError; +use arrow_array::cast::AsArray; +use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type}; +use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray}; +use arrow_schema::{ArrowError, DataType}; pub mod cosine; pub mod dot; @@ -23,6 +25,7 @@ pub mod norm_l2; pub use cosine::*; use deepsize::DeepSizeOf; pub use dot::*; +use hamming::hamming_distance_arrow_batch; pub use l2::*; pub use norm_l2::*; @@ -55,7 +58,7 @@ impl DistanceType { Self::L2 => l2_distance_arrow_batch, Self::Cosine => cosine_distance_arrow_batch, Self::Dot => dot_distance_arrow_batch, - Self::Hamming => todo!(), + Self::Hamming => hamming_distance_arrow_batch, } } @@ -100,3 +103,104 @@ impl TryFrom<&str> for DistanceType { } } } + +pub fn multivec_distance( + query: &dyn Array, + vectors: &ListArray, + distance_type: DistanceType, +) -> Result> { + let dim = if let DataType::FixedSizeList(_, dim) = vectors.value_type() { + dim as usize + } else { + return Err(ArrowError::InvalidArgumentError( + "vectors must be a list of fixed size list".to_string(), + )); + }; + + // check the query vectors type first + // because we don't want to check the vectors type for each vector + match query.data_type() { + DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {} + _ => { + return Err(ArrowError::InvalidArgumentError( + "query must be a float array or binary array".to_string(), + )); + } + } + + let dists = vectors + .iter() + .map(|v| { + v.map(|v| { + let multivector = v.as_fixed_size_list(); + match distance_type { + DistanceType::Hamming => { + let query = query.as_primitive::().values(); + query + .chunks_exact(dim) + .map(|q| { + multivector + .values() + .as_primitive::() + .values() + .chunks_exact(dim) + .map(|v| hamming::hamming(q, v)) + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap() + }) + .sum() + } + _ => match query.data_type() { + DataType::Float16 => multivec_distance_impl::( + query, + multivector, + dim, + distance_type, + ), + DataType::Float32 => multivec_distance_impl::( + query, + multivector, + dim, + distance_type, + ), + DataType::Float64 => multivec_distance_impl::( + query, + multivector, + dim, + distance_type, + ), + _ => unreachable!("missed to check query type"), + }, + } + }) + .unwrap_or(f32::NAN) + }) + .map(|sim| 1.0 - sim) + .collect(); + Ok(dists) +} + +fn multivec_distance_impl( + query: &dyn Array, + multivector: &FixedSizeListArray, + dim: usize, + distance_type: DistanceType, +) -> f32 +where + T::Native: L2 + Cosine + Dot, +{ + let query = query.as_primitive::().values(); + query + .chunks_exact(dim) + .map(|q| { + multivector + .values() + .as_primitive::() + .values() + .chunks_exact(dim) + .map(|v| 1.0 - distance_type.func()(q, v)) + .max_by(|a, b| a.total_cmp(b)) + .unwrap() + }) + .sum() +} diff --git a/rust/lance-linalg/src/distance/cosine.rs b/rust/lance-linalg/src/distance/cosine.rs index 9d1ce0a756e..864f5b962f6 100644 --- a/rust/lance-linalg/src/distance/cosine.rs +++ b/rust/lance-linalg/src/distance/cosine.rs @@ -11,12 +11,12 @@ use std::sync::Arc; use arrow_array::{ cast::AsArray, - types::{Float16Type, Float32Type, Float64Type}, + types::{Float16Type, Float32Type, Float64Type, Int8Type}, Array, FixedSizeListArray, Float32Array, }; use arrow_schema::DataType; use half::{bf16, f16}; -use lance_arrow::{ArrowFloatType, FloatArray}; +use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray}; #[cfg(feature = "fp16kernels")] use lance_core::utils::cpu::SimdSupport; use lance_core::utils::cpu::FP16_SIMD_SUPPORT; @@ -320,6 +320,14 @@ pub fn cosine_distance_arrow_batch( DataType::Float16 => do_cosine_distance_arrow_batch::(from.as_primitive(), to), DataType::Float32 => do_cosine_distance_arrow_batch::(from.as_primitive(), to), DataType::Float64 => do_cosine_distance_arrow_batch::(from.as_primitive(), to), + DataType::Int8 => do_cosine_distance_arrow_batch::( + &from + .as_primitive::() + .into_iter() + .map(|x| x.unwrap() as f32) + .collect(), + &to.convert_to_floating_point()?, + ), _ => Err(Error::InvalidArgumentError(format!( "Unsupported data type {:?}", from.data_type() diff --git a/rust/lance-linalg/src/distance/dot.rs b/rust/lance-linalg/src/distance/dot.rs index c8f8db86165..9ca42d8a78d 100644 --- a/rust/lance-linalg/src/distance/dot.rs +++ b/rust/lance-linalg/src/distance/dot.rs @@ -8,11 +8,11 @@ use std::ops::AddAssign; use std::sync::Arc; use crate::Error; -use arrow_array::types::{Float16Type, Float64Type}; +use arrow_array::types::{Float16Type, Float64Type, Int8Type}; use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray, Float32Array}; use arrow_schema::DataType; use half::{bf16, f16}; -use lance_arrow::{ArrowFloatType, FloatArray}; +use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray}; #[cfg(feature = "fp16kernels")] use lance_core::utils::cpu::SimdSupport; use lance_core::utils::cpu::FP16_SIMD_SUPPORT; @@ -278,6 +278,14 @@ pub fn dot_distance_arrow_batch( DataType::Float16 => do_dot_distance_arrow_batch::(from.as_primitive(), to), DataType::Float32 => do_dot_distance_arrow_batch::(from.as_primitive(), to), DataType::Float64 => do_dot_distance_arrow_batch::(from.as_primitive(), to), + DataType::Int8 => do_dot_distance_arrow_batch::( + &from + .as_primitive::() + .into_iter() + .map(|x| x.unwrap() as f32) + .collect(), + &to.convert_to_floating_point()?, + ), _ => Err(Error::InvalidArgumentError(format!( "Unsupported data type: {:?}", from.data_type() diff --git a/rust/lance-linalg/src/distance/hamming.rs b/rust/lance-linalg/src/distance/hamming.rs index 0b94f867bc0..03fda1467cc 100644 --- a/rust/lance-linalg/src/distance/hamming.rs +++ b/rust/lance-linalg/src/distance/hamming.rs @@ -3,6 +3,14 @@ //! Hamming distance. +use std::sync::Arc; + +use crate::{Error, Result}; +use arrow_array::cast::AsArray; +use arrow_array::types::UInt8Type; +use arrow_array::{Array, FixedSizeListArray, Float32Array}; +use arrow_schema::DataType; + pub trait Hamming { /// Hamming distance between two vectors. fn hamming(x: &[u8], y: &[u8]) -> f32; @@ -44,6 +52,40 @@ pub fn hamming_scalar(x: &[u8], y: &[u8]) -> f32 { .sum::() as f32 } +pub fn hamming_distance_batch<'a>( + from: &'a [u8], + to: &'a [u8], + dimension: usize, +) -> Box + 'a> { + debug_assert_eq!(from.len(), dimension); + debug_assert_eq!(to.len() % dimension, 0); + Box::new(to.chunks_exact(dimension).map(|v| hamming(from, v))) +} + +pub fn hamming_distance_arrow_batch( + from: &dyn Array, + to: &FixedSizeListArray, +) -> Result> { + let dists = match *from.data_type() { + DataType::UInt8 => hamming_distance_batch( + from.as_primitive::().values(), + to.values().as_primitive::().values(), + from.len(), + ), + _ => { + return Err(Error::InvalidArgumentError(format!( + "Unsupported data type: {:?}", + from.data_type() + ))) + } + }; + + Ok(Arc::new(Float32Array::new( + dists.collect(), + to.nulls().cloned(), + ))) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance-linalg/src/distance/l2.rs b/rust/lance-linalg/src/distance/l2.rs index f3c98be7093..c52c565a8c7 100644 --- a/rust/lance-linalg/src/distance/l2.rs +++ b/rust/lance-linalg/src/distance/l2.rs @@ -10,12 +10,12 @@ use std::sync::Arc; use arrow_array::{ cast::AsArray, - types::{Float16Type, Float32Type, Float64Type}, + types::{Float16Type, Float32Type, Float64Type, Int8Type}, Array, FixedSizeListArray, Float32Array, }; use arrow_schema::DataType; use half::{bf16, f16}; -use lance_arrow::{ArrowFloatType, FloatArray}; +use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray}; #[cfg(feature = "fp16kernels")] use lance_core::utils::cpu::SimdSupport; use lance_core::utils::cpu::FP16_SIMD_SUPPORT; @@ -293,6 +293,14 @@ pub fn l2_distance_arrow_batch( DataType::Float16 => do_l2_distance_arrow_batch::(from.as_primitive(), to), DataType::Float32 => do_l2_distance_arrow_batch::(from.as_primitive(), to), DataType::Float64 => do_l2_distance_arrow_batch::(from.as_primitive(), to), + DataType::Int8 => do_l2_distance_arrow_batch::( + &from + .as_primitive::() + .into_iter() + .map(|x| x.unwrap() as f32) + .collect(), + &to.convert_to_floating_point()?, + ), _ => Err(Error::ComputeError(format!( "Unsupported data type: {}", from.data_type() diff --git a/rust/lance-linalg/src/kmeans.rs b/rust/lance-linalg/src/kmeans.rs index 57c8f16839a..fb484851a84 100644 --- a/rust/lance-linalg/src/kmeans.rs +++ b/rust/lance-linalg/src/kmeans.rs @@ -23,12 +23,13 @@ use arrow_array::{ArrowNumericType, UInt8Array}; use arrow_ord::sort::sort_to_indices; use arrow_schema::{ArrowError, DataType}; use bitvec::prelude::*; +use lance_arrow::FixedSizeListArrayExt; use log::{info, warn}; use num_traits::{AsPrimitive, Float, FromPrimitive, Num, Zero}; use rand::prelude::*; use rayon::prelude::*; -use crate::distance::hamming::hamming; +use crate::distance::hamming::{hamming, hamming_distance_batch}; use crate::distance::{dot_distance_batch, DistanceType}; use crate::kernels::{argmax, argmin_value_float}; use crate::{ @@ -41,10 +42,10 @@ use crate::{ use crate::{Error, Result}; /// KMean initialization method. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq)] pub enum KMeanInit { Random, - KMeanPlusPlus, + Incremental(Arc), } /// KMean Training Parameters @@ -80,11 +81,21 @@ impl Default for KMeansParams { } impl KMeansParams { - /// Create a new KMeansParams with cosine distance. - #[allow(dead_code)] - fn cosine() -> Self { + pub fn new( + centroids: Option>, + max_iters: u32, + redos: usize, + distance_type: DistanceType, + ) -> Self { + let init = match centroids { + Some(centroids) => KMeanInit::Incremental(centroids), + None => KMeanInit::Random, + }; Self { - distance_type: DistanceType::Cosine, + max_iters, + redos, + distance_type, + init, ..Default::default() } } @@ -103,6 +114,9 @@ pub struct KMeans { /// How to calculate distance between two vectors. pub distance_type: DistanceType, + + /// The loss of the last training. + pub loss: f64, } /// Randomly initialize kmeans centroids. @@ -127,6 +141,7 @@ fn kmeans_random_init( centroids: Arc::new(centroids), dimension, distance_type, + loss: f64::MAX, } } @@ -170,7 +185,7 @@ fn hist_stddev(k: usize, membership: &[Option]) -> f32 { .sqrt() } -trait KMeansAlgo { +pub trait KMeansAlgo { /// Recompute the membership of each vector. /// /// Parameters: @@ -191,10 +206,11 @@ trait KMeansAlgo { k: usize, membership: &[Option], distance_type: DistanceType, + loss: f64, ) -> KMeans; } -struct KMeansAlgoFloat +pub struct KMeansAlgoFloat where T::Native: Float + Num, { @@ -245,6 +261,7 @@ where k: usize, membership: &[Option], distance_type: DistanceType, + loss: f64, ) -> KMeans { let mut cluster_cnts = vec![0_u64; k]; let mut new_centroids = vec![T::Native::zero(); k * dimension]; @@ -293,6 +310,7 @@ where centroids: Arc::new(PrimitiveArray::::from_iter_values(new_centroids)), dimension, distance_type, + loss, } } } @@ -337,6 +355,7 @@ impl KMeansAlgo for KModeAlgo { k: usize, membership: &[Option], distance_type: DistanceType, + loss: f64, ) -> KMeans { assert_eq!(distance_type, DistanceType::Hamming); @@ -379,6 +398,7 @@ impl KMeansAlgo for KModeAlgo { centroids: Arc::new(UInt8Array::from(centroids)), dimension, distance_type, + loss, } } } @@ -389,6 +409,7 @@ impl KMeans { centroids: arrow_array::array::new_empty_array(&DataType::Float32), dimension, distance_type, + loss: f64::MAX, } } @@ -398,6 +419,7 @@ impl KMeans { centroids: ArrayRef, dimension: usize, distance_type: DistanceType, + loss: f64, ) -> Self { assert!(matches!( centroids.data_type(), @@ -407,6 +429,7 @@ impl KMeans { centroids, dimension, distance_type, + loss, } } @@ -462,7 +485,7 @@ impl KMeans { // TODO: use seed for Rng. let rng = SmallRng::from_entropy(); for redo in 1..=params.redos { - let mut kmeans: Self = match params.init { + let mut kmeans: Self = match ¶ms.init { KMeanInit::Random => Self::init_random::( data.values(), dimension, @@ -470,9 +493,12 @@ impl KMeans { rng.clone(), params.distance_type, ), - KMeanInit::KMeanPlusPlus => { - unimplemented!() - } + KMeanInit::Incremental(centroids) => Self::with_centroids( + centroids.values().clone(), + dimension, + params.distance_type, + f64::MAX, + ), }; let mut loss = f64::MAX; @@ -496,6 +522,7 @@ impl KMeans { k, &membership, params.distance_type, + last_loss, ); last_membership = Some(membership); if (loss - last_loss).abs() / last_loss < params.tolerance { @@ -596,6 +623,12 @@ pub fn kmeans_find_partitions_arrow_array( nprobes, distance_type, )?), + (DataType::UInt8, DataType::UInt8) => kmeans_find_partitions_binary( + centroids.values().as_primitive::().values(), + query.as_primitive::().values(), + nprobes, + distance_type, + ), _ => Err(ArrowError::InvalidArgumentError(format!( "Centroids and vectors have different types: {} != {}", centroids.value_type(), @@ -637,38 +670,83 @@ pub fn kmeans_find_partitions( sort_to_indices(&dists_arr, None, Some(nprobes)) } +pub fn kmeans_find_partitions_binary( + centroids: &[u8], + query: &[u8], + nprobes: usize, + distance_type: DistanceType, +) -> Result { + let dists: Vec = match distance_type { + DistanceType::Hamming => hamming_distance_batch(query, centroids, query.len()).collect(), + _ => { + panic!( + "KMeans::find_partitions: {} is not supported", + distance_type + ); + } + }; + + // TODO: use heap to just keep nprobes smallest values. + let dists_arr = Float32Array::from(dists); + sort_to_indices(&dists_arr, None, Some(nprobes)) +} + /// Compute partitions from Arrow FixedSizeListArray. pub fn compute_partitions_arrow_array( centroids: &FixedSizeListArray, vectors: &FixedSizeListArray, distance_type: DistanceType, -) -> Result>> { +) -> Result<(Vec>, f64)> { if centroids.value_length() != vectors.value_length() { return Err(ArrowError::InvalidArgumentError( "Centroids and vectors have different dimensions".to_string(), )); } match (centroids.value_type(), vectors.value_type()) { - (DataType::Float16, DataType::Float16) => Ok(compute_partitions( - centroids.values().as_primitive::().values(), - vectors.values().as_primitive::().values(), + (DataType::Float16, DataType::Float16) => Ok(compute_partitions::< + Float16Type, + KMeansAlgoFloat, + >( + centroids.values().as_primitive(), + vectors.values().as_primitive(), centroids.value_length(), distance_type, )), - (DataType::Float32, DataType::Float32) => Ok(compute_partitions( - centroids.values().as_primitive::().values(), - vectors.values().as_primitive::().values(), + (DataType::Float32, DataType::Float32) => Ok(compute_partitions::< + Float32Type, + KMeansAlgoFloat, + >( + centroids.values().as_primitive(), + vectors.values().as_primitive(), centroids.value_length(), distance_type, )), - (DataType::Float64, DataType::Float64) => Ok(compute_partitions( - centroids.values().as_primitive::().values(), - vectors.values().as_primitive::().values(), + (DataType::Float32, DataType::Int8) => Ok(compute_partitions::< + Float32Type, + KMeansAlgoFloat, + >( + centroids.values().as_primitive(), + vectors.convert_to_floating_point()?.values().as_primitive(), + centroids.value_length(), + distance_type, + )), + (DataType::Float64, DataType::Float64) => Ok(compute_partitions::< + Float64Type, + KMeansAlgoFloat, + >( + centroids.values().as_primitive(), + vectors.values().as_primitive(), + centroids.value_length(), + distance_type, + )), + (DataType::UInt8, DataType::UInt8) => Ok(compute_partitions::( + centroids.values().as_primitive(), + vectors.values().as_primitive(), centroids.value_length(), distance_type, )), _ => Err(ArrowError::InvalidArgumentError( - "Centroids and vectors have different types".to_string(), + "Centroids and vectors have incompatible types".to_string(), )), } } @@ -676,17 +754,22 @@ pub fn compute_partitions_arrow_array( /// Compute partition ID of each vector in the KMeans. /// /// If returns `None`, means the vector is not valid, i.e., all `NaN`. -pub fn compute_partitions( - centroids: &[T], - vectors: &[T], +pub fn compute_partitions>( + centroids: &PrimitiveArray, + vectors: &PrimitiveArray, dimension: impl AsPrimitive, distance_type: DistanceType, -) -> Vec> { +) -> (Vec>, f64) +where + T::Native: Num, +{ let dimension = dimension.as_(); - vectors - .par_chunks(dimension) - .map(|vec| compute_partition(centroids, vec, distance_type)) - .collect::>() + K::compute_membership_and_loss( + centroids.values(), + vectors.values(), + dimension, + distance_type, + ) } #[inline] @@ -713,7 +796,7 @@ pub fn compute_partition( #[cfg(test)] mod tests { - use std::iter::repeat; + use std::iter::repeat_n; use lance_arrow::*; use lance_testing::datagen::generate_random_array; @@ -752,7 +835,12 @@ mod tests { ) }) .collect::>(); - let actual = compute_partitions(centroids.values(), data.values(), DIM, DistanceType::L2); + let (actual, _) = compute_partitions::>( + ¢roids, + &data, + DIM, + DistanceType::L2, + ); assert_eq!(expected, actual); } @@ -780,13 +868,19 @@ mod tests { const K: usize = 32; const NUM_CENTROIDS: usize = 16 * 2048; let centroids = generate_random_array(DIM * NUM_CENTROIDS); - let values = Float32Array::from_iter_values(repeat(f32::NAN).take(DIM * K)); + let values = Float32Array::from_iter_values(repeat_n(f32::NAN, DIM * K)); - compute_partitions::(centroids.values(), values.values(), DIM, DistanceType::L2) - .iter() - .for_each(|cd| { - assert!(cd.is_none()); - }); + compute_partitions::>( + ¢roids, + &values, + DIM, + DistanceType::L2, + ) + .0 + .iter() + .for_each(|cd| { + assert!(cd.is_none()); + }); } #[tokio::test] @@ -795,7 +889,7 @@ mod tests { const K: usize = 32; const NUM_CENTROIDS: usize = 16 * 2048; let centroids = generate_random_array(DIM * NUM_CENTROIDS); - let values = repeat(f32::NAN).take(DIM * K).collect::>(); + let values = repeat_n(f32::NAN, DIM * K).collect::>(); let (membership, _) = KMeansAlgoFloat::::compute_membership_and_loss( centroids.as_slice(), diff --git a/rust/lance-linalg/src/simd.rs b/rust/lance-linalg/src/simd.rs index 74c3b56d3b5..dc3b6b680ee 100644 --- a/rust/lance-linalg/src/simd.rs +++ b/rust/lance-linalg/src/simd.rs @@ -16,8 +16,10 @@ use std::ops::{Add, AddAssign, Mul, Sub, SubAssign}; pub mod f32; pub mod i32; +pub mod u8; use num_traits::{Float, Num}; +use u8::u8x16; /// Lance SIMD lib /// @@ -41,8 +43,6 @@ pub trait SIMD: /// Create a new instance with all lanes set to zero. fn zeros() -> Self; - /// Gather elements from the slice, using i32 indices. - /// Load aligned data from aligned memory. /// /// # Safety @@ -95,3 +95,7 @@ pub trait FloatSimd: SIMD { /// c = a * b + c fn multiply_add(&mut self, a: Self, b: Self); } + +pub trait Shuffle { + fn shuffle(&self, indices: u8x16) -> Self; +} diff --git a/rust/lance-linalg/src/simd/f32.rs b/rust/lance-linalg/src/simd/f32.rs index 8deb50338b0..8091bc83a10 100644 --- a/rust/lance-linalg/src/simd/f32.rs +++ b/rust/lance-linalg/src/simd/f32.rs @@ -485,7 +485,6 @@ impl<'a> From<&'a [f32; 16]> for f32x16 { impl SIMD for f32x16 { #[inline] - fn splat(val: f32) -> Self { #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] unsafe { @@ -602,7 +601,6 @@ impl SIMD for f32x16 { } #[inline] - unsafe fn store_unaligned(&self, ptr: *mut f32) { #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] unsafe { diff --git a/rust/lance-linalg/src/simd/u8.rs b/rust/lance-linalg/src/simd/u8.rs new file mode 100644 index 00000000000..aa1b3f3c677 --- /dev/null +++ b/rust/lance-linalg/src/simd/u8.rs @@ -0,0 +1,440 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! `u8x8`, 8 of `u8` values + +use std::fmt::Formatter; + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; +use std::ops::{Add, AddAssign, Mul, Sub, SubAssign}; + +use super::{Shuffle, SIMD}; + +/// 16 of 8-bit `u8` values. +#[allow(non_camel_case_types)] +#[cfg(target_arch = "x86_64")] +#[derive(Clone, Copy)] +pub struct u8x16(pub __m128i); + +/// 16 of 8-bit `u8` values. +#[allow(non_camel_case_types)] +#[cfg(target_arch = "aarch64")] +#[derive(Clone, Copy)] +pub struct u8x16(pub uint8x16_t); + +#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] +#[derive(Clone, Copy)] +pub struct u8x16(pub [u8; 16]); + +impl u8x16 { + #[inline] + pub fn bit_and(self, mask: u8) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_and_si128(self.0, _mm_set1_epi8(mask as i8))) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vandq_u8(self.0, vdupq_n_u8(mask))) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..16 { + self.0[i] &= mask; + } + } + } + + #[inline] + pub fn right_shift(self) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + let shifted = _mm_srli_epi16(self.0, N); + let mask = _mm_set1_epi8((1_i8 << (8 - N)) - 1); + Self(_mm_and_si128(shifted, mask)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vshrq_n_u8::(self.0)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = self.0[i] >> N; + } + Self(result) + } + } +} + +impl std::fmt::Debug for u8x16 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut arr = [0u8; 16]; + unsafe { + self.store_unaligned(arr.as_mut_ptr()); + } + write!(f, "u8x16({:?})", arr) + } +} + +impl From<&[u8]> for u8x16 { + fn from(value: &[u8]) -> Self { + unsafe { Self::load_unaligned(value.as_ptr()) } + } +} + +impl<'a> From<&'a [u8; 16]> for u8x16 { + fn from(value: &'a [u8; 16]) -> Self { + unsafe { Self::load_unaligned(value.as_ptr()) } + } +} + +impl SIMD for u8x16 { + #[inline] + fn splat(val: u8) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_set1_epi8(val as i8)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vdupq_n_u8(val)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = val; + } + Self(result) + } + } + + #[inline] + fn zeros() -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_setzero_si128()) + } + #[cfg(target_arch = "aarch64")] + { + Self::splat(0) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + Self([0; 16]) + } + } + + #[inline] + unsafe fn load(ptr: *const u8) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_loadu_si128(ptr as *const __m128i)) + } + #[cfg(target_arch = "aarch64")] + { + Self::load_unaligned(ptr) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + Self::load_unaligned(ptr) + } + } + + #[inline] + unsafe fn load_unaligned(ptr: *const u8) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_loadu_si128(ptr as *const __m128i)) + } + #[cfg(target_arch = "aarch64")] + { + Self(vld1q_u8(ptr)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = *ptr.add(i); + } + Self(result) + } + } + + #[inline] + unsafe fn store(&self, ptr: *mut u8) { + #[cfg(target_arch = "x86_64")] + unsafe { + _mm_storeu_si128(ptr as *mut __m128i, self.0) + } + #[cfg(target_arch = "aarch64")] + unsafe { + vst1q_u8(ptr, self.0) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + self.store_unaligned(ptr); + } + } + + #[inline] + unsafe fn store_unaligned(&self, ptr: *mut u8) { + #[cfg(target_arch = "x86_64")] + unsafe { + _mm_storeu_si128(ptr as *mut __m128i, self.0) + } + #[cfg(target_arch = "aarch64")] + unsafe { + vst1q_u8(ptr, self.0) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..16 { + *ptr.add(i) = self.0[i]; + } + } + } + + fn reduce_sum(&self) -> u8 { + todo!("it is not implemented yet"); + } + + #[inline] + fn reduce_min(&self) -> u8 { + #[cfg(target_arch = "x86_64")] + unsafe { + let low = _mm_and_si128(self.0, _mm_set1_epi8(0xFF_u8 as i8)); + let high = _mm_srli_si128(self.0, 8); + let min_low = _mm_min_epu8(low, high); + let min_low = _mm_min_epu8(min_low, _mm_srli_si128(min_low, 4)); + let min_low = _mm_min_epu8(min_low, _mm_srli_si128(min_low, 2)); + let min_low = _mm_min_epu8(min_low, _mm_srli_si128(min_low, 1)); + _mm_extract_epi8(min_low, 0) as u8 + } + #[cfg(target_arch = "aarch64")] + unsafe { + vminvq_u8(self.0) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut min = self.0[0]; + for i in 1..16 { + min = std::cmp::min(min, self.0[i]); + } + min + } + } + + #[inline] + fn min(&self, rhs: &Self) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_min_epu8(self.0, rhs.0)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vminq_u8(self.0, rhs.0)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = std::cmp::min(self.0[i], rhs.0[i]); + } + Self(result) + } + } + + fn find(&self, _val: u8) -> Option { + todo!() + } +} + +impl Shuffle for u8x16 { + fn shuffle(&self, indices: u8x16) -> Self { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_shuffle_epi8(self.0, indices.0)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vqtbl1q_u8(self.0, indices.0)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = self.0[indices.0[i] as usize]; + } + Self(result) + } + } +} + +impl Add for u8x16 { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_adds_epu8(self.0, rhs.0)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vqaddq_u8(self.0, rhs.0)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = self.0[i].saturating_add(rhs.0[i]); + } + Self(result) + } + } +} + +impl AddAssign for u8x16 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + #[cfg(target_arch = "x86_64")] + unsafe { + self.0 = _mm_adds_epu8(self.0, rhs.0) + } + #[cfg(target_arch = "aarch64")] + unsafe { + self.0 = vqaddq_u8(self.0, rhs.0) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..16 { + self.0[i] = self.0[i].saturating_add(rhs.0[i]); + } + } + } +} + +impl Mul for u8x16 { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + #[cfg(target_arch = "x86_64")] + unsafe { + let a_lo = _mm_unpacklo_epi8(self.0, _mm_setzero_si128()); + let a_hi = _mm_unpackhi_epi8(self.0, _mm_setzero_si128()); + let b_lo = _mm_unpacklo_epi8(rhs.0, _mm_setzero_si128()); + let b_hi = _mm_unpackhi_epi8(rhs.0, _mm_setzero_si128()); + + let res_lo = _mm_mullo_epi16(a_lo, b_lo); + let res_hi = _mm_mullo_epi16(a_hi, b_hi); + + Self(_mm_packus_epi16(res_lo, res_hi)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vmulq_u8(self.0, rhs.0)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = self.0[i].wrapping_mul(rhs.0[i]); + } + Self(result) + } + } +} + +impl Sub for u8x16 { + type Output = Self; + + #[inline] + fn sub(self, rhs: Self) -> Self::Output { + #[cfg(target_arch = "x86_64")] + unsafe { + Self(_mm_sub_epi8(self.0, rhs.0)) + } + #[cfg(target_arch = "aarch64")] + unsafe { + Self(vsubq_u8(self.0, rhs.0)) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let mut result = [0u8; 16]; + for i in 0..16 { + result[i] = self.0[i].wrapping_sub(rhs.0[i]); + } + Self(result) + } + } +} + +impl SubAssign for u8x16 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + #[cfg(target_arch = "x86_64")] + unsafe { + self.0 = _mm_sub_epi8(self.0, rhs.0) + } + #[cfg(target_arch = "aarch64")] + unsafe { + self.0 = vsubq_u8(self.0, rhs.0) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + for i in 0..16 { + self.0[i] = self.0[i].wrapping_sub(rhs.0[i]); + } + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_basic_u8x16_ops() { + let a = (0..16).map(|f| f as u8).collect::>(); + let b = (16..32).map(|f| f as u8).collect::>(); + + let simd_a = unsafe { u8x16::load_unaligned(a.as_ptr()) }; + let simd_b = unsafe { u8x16::load_unaligned(b.as_ptr()) }; + + let simd_add = simd_a + simd_b; + (0..16) + .zip(simd_add.as_array().iter()) + .for_each(|(x, &y)| assert_eq!((x + x + 16) as u8, y)); + + // on x86_64, the result of simd_mul is saturated + // on aarch64, the result of simd_mul is not saturated + let simd_mul = simd_a * simd_b; + (0..16).zip(simd_mul.as_array().iter()).for_each(|(x, &y)| { + #[cfg(target_arch = "x86_64")] + assert_eq!(std::cmp::min(x * (x + 16), 255_i32) as u8, y); + #[cfg(target_arch = "aarch64")] + assert_eq!((x * (x + 16_i32)) as u8, y); + }); + } + + #[test] + fn test_saturating_add() { + let a = u8x16::splat(200); + let b = u8x16::splat(100); + let mut result = a + b; + + let expected = (0..16).map(|_| 255).collect::>(); + assert_eq!(result.as_array(), expected.as_slice()); + + result += b; + assert_eq!(result.as_array(), expected.as_slice()); + } +} diff --git a/rust/lance-table/Cargo.toml b/rust/lance-table/Cargo.toml index f4696760419..b0de3e6b873 100644 --- a/rust/lance-table/Cargo.toml +++ b/rust/lance-table/Cargo.toml @@ -22,7 +22,7 @@ arrow-buffer.workspace = true arrow-ipc.workspace = true arrow-schema.workspace = true async-trait.workspace = true -aws-credential-types.workspace = true +aws-credential-types = { workspace = true, optional = true } aws-sdk-dynamodb = { workspace = true, optional = true } byteorder.workspace = true bytes.workspace = true @@ -57,10 +57,15 @@ pprof = { workspace = true } [build-dependencies] prost-build.workspace = true +protobuf-src = { version = "2.1", optional = true } [features] -dynamodb = ["aws-sdk-dynamodb", "lazy_static"] -dynamodb_tests = ["dynamodb"] +dynamodb = ["aws-sdk-dynamodb", "lazy_static", "aws-credential-types", "lance-io/aws"] +protoc = ["dep:protobuf-src"] + +[package.metadata.docs.rs] +# docs.rs uses an older version of Ubuntu that does not have the necessary protoc version +features = ["protoc"] [[bench]] name = "row_id_index" diff --git a/rust/lance-table/build.rs b/rust/lance-table/build.rs index e0d0c153936..c4b2cc52dc5 100644 --- a/rust/lance-table/build.rs +++ b/rust/lance-table/build.rs @@ -6,6 +6,10 @@ use std::io::Result; fn main() -> Result<()> { println!("cargo:rerun-if-changed=protos"); + #[cfg(feature = "protoc")] + // Use vendored protobuf compiler if requested. + std::env::set_var("PROTOC", protobuf_src::protoc()); + let mut prost_build = prost_build::Config::new(); prost_build.extern_path(".lance.file", "::lance_file::format::pb"); prost_build.protoc_arg("--experimental_allow_proto3_optional"); diff --git a/rust/lance-table/src/feature_flags.rs b/rust/lance-table/src/feature_flags.rs index 1edeb520d5e..5ace6300d65 100644 --- a/rust/lance-table/src/feature_flags.rs +++ b/rust/lance-table/src/feature_flags.rs @@ -3,7 +3,7 @@ //! Feature flags -use snafu::{location, Location}; +use snafu::location; use crate::format::Manifest; use lance_core::{Error, Result}; diff --git a/rust/lance-table/src/format.rs b/rust/lance-table/src/format.rs index 5636f9138b3..0fa3ddbee9f 100644 --- a/rust/lance-table/src/format.rs +++ b/rust/lance-table/src/format.rs @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use arrow_buffer::ToByteSlice; -use snafu::{location, Location}; +use snafu::location; use uuid::Uuid; mod fragment; diff --git a/rust/lance-table/src/format/fragment.rs b/rust/lance-table/src/format/fragment.rs index 475b2fb23e7..01b05c6ce0d 100644 --- a/rust/lance-table/src/format/fragment.rs +++ b/rust/lance-table/src/format/fragment.rs @@ -7,7 +7,7 @@ use lance_file::format::{MAJOR_VERSION, MINOR_VERSION}; use lance_file::version::LanceFileVersion; use object_store::path::Path; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; use crate::format::pb; diff --git a/rust/lance-table/src/format/index.rs b/rust/lance-table/src/format/index.rs index a2a48c36136..480131f0b18 100644 --- a/rust/lance-table/src/format/index.rs +++ b/rust/lance-table/src/format/index.rs @@ -5,7 +5,7 @@ use deepsize::DeepSizeOf; use roaring::RoaringBitmap; -use snafu::{location, Location}; +use snafu::location; use uuid::Uuid; use super::pb; diff --git a/rust/lance-table/src/format/manifest.rs b/rust/lance-table/src/format/manifest.rs index 0546e040f44..6e0d3c21606 100644 --- a/rust/lance-table/src/format/manifest.rs +++ b/rust/lance-table/src/format/manifest.rs @@ -23,7 +23,7 @@ use lance_core::datatypes::{Schema, StorageClass}; use lance_core::{Error, Result}; use lance_io::object_store::ObjectStore; use lance_io::utils::read_struct; -use snafu::{location, Location}; +use snafu::location; /// Manifest of a dataset /// @@ -204,6 +204,20 @@ impl Manifest { .retain(|key, _| !delete_keys.contains(&key.as_str())); } + /// Replaces the schema metadata with the given key-value pairs. + pub fn update_schema_metadata(&mut self, new_metadata: HashMap) { + self.schema.metadata = new_metadata; + } + + /// Replaces the metadata of the field with the given id with the given key-value pairs. + /// + /// If the field does not exist in the schema, this is a no-op. + pub fn update_field_metadata(&mut self, field_id: i32, new_metadata: HashMap) { + if let Some(field) = self.schema.field_by_id_mut(field_id) { + field.metadata = new_metadata; + } + } + /// Check the current fragment list and update the high water mark pub fn update_max_fragment_id(&mut self) { let max_fragment_id = self @@ -779,8 +793,8 @@ mod tests { /*blob_dataset_version= */ None, ); - let mut config = HashMap::new(); - config.insert("lance:test".to_string(), "value".to_string()); + let mut config = manifest.config.clone(); + config.insert("lance.test".to_string(), "value".to_string()); config.insert("other-key".to_string(), "other-value".to_string()); manifest.update_config(config.clone()); diff --git a/rust/lance-table/src/io/commit.rs b/rust/lance-table/src/io/commit.rs index 0c5dd46628b..afdc716b0be 100644 --- a/rust/lance-table/src/io/commit.rs +++ b/rust/lance-table/src/io/commit.rs @@ -32,9 +32,11 @@ use futures::{ stream::BoxStream, StreamExt, TryStreamExt, }; +use lance_io::object_writer::WriteResult; use log::warn; +use object_store::PutOptions; use object_store::{path::Path, Error as ObjectStoreError, ObjectStore as OSObjectStore}; -use snafu::{location, Location}; +use snafu::location; use url::Url; #[cfg(feature = "dynamodb")] @@ -49,7 +51,7 @@ use { self::external_manifest::{ExternalManifestCommitHandler, ExternalManifestStore}, aws_credential_types::provider::error::CredentialsError, aws_credential_types::provider::ProvideCredentials, - lance_io::object_store::{build_aws_credential, StorageOptions}, + lance_io::object_store::{providers::aws::build_aws_credential, StorageOptions}, object_store::aws::AmazonS3ConfigKey, object_store::aws::AwsCredentialProvider, std::borrow::Cow, @@ -169,12 +171,14 @@ pub async fn migrate_scheme_to_v2(object_store: &ObjectStore, dataset_base: &Pat } /// Function that writes the manifest to the object store. +/// +/// Returns the size of the written manifest. pub type ManifestWriter = for<'a> fn( object_store: &'a ObjectStore, manifest: &'a mut Manifest, indices: Option>, path: &'a Path, -) -> BoxFuture<'a, Result<()>>; +) -> BoxFuture<'a, Result>; #[derive(Debug)] pub struct ManifestLocation { @@ -186,6 +190,39 @@ pub struct ManifestLocation { pub size: Option, /// Naming scheme of the manifest file. pub naming_scheme: ManifestNamingScheme, + /// Optional e-tag, used for integrity checks. Manifests should be immutable, so + /// if we detect a change in the e-tag, it means the manifest was tampered with. + /// This might happen if the dataset was deleted and then re-created. + pub e_tag: Option, +} + +impl TryFrom for ManifestLocation { + type Error = Error; + + fn try_from(meta: object_store::ObjectMeta) -> Result { + let filename = meta.location.filename().ok_or_else(|| Error::Internal { + message: "ObjectMeta location does not have a filename".to_string(), + location: location!(), + })?; + let scheme = + ManifestNamingScheme::detect_scheme(filename).ok_or_else(|| Error::Internal { + message: format!("Invalid manifest filename: '{}'", filename), + location: location!(), + })?; + let version = scheme + .parse_version(filename) + .ok_or_else(|| Error::Internal { + message: format!("Invalid manifest filename: '{}'", filename), + location: location!(), + })?; + Ok(Self { + version, + path: meta.location, + size: Some(meta.size as u64), + naming_scheme: scheme, + e_tag: meta.e_tag, + }) + } } /// Get the latest manifest path @@ -199,7 +236,7 @@ async fn current_manifest_path( } } - let manifest_files = object_store.inner.list(Some(&base.child(VERSIONS_DIR))); + let manifest_files = object_store.list(Some(base.child(VERSIONS_DIR))); let mut valid_manifests = manifest_files.try_filter_map(|res| { if let Some(scheme) = ManifestNamingScheme::detect_scheme(res.location.filename().unwrap()) @@ -248,6 +285,7 @@ async fn current_manifest_path( path: meta.location, size: Some(meta.size as u64), naming_scheme: scheme, + e_tag: meta.e_tag, }) } // If the first valid manifest we see if V1, assume for now that we are @@ -279,6 +317,7 @@ async fn current_manifest_path( path: current_meta.location, size: Some(current_meta.size as u64), naming_scheme: scheme, + e_tag: current_meta.e_tag, }) } (None, _) => Err(Error::NotFound { @@ -340,45 +379,66 @@ fn current_manifest_local(base: &Path) -> std::io::Result String { + let inode = get_inode(metadata); + let size = metadata.len(); + let mtime = metadata + .modified() + .ok() + .and_then(|mtime| mtime.duration_since(std::time::SystemTime::UNIX_EPOCH).ok()) + .unwrap_or_default() + .as_micros(); + + // Use an ETag scheme based on that used by many popular HTTP servers + // + // + format!("{inode:x}-{mtime:x}-{size:x}") +} + +#[cfg(unix)] +/// We include the inode when available to yield an ETag more resistant to collisions +/// and as used by popular web servers such as [Apache](https://httpd.apache.org/docs/2.2/mod/core.html#fileetag) +fn get_inode(metadata: &std::fs::Metadata) -> u64 { + std::os::unix::fs::MetadataExt::ino(metadata) +} + +#[cfg(not(unix))] +/// On platforms where an inode isn't available, fallback to just relying on size and mtime +fn get_inode(_metadata: &std::fs::Metadata) -> u64 { + 0 +} + async fn list_manifests<'a>( base_path: &Path, object_store: &'a dyn OSObjectStore, -) -> Result>> { +) -> Result>> { Ok(object_store .read_dir_all(&base_path.child(VERSIONS_DIR), None) .await? - .try_filter_map(|obj_meta| { - if obj_meta.location.extension() == Some(MANIFEST_EXTENSION) { - future::ready(Ok(Some(obj_meta.location))) - } else { - future::ready(Ok(None)) - } + .filter_map(|obj_meta| { + futures::future::ready( + obj_meta + .map(|m| ManifestLocation::try_from(m).ok()) + .transpose(), + ) }) .boxed()) } -pub fn parse_version_from_path(path: &Path) -> Result { - path.filename() - .and_then(|name| name.split_once('.')) - .filter(|(_, extension)| *extension == MANIFEST_EXTENSION) - .and_then(|(version, _)| version.parse::().ok()) - .ok_or(Error::Internal { - message: format!("Expected manifest file, but found {}", path), - location: location!(), - }) -} - fn make_staging_manifest_path(base: &Path) -> Result { let id = uuid::Uuid::new_v4().to_string(); Path::parse(format!("{base}-{id}")).map_err(|e| Error::IO { @@ -408,41 +468,6 @@ pub trait CommitHandler: Debug + Send + Sync { Ok(current_manifest_path(object_store, base_path).await?) } - /// Get the path to the latest version manifest of a dataset at the base_path - async fn resolve_latest_version( - &self, - base_path: &Path, - object_store: &ObjectStore, - ) -> std::result::Result { - // TODO: we need to pade 0's to the version number on the manifest file path - Ok(current_manifest_path(object_store, base_path).await?.path) - } - - // for default implementation, parse the version from the path - async fn resolve_latest_version_id( - &self, - base_path: &Path, - object_store: &ObjectStore, - ) -> Result { - Ok(current_manifest_path(object_store, base_path) - .await? - .version) - } - - /// Get the path to a specific versioned manifest of a dataset at the base_path - /// - /// The version must already exist. - async fn resolve_version( - &self, - base_path: &Path, - version: u64, - object_store: &dyn OSObjectStore, - ) -> std::result::Result { - Ok(default_resolve_version(base_path, version, object_store) - .await? - .path) - } - async fn resolve_version_location( &self, base_path: &Path, @@ -452,12 +477,11 @@ pub trait CommitHandler: Debug + Send + Sync { default_resolve_version(base_path, version, object_store).await } - /// List manifests that are available for a dataset at the base_path - async fn list_manifests<'a>( + async fn list_manifest_locations<'a>( &self, base_path: &Path, object_store: &'a dyn OSObjectStore, - ) -> Result>> { + ) -> Result>> { list_manifests(base_path, object_store).await } @@ -473,7 +497,7 @@ pub trait CommitHandler: Debug + Send + Sync { object_store: &ObjectStore, manifest_writer: ManifestWriter, naming_scheme: ManifestNamingScheme, - ) -> std::result::Result; + ) -> std::result::Result; /// Delete the recorded manifest information for a dataset at the base_path async fn delete(&self, _base_path: &Path) -> Result<()> { @@ -495,6 +519,7 @@ async fn default_resolve_version( // Both V1 and V2 should give the same path for detached versions path: ManifestNamingScheme::V2.manifest_path(base_path, version), size: None, + e_tag: None, }); } @@ -507,6 +532,7 @@ async fn default_resolve_version( path, size: Some(meta.size as u64), naming_scheme: scheme, + e_tag: meta.e_tag, }), Err(ObjectStoreError::NotFound { .. }) => { // fallback to V1 @@ -516,6 +542,7 @@ async fn default_resolve_version( path: scheme.manifest_path(base_path, version), size: None, naming_scheme: scheme, + e_tag: None, }) } Err(e) => Err(e.into()), @@ -603,9 +630,9 @@ pub async fn commit_handler_from_url( }; match url.scheme() { - // TODO: for Cloudflare R2 and Minio, we can provide a PutIfNotExist commit handler - // See: https://docs.rs/object_store/latest/object_store/aws/enum.S3ConditionalPut.html#variant.ETagMatch - "s3" => Ok(Arc::new(UnsafeCommitHandler)), + "s3" | "gs" | "az" | "memory" | "file" | "file-object-store" => { + Ok(Arc::new(ConditionalPutCommitHandler)) + } #[cfg(not(feature = "dynamodb"))] "s3+ddb" => Err(Error::InvalidInput { source: "`s3+ddb://` scheme requires `dynamodb` feature to be enabled".into(), @@ -669,7 +696,6 @@ pub async fn commit_handler_from_url( .await?, })) } - "gs" | "az" | "file" | "file-object-store" | "memory" => Ok(Arc::new(RenameCommitHandler)), _ => Ok(Arc::new(UnsafeCommitHandler)), } } @@ -728,7 +754,7 @@ impl CommitHandler for UnsafeCommitHandler { object_store: &ObjectStore, manifest_writer: ManifestWriter, naming_scheme: ManifestNamingScheme, - ) -> std::result::Result { + ) -> std::result::Result { // Log a one-time warning if !WARNED_ON_UNSAFE_COMMIT.load(std::sync::atomic::Ordering::Relaxed) { WARNED_ON_UNSAFE_COMMIT.store(true, std::sync::atomic::Ordering::Relaxed); @@ -740,9 +766,15 @@ impl CommitHandler for UnsafeCommitHandler { let version_path = naming_scheme.manifest_path(base_path, manifest.version); // Write the manifest naively - manifest_writer(object_store, manifest, indices, &version_path).await?; - - Ok(version_path) + let res = manifest_writer(object_store, manifest, indices, &version_path).await?; + + Ok(ManifestLocation { + version: manifest.version, + size: Some(res.size as u64), + naming_scheme, + path: version_path, + e_tag: res.e_tag, + }) } } @@ -788,7 +820,7 @@ impl CommitHandler for T { object_store: &ObjectStore, manifest_writer: ManifestWriter, naming_scheme: ManifestNamingScheme, - ) -> std::result::Result { + ) -> std::result::Result { let path = naming_scheme.manifest_path(base_path, manifest.version); // NOTE: once we have the lease we cannot use ? to return errors, since // we must release the lease before returning. @@ -817,7 +849,14 @@ impl CommitHandler for T { // Release the lock lease.release(res.is_ok()).await?; - res.map_err(|err| err.into()).map(|_| path) + let res = res?; + Ok(ManifestLocation { + version: manifest.version, + size: Some(res.size as u64), + naming_scheme, + path, + e_tag: res.e_tag, + }) } } @@ -831,7 +870,7 @@ impl CommitHandler for Arc { object_store: &ObjectStore, manifest_writer: ManifestWriter, naming_scheme: ManifestNamingScheme, - ) -> std::result::Result { + ) -> std::result::Result { self.as_ref() .commit( manifest, @@ -860,7 +899,7 @@ impl CommitHandler for RenameCommitHandler { object_store: &ObjectStore, manifest_writer: ManifestWriter, naming_scheme: ManifestNamingScheme, - ) -> std::result::Result { + ) -> std::result::Result { // Create a temporary object, then use `rename_if_not_exists` to commit. // If failed, clean up the temporary object. @@ -868,14 +907,23 @@ impl CommitHandler for RenameCommitHandler { let tmp_path = make_staging_manifest_path(&path)?; // Write the manifest to the temporary path - manifest_writer(object_store, manifest, indices, &tmp_path).await?; + let res = manifest_writer(object_store, manifest, indices, &tmp_path).await?; match object_store .inner .rename_if_not_exists(&tmp_path, &path) .await { - Ok(_) => Ok(path), + Ok(_) => { + // Successfully committed + Ok(ManifestLocation { + version: manifest.version, + path, + size: Some(res.size as u64), + naming_scheme, + e_tag: None, // Re-name can change e-tag. + }) + } Err(ObjectStoreError::AlreadyExists { .. }) => { // Another transaction has already been committed // Attempt to clean up temporary object, but ignore errors if we can't @@ -897,6 +945,60 @@ impl Debug for RenameCommitHandler { } } +pub struct ConditionalPutCommitHandler; + +#[async_trait::async_trait] +impl CommitHandler for ConditionalPutCommitHandler { + async fn commit( + &self, + manifest: &mut Manifest, + indices: Option>, + base_path: &Path, + object_store: &ObjectStore, + manifest_writer: ManifestWriter, + naming_scheme: ManifestNamingScheme, + ) -> std::result::Result { + let path = naming_scheme.manifest_path(base_path, manifest.version); + + let memory_store = ObjectStore::memory(); + let dummy_path = "dummy"; + manifest_writer(&memory_store, manifest, indices, &dummy_path.into()).await?; + let dummy_data = memory_store.read_one_all(&dummy_path.into()).await?; + let size = dummy_data.len() as u64; + let res = object_store + .inner + .put_opts( + &path, + dummy_data.into(), + PutOptions { + mode: object_store::PutMode::Create, + ..Default::default() + }, + ) + .await + .map_err(|err| match err { + ObjectStoreError::AlreadyExists { .. } | ObjectStoreError::Precondition { .. } => { + CommitError::CommitConflict + } + _ => CommitError::OtherError(err.into()), + })?; + + Ok(ManifestLocation { + version: manifest.version, + path, + size: Some(size), + naming_scheme, + e_tag: res.e_tag, + }) + } +} + +impl Debug for ConditionalPutCommitHandler { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConditionalPutCommitHandler").finish() + } +} + #[derive(Debug, Clone)] pub struct CommitConfig { pub num_retries: u32, diff --git a/rust/lance-table/src/io/commit/dynamodb.rs b/rust/lance-table/src/io/commit/dynamodb.rs index ceac8aec86b..a46adfa2dda 100644 --- a/rust/lance-table/src/io/commit/dynamodb.rs +++ b/rust/lance-table/src/io/commit/dynamodb.rs @@ -17,8 +17,9 @@ use aws_sdk_dynamodb::operation::{ }; use aws_sdk_dynamodb::types::{AttributeValue, KeyType}; use aws_sdk_dynamodb::Client; +use object_store::path::Path; +use snafu::location; use snafu::OptionExt; -use snafu::{location, Location}; use tokio::sync::RwLock; use crate::io::commit::external_manifest::ExternalManifestStore; @@ -26,6 +27,9 @@ use lance_core::error::box_error; use lance_core::error::NotFoundSnafu; use lance_core::{Error, Result}; +use super::external_manifest::detect_naming_scheme_from_path; +use super::ManifestLocation; + #[derive(Debug)] struct WrappedSdkError(SdkError); @@ -259,6 +263,7 @@ impl ExternalManifestStore for DynamoDBExternalManifestStore { "dynamodb not found: base_uri: {}; version: {}", base_uri, version ), + location: location!(), })?; let path = item @@ -274,8 +279,63 @@ impl ExternalManifestStore for DynamoDBExternalManifestStore { } } + async fn get_manifest_location( + &self, + base_uri: &str, + version: u64, + ) -> Result { + let get_item_result = self + .ddb_get() + .key(base_uri!(), AttributeValue::S(base_uri.into())) + .key(version!(), AttributeValue::N(version.to_string())) + .send() + .await + .wrap_err()?; + + let item = get_item_result.item.context(NotFoundSnafu { + uri: format!( + "dynamodb not found: base_uri: {}; version: {}", + base_uri, version + ), + location: location!(), + })?; + + let path = item + .get(path!()) + .ok_or_else(|| Error::io(format!("key {} is not present", path!()), location!()))? + .as_s() + .map_err(|_| Error::io(format!("key {} is not a string", path!()), location!()))? + .as_str(); + let path = Path::from(path); + + let size = item + .get("size") + .and_then(|attr| attr.as_n().ok().and_then(|v| v.parse().ok())); + + let e_tag = item.get("e_tag").and_then(|attr| attr.as_s().ok().cloned()); + + let naming_scheme = detect_naming_scheme_from_path(&path)?; + + Ok(ManifestLocation { + version, + path, + size, + naming_scheme, + e_tag, + }) + } + /// Get the latest version of a dataset at the base_uri async fn get_latest_version(&self, base_uri: &str) -> Result> { + self.get_latest_manifest_location(base_uri) + .await + .map(|location| location.map(|loc| (loc.version, loc.path.to_string()))) + } + + async fn get_latest_manifest_location( + &self, + base_uri: &str, + ) -> Result> { let query_result = self .ddb_query() .key_condition_expression(format!("{} = :{}", base_uri!(), base_uri!())) @@ -323,14 +383,30 @@ impl ExternalManifestStore for DynamoDBExternalManifestStore { ) )?; + let size = item.get("size").and_then(|attr| match attr { + AttributeValue::N(size) => size.parse().ok(), + _ => None, + }); + + let e_tag = item.get("e_tag").and_then(|attr| attr.as_s().ok().cloned()); + match (version_attribute, path_attribute) { - (AttributeValue::N(version), AttributeValue::S(path)) => Ok(Some(( - version.parse().map_err(|e| Error::io( + (AttributeValue::N(version), AttributeValue::S(path)) => { + let version = version.parse().map_err(|e| Error::io( format!("dynamodb error: could not parse the version number returned {}, error: {}", version, e), location!(), - ))?, - path.clone(), - ))), + ))?; + let path = Path::from(path.as_str()); + let naming_scheme = detect_naming_scheme_from_path(&path)?; + let location = ManifestLocation { + version, + path, + size, + naming_scheme, + e_tag, + }; + Ok(Some(location)) + }, _ => Err(Error::io( format!("dynamodb error: found entries for {base_uri} but the returned data is not number type"), location!(), @@ -342,12 +418,27 @@ impl ExternalManifestStore for DynamoDBExternalManifestStore { } /// Put the manifest path for a given base_uri and version, should fail if the version already exists - async fn put_if_not_exists(&self, base_uri: &str, version: u64, path: &str) -> Result<()> { - self.ddb_put() + async fn put_if_not_exists( + &self, + base_uri: &str, + version: u64, + path: &str, + size: u64, + e_tag: Option, + ) -> Result<()> { + let mut put_item = self + .ddb_put() .item(base_uri!(), AttributeValue::S(base_uri.into())) .item(version!(), AttributeValue::N(version.to_string())) .item(path!(), AttributeValue::S(path.to_string())) .item(committer!(), AttributeValue::S(self.committer_name.clone())) + .item("size", AttributeValue::N(size.to_string())); + + if let Some(e_tag) = e_tag { + put_item = put_item.item("e_tag", AttributeValue::S(e_tag)); + } + + put_item .condition_expression(format!( "attribute_not_exists({}) AND attribute_not_exists({})", base_uri!(), @@ -361,12 +452,27 @@ impl ExternalManifestStore for DynamoDBExternalManifestStore { } /// Put the manifest path for a given base_uri and version, should fail if the version **does not** already exist - async fn put_if_exists(&self, base_uri: &str, version: u64, path: &str) -> Result<()> { - self.ddb_put() + async fn put_if_exists( + &self, + base_uri: &str, + version: u64, + path: &str, + size: u64, + e_tag: Option, + ) -> Result<()> { + let mut put_item = self + .ddb_put() .item(base_uri!(), AttributeValue::S(base_uri.into())) .item(version!(), AttributeValue::N(version.to_string())) .item(path!(), AttributeValue::S(path.to_string())) .item(committer!(), AttributeValue::S(self.committer_name.clone())) + .item("size", AttributeValue::N(size.to_string())); + + if let Some(e_tag) = e_tag { + put_item = put_item.item("e_tag", AttributeValue::S(e_tag)); + } + + put_item .condition_expression(format!( "attribute_exists({}) AND attribute_exists({})", base_uri!(), diff --git a/rust/lance-table/src/io/commit/external_manifest.rs b/rust/lance-table/src/io/commit/external_manifest.rs index c03c4a99607..24f275ca5dd 100644 --- a/rust/lance-table/src/io/commit/external_manifest.rs +++ b/rust/lance-table/src/io/commit/external_manifest.rs @@ -9,10 +9,11 @@ use std::sync::Arc; use async_trait::async_trait; use lance_core::{Error, Result}; -use lance_io::object_store::{ObjectStore, ObjectStoreExt}; +use lance_io::object_store::ObjectStore; use log::warn; +use object_store::ObjectMeta; use object_store::{path::Path, Error as ObjectStoreError, ObjectStore as OSObjectStore}; -use snafu::{location, Location}; +use snafu::location; use super::{ current_manifest_path, default_resolve_version, make_staging_manifest_path, ManifestLocation, @@ -38,6 +39,23 @@ pub trait ExternalManifestStore: std::fmt::Debug + Send + Sync { /// Get the manifest path for a given base_uri and version async fn get(&self, base_uri: &str, version: u64) -> Result; + async fn get_manifest_location( + &self, + base_uri: &str, + version: u64, + ) -> Result { + let path = self.get(base_uri, version).await?; + let path = Path::from(path); + let naming_scheme = detect_naming_scheme_from_path(&path)?; + Ok(ManifestLocation { + version, + path, + size: None, + naming_scheme, + e_tag: None, + }) + } + /// Get the latest version of a dataset at the base_uri, and the path to the manifest. /// The path is provided as an optimization. The path is deterministic based on /// the version and the store should not customize it. @@ -61,6 +79,7 @@ pub trait ExternalManifestStore: std::fmt::Debug + Send + Sync { path, size: None, naming_scheme, + e_tag: None, }) }) .transpose() @@ -68,10 +87,24 @@ pub trait ExternalManifestStore: std::fmt::Debug + Send + Sync { } /// Put the manifest path for a given base_uri and version, should fail if the version already exists - async fn put_if_not_exists(&self, base_uri: &str, version: u64, path: &str) -> Result<()>; + async fn put_if_not_exists( + &self, + base_uri: &str, + version: u64, + path: &str, + size: u64, + e_tag: Option, + ) -> Result<()>; /// Put the manifest path for a given base_uri and version, should fail if the version **does not** already exist - async fn put_if_exists(&self, base_uri: &str, version: u64, path: &str) -> Result<()>; + async fn put_if_exists( + &self, + base_uri: &str, + version: u64, + path: &str, + size: u64, + e_tag: Option, + ) -> Result<()>; /// Delete the manifest information for given base_uri from the store async fn delete(&self, _base_uri: &str) -> Result<()> { @@ -79,9 +112,12 @@ pub trait ExternalManifestStore: std::fmt::Debug + Send + Sync { } } -fn detect_naming_scheme_from_path(path: &Path) -> Result { +pub(crate) fn detect_naming_scheme_from_path(path: &Path) -> Result { path.filename() - .and_then(ManifestNamingScheme::detect_scheme) + .and_then(|name| { + ManifestNamingScheme::detect_scheme(name) + .or_else(|| Some(ManifestNamingScheme::detect_scheme_staging(name))) + }) .ok_or_else(|| { Error::corrupt_file( path.clone(), @@ -109,28 +145,61 @@ impl ExternalManifestCommitHandler { /// by any number of readers or writers, so care should be taken to ensure /// that the manifest is not lost nor any errors occur due to duplicate /// operations. + #[allow(clippy::too_many_arguments)] async fn finalize_manifest( &self, base_path: &Path, staging_manifest_path: &Path, version: u64, + size: u64, + e_tag: Option, store: &dyn OSObjectStore, naming_scheme: ManifestNamingScheme, - ) -> std::result::Result { + ) -> std::result::Result { // step 1: copy the manifest to the final location let final_manifest_path = naming_scheme.manifest_path(base_path, version); - match store + + let copied = match store .copy(staging_manifest_path, &final_manifest_path) .await { - Ok(_) => {} - Err(ObjectStoreError::NotFound { .. }) => return Ok(final_manifest_path), // Another writer beat us to it. + Ok(_) => true, + Err(ObjectStoreError::NotFound { .. }) => false, // Another writer beat us to it. Err(e) => return Err(e.into()), }; + // On S3, the etag can change if originally was MultipartUpload and later was Copy + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_Object.html#AmazonS3-Type-Object-ETag + // We only do MultipartUpload for > 5MB files, so we can skip this check + // if size < 5MB + let e_tag = if size < 5 * 1024 * 1024 { + e_tag + } else { + let meta = store.head(&final_manifest_path).await?; + meta.e_tag + }; + + let location = ManifestLocation { + version, + path: final_manifest_path, + size: Some(size), + naming_scheme, + e_tag, + }; + + if !copied { + return Ok(location); + } + // step 2: flip the external store to point to the final location self.external_manifest_store - .put_if_exists(base_path.as_ref(), version, final_manifest_path.as_ref()) + .put_if_exists( + base_path.as_ref(), + version, + location.path.as_ref(), + size, + location.e_tag.clone(), + ) .await?; // step 3: delete the staging manifest @@ -140,7 +209,7 @@ impl ExternalManifestCommitHandler { Err(e) => return Err(e.into()), } - Ok(final_manifest_path) + Ok(location) } } @@ -151,86 +220,69 @@ impl CommitHandler for ExternalManifestCommitHandler { base_path: &Path, object_store: &ObjectStore, ) -> std::result::Result { - let path = self.resolve_latest_version(base_path, object_store).await?; - let naming_scheme = detect_naming_scheme_from_path(&path)?; - Ok(ManifestLocation { - version: self - .resolve_latest_version_id(base_path, object_store) - .await?, - path, - size: None, - naming_scheme, - }) - } - - /// Get the latest version of a dataset at the path - async fn resolve_latest_version( - &self, - base_path: &Path, - object_store: &ObjectStore, - ) -> std::result::Result { - let version = self + let location = self .external_manifest_store - .get_latest_version(base_path.as_ref()) + .get_latest_manifest_location(base_path.as_ref()) .await?; - match version { - Some((version, path)) => { + match location { + Some(ManifestLocation { + version, + path, + size, + naming_scheme, + e_tag, + }) => { // The path is finalized, no need to check object store - if path.ends_with(&format!(".{MANIFEST_EXTENSION}")) { - return Ok(Path::parse(path)?); + if path.extension() == Some(MANIFEST_EXTENSION) { + return Ok(ManifestLocation { + version, + path, + size, + naming_scheme, + e_tag, + }); } - // Detect naming scheme based on presence of zero padding. - let staged_path = Path::parse(&path)?; - let naming_scheme = - ManifestNamingScheme::detect_scheme_staging(staged_path.filename().unwrap()); - - self.finalize_manifest( - base_path, - &staged_path, - version, - &object_store.inner, - naming_scheme, - ) - .await + let (size, e_tag) = if let Some(size) = size { + (size, e_tag) + } else { + let meta = object_store.inner.head(&path).await?; + (meta.size as u64, meta.e_tag) + }; + + let final_location = self + .finalize_manifest( + base_path, + &path, + version, + size, + e_tag.clone(), + &object_store.inner, + naming_scheme, + ) + .await?; + + Ok(final_location) } // Dataset not found in the external store, this could be because the dataset did not // use external store for commit before. In this case, we search for the latest manifest - None => Ok(current_manifest_path(object_store, base_path).await?.path), + None => current_manifest_path(object_store, base_path).await, } } - async fn resolve_latest_version_id( - &self, - base_path: &Path, - object_store: &ObjectStore, - ) -> std::result::Result { - let version = self - .external_manifest_store - .get_latest_version(base_path.as_ref()) - .await?; - - match version { - Some((version, _)) => Ok(version), - None => Ok(current_manifest_path(object_store, base_path) - .await? - .version), - } - } - - async fn resolve_version( + async fn resolve_version_location( &self, base_path: &Path, version: u64, object_store: &dyn OSObjectStore, - ) -> std::result::Result { - let path_res = self + ) -> std::result::Result { + let location_res = self .external_manifest_store - .get(base_path.as_ref(), version) + .get_manifest_location(base_path.as_ref(), version) .await; - let path = match path_res { + let location = match location_res { Ok(p) => p, // not board external manifest yet, direct to object store Err(Error::NotFound { .. }) => { @@ -241,69 +293,73 @@ impl CommitHandler for ExternalManifestCommitHandler { location: location!(), })? .path; - if object_store.exists(&path).await? { - // best effort put, if it fails, it's okay - match self - .external_manifest_store - .put_if_not_exists(base_path.as_ref(), version, path.as_ref()) - .await - { - Ok(_) => {} - Err(e) => { + match object_store.head(&path).await { + Ok(ObjectMeta { size, e_tag, .. }) => { + let res = self + .external_manifest_store + .put_if_not_exists( + base_path.as_ref(), + version, + path.as_ref(), + size as u64, + e_tag.clone(), + ) + .await; + if let Err(e) = res { warn!( - "could not update external manifest store during load, with error: {}", - e - ); + "could not update external manifest store during load, with error: {}", + e + ); } + let naming_scheme = + ManifestNamingScheme::detect_scheme_staging(path.filename().unwrap()); + return Ok(ManifestLocation { + version, + path, + size: Some(size as u64), + naming_scheme, + e_tag, + }); } - return Ok(path); - } else { - return Err(Error::NotFound { - uri: path.to_string(), - location: location!(), - }); + Err(ObjectStoreError::NotFound { .. }) => { + return Err(Error::NotFound { + uri: path.to_string(), + location: location!(), + }); + } + Err(e) => return Err(e.into()), } } Err(e) => return Err(e), }; // finalized path, just return - let current_path = Path::parse(path)?; - if current_path.extension() == Some(MANIFEST_EXTENSION) { - return Ok(current_path); + if location.path.extension() == Some(MANIFEST_EXTENSION) { + return Ok(location); } let naming_scheme = - ManifestNamingScheme::detect_scheme_staging(current_path.filename().unwrap()); + ManifestNamingScheme::detect_scheme_staging(location.path.filename().unwrap()); + + let (size, e_tag) = if let Some(size) = location.size { + (size, location.e_tag.clone()) + } else { + let meta = object_store.head(&location.path).await?; + (meta.size as u64, meta.e_tag) + }; self.finalize_manifest( base_path, - &Path::parse(¤t_path)?, + &location.path, version, + size, + e_tag, object_store, naming_scheme, ) .await } - async fn resolve_version_location( - &self, - base_path: &Path, - version: u64, - object_store: &dyn OSObjectStore, - ) -> std::result::Result { - let path = self - .resolve_version(base_path, version, object_store) - .await?; - let naming_scheme = detect_naming_scheme_from_path(&path)?; - Ok(ManifestLocation { - version, - path, - size: None, - naming_scheme, - }) - } - async fn commit( &self, manifest: &mut Manifest, @@ -312,19 +368,25 @@ impl CommitHandler for ExternalManifestCommitHandler { object_store: &ObjectStore, manifest_writer: ManifestWriter, naming_scheme: ManifestNamingScheme, - ) -> std::result::Result { + ) -> std::result::Result { // path we get here is the path to the manifest we want to write // use object_store.base_path.as_ref() for getting the root of the dataset // step 1: Write the manifest we want to commit to object store with a temporary name let path = naming_scheme.manifest_path(base_path, manifest.version); let staging_path = make_staging_manifest_path(&path)?; - manifest_writer(object_store, manifest, indices, &staging_path).await?; + let write_res = manifest_writer(object_store, manifest, indices, &staging_path).await?; // step 2 & 3: Try to commit this version to external store, return err on failure let res = self .external_manifest_store - .put_if_not_exists(base_path.as_ref(), manifest.version, staging_path.as_ref()) + .put_if_not_exists( + base_path.as_ref(), + manifest.version, + staging_path.as_ref(), + write_res.size as u64, + write_res.e_tag.clone(), + ) .await .map_err(|_| CommitError::CommitConflict {}); @@ -338,15 +400,15 @@ impl CommitHandler for ExternalManifestCommitHandler { return Err(err); } - let scheme = detect_naming_scheme_from_path(&path)?; - Ok(self .finalize_manifest( base_path, &staging_path, manifest.version, + write_res.size as u64, + write_res.e_tag, &object_store.inner, - scheme, + naming_scheme, ) .await?) } diff --git a/rust/lance-table/src/io/deletion.rs b/rust/lance-table/src/io/deletion.rs index da113276469..2d43dff1d41 100644 --- a/rust/lance-table/src/io/deletion.rs +++ b/rust/lance-table/src/io/deletion.rs @@ -11,13 +11,14 @@ use arrow_schema::{ArrowError, DataType, Field, Schema}; use bytes::Buf; use lance_core::error::{box_error, CorruptFileSnafu}; use lance_core::utils::deletion::DeletionVector; +use lance_core::utils::tracing::{AUDIT_MODE_CREATE, AUDIT_TYPE_DELETION, TRACE_FILE_AUDIT}; use lance_core::{Error, Result}; use lance_io::object_store::ObjectStore; use object_store::path::Path; use rand::Rng; use roaring::bitmap::RoaringBitmap; -use snafu::{location, Location, ResultExt}; -use tracing::instrument; +use snafu::{location, ResultExt}; +use tracing::{info, instrument}; use crate::format::{DeletionFile, DeletionFileType, Fragment}; @@ -56,8 +57,8 @@ pub async fn write_deletion_file( removed_rows: &DeletionVector, object_store: &ObjectStore, ) -> Result> { - match removed_rows { - DeletionVector::NoDeletions => Ok(None), + let deletion_file = match removed_rows { + DeletionVector::NoDeletions => None, DeletionVector::Set(set) => { let id = rand::thread_rng().gen::(); let deletion_file = DeletionFile { @@ -90,7 +91,9 @@ pub async fn write_deletion_file( object_store.put(&path, &out).await?; - Ok(Some(deletion_file)) + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_CREATE, type=AUDIT_TYPE_DELETION, path = path.to_string()); + + Some(deletion_file) } DeletionVector::Bitmap(bitmap) => { let id = rand::thread_rng().gen::(); @@ -107,9 +110,12 @@ pub async fn write_deletion_file( object_store.put(&path, &out).await?; - Ok(Some(deletion_file)) + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_CREATE, type=AUDIT_TYPE_DELETION, path = path.to_string()); + + Some(deletion_file) } - } + }; + Ok(deletion_file) } /// Read a deletion file for a fragment. @@ -136,7 +142,10 @@ pub async fn read_deletion_file( let mut batches: Vec = ArrowFileReader::try_new(data, None)? .collect::>() .map_err(box_error) - .context(CorruptFileSnafu { path: path.clone() })?; + .context(CorruptFileSnafu { + path: path.clone(), + location: location!(), + })?; if batches.len() != 1 { return Err(Error::corrupt_file( @@ -189,7 +198,10 @@ pub async fn read_deletion_file( let reader = data.reader(); let bitmap = RoaringBitmap::deserialize_from(reader) .map_err(box_error) - .context(CorruptFileSnafu { path })?; + .context(CorruptFileSnafu { + path, + location: location!(), + })?; Ok(Some(DeletionVector::Bitmap(bitmap))) } diff --git a/rust/lance-table/src/io/manifest.rs b/rust/lance-table/src/io/manifest.rs index 766b665a336..ede3d95e5ad 100644 --- a/rust/lance-table/src/io/manifest.rs +++ b/rust/lance-table/src/io/manifest.rs @@ -10,7 +10,7 @@ use lance_arrow::DataTypeExt; use lance_file::{version::LanceFileVersion, writer::ManifestProvider}; use object_store::path::Path; use prost::Message; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use lance_core::{datatypes::Schema, Error, Result}; @@ -45,6 +45,13 @@ pub async fn read_manifest( end: file_size, }; let buf = object_store.inner.get_range(path, range).await?; + + // In case of corruption, the known_size might be wrong. We can retry without + // the size to be more robust. + if (buf.len() < 16 || !buf.ends_with(MAGIC)) && known_size.is_some() { + return Box::pin(read_manifest(object_store, path, None)).await; + } + if buf.len() < 16 { return Err(Error::io( "Invalid format: file size is smaller than 16 bytes".to_string(), diff --git a/rust/lance-table/src/rowids.rs b/rust/lance-table/src/rowids.rs index 38ee381f8d2..54330873e13 100644 --- a/rust/lance-table/src/rowids.rs +++ b/rust/lance-table/src/rowids.rs @@ -29,7 +29,7 @@ use lance_core::{utils::mask::RowIdTreeMap, Error, Result}; use lance_io::ReadBatchParams; pub use serde::{read_row_ids, write_row_ids}; -use snafu::{location, Location}; +use snafu::location; use segment::U64Segment; @@ -204,10 +204,7 @@ impl RowIdSequence { // If we've cycled through all segments, we know the row id is not in the sequence. while i < self.0.len() { let (segment_idx, segment) = segment_iter.next().unwrap(); - if segment - .range() - .map_or(false, |range| range.contains(&row_id)) - { + if segment.range().is_some_and(|range| range.contains(&row_id)) { if let Some(offset) = segment.position(row_id) { segment_matches.get_mut(segment_idx).unwrap().push(offset); } @@ -343,7 +340,7 @@ pub struct RowIdSeqSlice<'a> { offset_last: usize, } -impl<'a> RowIdSeqSlice<'a> { +impl RowIdSeqSlice<'_> { pub fn iter(&self) -> impl Iterator + '_ { let mut known_size = self.segments.iter().map(|segment| segment.len()).sum(); known_size -= self.offset_start; diff --git a/rust/lance-table/src/rowids/bitmap.rs b/rust/lance-table/src/rowids/bitmap.rs index dc628ddcf8f..bee46ada8fa 100644 --- a/rust/lance-table/src/rowids/bitmap.rs +++ b/rust/lance-table/src/rowids/bitmap.rs @@ -21,12 +21,12 @@ impl std::fmt::Debug for Bitmap { impl Bitmap { pub fn new_empty(len: usize) -> Self { - let data = vec![0; (len + 7) / 8]; + let data = vec![0; len.div_ceil(8)]; Self { data, len } } pub fn new_full(len: usize) -> Self { - let mut data = vec![0xff; (len + 7) / 8]; + let mut data = vec![0xff; len.div_ceil(8)]; // Zero past the end of len let remainder = len % 8; if remainder != 0 { @@ -92,7 +92,7 @@ pub struct BitmapSlice<'a> { len: usize, } -impl<'a> BitmapSlice<'a> { +impl BitmapSlice<'_> { pub fn count_ones(&self) -> usize { if self.len == 0 { return 0; @@ -138,7 +138,7 @@ impl<'a> BitmapSlice<'a> { } } -impl<'a> From> for Bitmap { +impl From> for Bitmap { fn from(slice: BitmapSlice) -> Self { let mut bitmap = Self::new_empty(slice.len); for i in 0..slice.len { diff --git a/rust/lance-table/src/rowids/index.rs b/rust/lance-table/src/rowids/index.rs index e9d954f4d6c..16c872adfd1 100644 --- a/rust/lance-table/src/rowids/index.rs +++ b/rust/lance-table/src/rowids/index.rs @@ -8,7 +8,7 @@ use deepsize::DeepSizeOf; use lance_core::utils::address::RowAddress; use lance_core::{Error, Result}; use rangemap::RangeInclusiveMap; -use snafu::{location, Location}; +use snafu::location; use super::{RowIdSequence, U64Segment}; diff --git a/rust/lance-table/src/rowids/serde.rs b/rust/lance-table/src/rowids/serde.rs index 6713411553d..75c4c45278e 100644 --- a/rust/lance-table/src/rowids/serde.rs +++ b/rust/lance-table/src/rowids/serde.rs @@ -3,7 +3,7 @@ use crate::{format::pb, rowids::bitmap::Bitmap}; use lance_core::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; use super::{encoded_array::EncodedU64Array, RowIdSequence, U64Segment}; use prost::Message; diff --git a/rust/lance-table/src/utils/stream.rs b/rust/lance-table/src/utils/stream.rs index 9edea6341a4..e5474c2e920 100644 --- a/rust/lance-table/src/utils/stream.rs +++ b/rust/lance-table/src/utils/stream.rs @@ -335,12 +335,14 @@ mod tests { let left = batch_task_stream( lance_datagen::gen() .col("x", lance_datagen::array::step::()) - .into_reader_stream(RowCount::from(100), BatchCount::from(10)), + .into_reader_stream(RowCount::from(100), BatchCount::from(10)) + .0, ); let right = batch_task_stream( lance_datagen::gen() .col("y", lance_datagen::array::step::()) - .into_reader_stream(RowCount::from(100), BatchCount::from(10)), + .into_reader_stream(RowCount::from(100), BatchCount::from(10)) + .0, ); let merged = super::merge_streams(vec![left, right]) @@ -370,7 +372,9 @@ mod tests { datagen = datagen.col("x", lance_datagen::array::rand::()); } let data = batch_task_stream( - datagen.into_reader_stream(RowCount::from(10), BatchCount::from(10)), + datagen + .into_reader_stream(RowCount::from(10), BatchCount::from(10)) + .0, ); let config = RowIdAndDeletesConfig { @@ -465,7 +469,8 @@ mod tests { // 100 rows across 10 batches of 10 rows let data = batch_task_stream( datagen - .into_reader_stream(RowCount::from(10), BatchCount::from(10)), + .into_reader_stream(RowCount::from(10), BatchCount::from(10)) + .0, ); let config = RowIdAndDeletesConfig { diff --git a/rust/lance-testing/src/datagen.rs b/rust/lance-testing/src/datagen.rs index db5bba8bc61..df38f6cacd6 100644 --- a/rust/lance-testing/src/datagen.rs +++ b/rust/lance-testing/src/datagen.rs @@ -9,7 +9,8 @@ use std::{iter::repeat_with, ops::Range}; use arrow_array::types::ArrowPrimitiveType; use arrow_array::{ - Float32Array, Int32Array, PrimitiveArray, RecordBatch, RecordBatchIterator, RecordBatchReader, + Float32Array, Int32Array, Int8Array, PrimitiveArray, RecordBatch, RecordBatchIterator, + RecordBatchReader, }; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_arrow::{fixed_size_list_type, ArrowFloatType, FixedSizeListArrayExt}; @@ -222,6 +223,13 @@ pub fn generate_random_array(n: usize) -> Float32Array { Float32Array::from_iter_values(repeat_with(|| rng.gen::()).take(n)) } +/// Create a random float32 array where each element is uniformly +/// distributed between [0..1] +pub fn generate_random_int8_array(n: usize) -> Int8Array { + let mut rng = rand::thread_rng(); + Int8Array::from_iter_values(repeat_with(|| rng.gen::()).take(n)) +} + /// Create a random primitive array where each element is uniformly distributed a /// given range. pub fn generate_random_array_with_range( diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 05b0a0a3817..924a673032f 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -28,6 +28,7 @@ lance-table = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ipc = { workspace = true } arrow-ord = { workspace = true } arrow-row = { workspace = true } arrow-schema = { workspace = true } @@ -39,12 +40,12 @@ bytes.workspace = true chrono.workspace = true clap = { version = "4.1.1", features = ["derive"], optional = true } # This is already used by datafusion -dashmap = "5" +dashmap = "6" deepsize.workspace = true # matches arrow-rs use half.workspace = true itertools.workspace = true -object_store = { workspace = true, features = ["aws", "gcp", "azure"] } +object_store = { workspace = true } aws-credential-types.workspace = true pin-project.workspace = true prost.workspace = true @@ -60,6 +61,8 @@ arrow.workspace = true datafusion.workspace = true datafusion-functions.workspace = true datafusion-physical-expr.workspace = true +datafusion-expr.workspace = true +either.workspace = true lapack = { version = "0.19.0", optional = true } snafu = { workspace = true } log = { workspace = true } @@ -69,10 +72,12 @@ moka.workspace = true permutation = { version = "0.4.0" } tantivy.workspace = true tfrecord = { version = "0.15.0", optional = true, features = ["async"] } +prost_old = { version = "0.12.6", package = "prost", optional = true } aws-sdk-dynamodb = { workspace = true, optional = true } tempfile.workspace = true tracing.workspace = true lazy_static = { workspace = true } +humantime = { workspace = true } async_cell = "0.2.2" [target.'cfg(target_os = "linux")'.dev-dependencies] @@ -80,9 +85,6 @@ pprof.workspace = true # Need this so we can prevent dynamic linking in binaries (see cli feature) lzma-sys = { version = "0.1" } -[build-dependencies] -prost-build.workspace = true - [dev-dependencies] lance-test-macros = { workspace = true } lance-datagen = { workspace = true } @@ -95,20 +97,33 @@ all_asserts = "2.3.1" mock_instant.workspace = true lance-testing = { workspace = true } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } -env_logger = "0.10.0" +env_logger = "0.11.7" tracing-chrome = "0.7.1" rstest = { workspace = true } random_word = { version = "0.4.3", features = ["en"] } +# For S3 / DynamoDB tests +aws-config = { workspace = true } +aws-sdk-s3 = { workspace = true } [features] +default = ["aws", "azure", "gcp"] fp16kernels = ["lance-linalg/fp16kernels"] # Prevent dynamic linking of lzma, which comes from datafusion cli = ["clap", "lzma-sys/static"] -tensorflow = ["tfrecord"] +tensorflow = ["tfrecord", "prost_old"] dynamodb = ["lance-table/dynamodb", "aws-sdk-dynamodb"] dynamodb_tests = ["dynamodb"] substrait = ["lance-datafusion/substrait"] +protoc = [ + "lance-encoding/protoc", + "lance-file/protoc", + "lance-index/protoc", + "lance-table/protoc", +] +aws = ["lance-io/aws"] +gcp = ["lance-io/gcp"] +azure = ["lance-io/azure"] [[bin]] name = "lq" diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 58c261ccf50..7cf852fcd04 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -10,16 +10,17 @@ use arrow_array::{ use async_trait::async_trait; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::{physical_plan::SendableRecordBatchStream, scalar::ScalarValue}; -use futures::TryStreamExt; +use futures::{FutureExt, TryStreamExt}; use lance::{io::ObjectStore, Dataset}; use lance_core::{cache::FileMetadataCache, Result}; use lance_datafusion::utils::reader_to_stream; use lance_datagen::{array, gen, BatchCount, RowCount}; +use lance_index::metrics::NoOpMetricsCollector; use lance_index::scalar::{ - btree::{train_btree_index, BTreeIndex, TrainingSource}, + btree::{train_btree_index, BTreeIndex, TrainingSource, DEFAULT_BTREE_BATCH_SIZE}, flat::FlatIndexMetadata, lance_format::LanceIndexStore, - IndexStore, SargableQuery, ScalarIndex, + IndexStore, SargableQuery, ScalarIndex, SearchResult, }; #[cfg(target_os = "linux")] use pprof::criterion::{Output, PProfProfiler}; @@ -60,11 +61,13 @@ impl TrainingSource for BenchmarkDataSource { } impl BenchmarkFixture { - #[allow(dead_code)] fn test_store(tempdir: &TempDir) -> Arc { let test_path = tempdir.path(); let (object_store, test_path) = - ObjectStore::from_path(test_path.as_os_str().to_str().unwrap()).unwrap(); + ObjectStore::from_uri(test_path.as_os_str().to_str().unwrap()) + .now_or_never() + .unwrap() + .unwrap(); Arc::new(LanceIndexStore::new( object_store, test_path, @@ -72,16 +75,6 @@ impl BenchmarkFixture { )) } - fn legacy_test_store(tempdir: &TempDir) -> Arc { - let test_path = tempdir.path(); - let (object_store, test_path) = - ObjectStore::from_path(test_path.as_os_str().to_str().unwrap()).unwrap(); - Arc::new( - LanceIndexStore::new(object_store, test_path, FileMetadataCache::no_cache()) - .with_legacy_format(true), - ) - } - async fn write_baseline_data(tempdir: &TempDir) -> Arc { let test_path = tempdir.path().as_os_str().to_str().unwrap(); Arc::new( @@ -98,6 +91,7 @@ impl BenchmarkFixture { Box::new(BenchmarkDataSource {}), &sub_index_trainer, index_store.as_ref(), + DEFAULT_BTREE_BATCH_SIZE as u32, ) .await .unwrap(); @@ -105,7 +99,7 @@ impl BenchmarkFixture { async fn open() -> Self { let tempdir = tempfile::tempdir().unwrap(); - let index_store = Self::legacy_test_store(&tempdir); + let index_store = Self::test_store(&tempdir); let baseline_dataset = Self::write_baseline_data(&tempdir).await; Self::train_scalar_index(&index_store).await; @@ -135,10 +129,16 @@ async fn baseline_equality_search(fixture: &BenchmarkFixture) { } async fn warm_indexed_equality_search(index: &BTreeIndex) { - let row_ids = index - .search(&SargableQuery::Equals(ScalarValue::UInt32(Some(10000)))) + let result = index + .search( + &SargableQuery::Equals(ScalarValue::UInt32(Some(10000))), + &NoOpMetricsCollector, + ) .await .unwrap(); + let SearchResult::Exact(row_ids) = result else { + panic!("Expected exact results") + }; assert_eq!(row_ids.len(), Some(1)); } @@ -161,27 +161,41 @@ async fn baseline_inequality_search(fixture: &BenchmarkFixture) { } async fn warm_indexed_inequality_search(index: &BTreeIndex) { - let row_ids = index - .search(&SargableQuery::Range( - std::ops::Bound::Included(ScalarValue::UInt32(Some(50_000_000))), - std::ops::Bound::Unbounded, - )) + let result = index + .search( + &SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::UInt32(Some(50_000_000))), + std::ops::Bound::Unbounded, + ), + &NoOpMetricsCollector, + ) .await .unwrap(); + let SearchResult::Exact(row_ids) = result else { + panic!("Expected exact results") + }; + // 100Mi - 50M = 54,857,600 assert_eq!(row_ids.len(), Some(54857600)); } async fn warm_indexed_isin_search(index: &BTreeIndex) { - let row_ids = index - .search(&SargableQuery::IsIn(vec![ - ScalarValue::UInt32(Some(10000)), - ScalarValue::UInt32(Some(50000000)), - ScalarValue::UInt32(Some(150000000)), // Not found - ScalarValue::UInt32(Some(287123)), - ])) + let result = index + .search( + &SargableQuery::IsIn(vec![ + ScalarValue::UInt32(Some(10000)), + ScalarValue::UInt32(Some(50000000)), + ScalarValue::UInt32(Some(150000000)), // Not found + ScalarValue::UInt32(Some(287123)), + ]), + &NoOpMetricsCollector, + ) .await .unwrap(); + let SearchResult::Exact(row_ids) = result else { + panic!("Expected exact results") + }; + // Only 3 because 150M is not in dataset assert_eq!(row_ids.len(), Some(3)); } diff --git a/rust/lance/benches/take.rs b/rust/lance/benches/take.rs index f90812b3411..8b4af6402d3 100644 --- a/rust/lance/benches/take.rs +++ b/rust/lance/benches/take.rs @@ -12,15 +12,12 @@ use lance::{ dataset::{builder::DatasetBuilder, ProjectionRequest}, }; use lance_file::version::LanceFileVersion; -use lance_table::io::commit::RenameCommitHandler; -use object_store::ObjectStore; #[cfg(target_os = "linux")] use pprof::criterion::{Output, PProfProfiler}; use rand::Rng; use std::sync::Arc; #[cfg(target_os = "linux")] use std::time::Duration; -use url::Url; use lance::dataset::{Dataset, WriteMode, WriteParams}; @@ -95,7 +92,7 @@ async fn create_dataset( num_batches: i32, file_size: i32, ) -> Dataset { - let store = create_file( + create_file( std::path::Path::new(path), WriteMode::Create, data_storage_version, @@ -104,15 +101,7 @@ async fn create_dataset( ) .await; - DatasetBuilder::from_uri(path) - .with_object_store( - store, - Url::parse(path).unwrap(), - Arc::new(RenameCommitHandler), - ) - .load() - .await - .unwrap() + DatasetBuilder::from_uri(path).load().await.unwrap() } async fn create_file( @@ -121,7 +110,7 @@ async fn create_file( data_storage_version: LanceFileVersion, num_batches: i32, file_size: i32, -) -> Arc { +) { let schema = Arc::new(ArrowSchema::new(vec![ Field::new("i", DataType::Int32, false), Field::new("f", DataType::Float32, false), @@ -183,10 +172,9 @@ async fn create_file( ..Default::default() }; let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - let ds = Dataset::write(reader, test_uri, Some(write_params)) + Dataset::write(reader, test_uri, Some(write_params)) .await .unwrap(); - ds.object_store.inner.clone() } #[cfg(target_os = "linux")] diff --git a/rust/lance/examples/full_text_search.rs b/rust/lance/examples/full_text_search.rs index 29327426fbb..135861aa80d 100644 --- a/rust/lance/examples/full_text_search.rs +++ b/rust/lance/examples/full_text_search.rs @@ -74,8 +74,8 @@ async fn main() { } let dataset = Dataset::open(dataset_dir.as_ref()).await.unwrap(); - let query = tokens[0]; - let query = FullTextSearchQuery::new(query.to_owned()).limit(Some(10)); + let query_string = tokens[0]; + let query = FullTextSearchQuery::new(query_string.to_owned()).limit(Some(10)); println!("query: {:?}", query); let batch = dataset .scan() @@ -108,7 +108,7 @@ async fn main() { .try_into_batch() .await .unwrap(); - let flat_results = flat_full_text_search(&[&batch], "doc", &query.query, None) + let flat_results = flat_full_text_search(&[&batch], "doc", query_string, None) .unwrap() .into_iter() .collect::>(); diff --git a/rust/lance/examples/hnsw.rs b/rust/lance/examples/hnsw.rs index 414038167fa..9c8b9d558ae 100644 --- a/rust/lance/examples/hnsw.rs +++ b/rust/lance/examples/hnsw.rs @@ -16,7 +16,7 @@ use futures::StreamExt; use lance::Dataset; use lance_index::vector::v3::subindex::IvfSubIndex; use lance_index::vector::{ - flat::storage::FlatStorage, + flat::storage::FlatFloatStorage, hnsw::{builder::HnswBuildParams, HNSW}, }; use lance_linalg::distance::DistanceType; @@ -79,7 +79,7 @@ async fn main() { let fsl = concat(&arrs).unwrap().as_fixed_size_list().clone(); println!("Loaded {:?} batches", fsl.len()); - let vector_store = Arc::new(FlatStorage::new(fsl.clone(), DistanceType::L2)); + let vector_store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2)); let q = fsl.value(0); let k = 10; diff --git a/rust/lance/src/arrow/json.rs b/rust/lance/src/arrow/json.rs index 2a0612ed578..49165912742 100644 --- a/rust/lance/src/arrow/json.rs +++ b/rust/lance/src/arrow/json.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use std::sync::Arc; -use snafu::{location, Location}; +use snafu::location; use arrow_schema::{DataType, Field, Schema}; use serde::{Deserialize, Serialize}; @@ -82,6 +82,13 @@ impl TryFrom<&DataType> for JsonDataType { length: Some(*len as usize), }); } + DataType::FixedSizeBinary(len) => { + return Ok(Self { + type_: "fixed_size_binary".to_string(), + fields: None, + length: Some(*len as usize), + }); + } DataType::Struct(fields) => { let fields = fields .iter() @@ -157,6 +164,13 @@ impl TryFrom<&JsonDataType> for DataType { _ => unreachable!(), } } + "fixed_size_binary" => { + let length = value.length.ok_or_else(|| Error::Arrow { + message: "Json conversion: FixedSizeBinary type requires a length".to_string(), + location: location!(), + })?; + Ok(Self::FixedSizeBinary(length as i32)) + } _ => Err(Error::Arrow { message: format!("Json conversion: Unsupported type: {value:?}"), location: location!(), @@ -172,6 +186,9 @@ pub struct JsonField { #[serde(rename = "type")] type_: JsonDataType, nullable: bool, + + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option>, } impl TryFrom<&Field> for JsonField { @@ -180,10 +197,17 @@ impl TryFrom<&Field> for JsonField { fn try_from(field: &Field) -> Result { let data_type = JsonDataType::try_new(field.data_type())?; + let metadata = if field.metadata().is_empty() { + None + } else { + Some(field.metadata().clone()) + }; + Ok(Self { name: field.name().to_string(), nullable: field.is_nullable(), type_: data_type, + metadata, }) } } @@ -193,7 +217,11 @@ impl TryFrom<&JsonField> for Field { fn try_from(value: &JsonField) -> Result { let data_type = DataType::try_from(&value.type_)?; - Ok(Self::new(&value.name, data_type, value.nullable)) + let mut field = Self::new(&value.name, data_type, value.nullable); + if let Some(metadata) = value.metadata.clone() { + field.set_metadata(metadata); + } + Ok(field) } } @@ -361,6 +389,14 @@ mod test { ), ); + assert_type_json_str( + DataType::FixedSizeBinary(32), + json!({ + "type": "fixed_size_binary", + "length": 32 + }), + ); + assert_type_json_str( DataType::Struct( vec![ @@ -445,4 +481,55 @@ mod test { let actual = Schema::from_json(&json_str).unwrap(); assert_eq!(schema, actual); } + + #[test] + fn test_metadata_roundtrip() { + let mut schema_metadata = HashMap::new(); + schema_metadata.insert("sk_1".to_string(), "sv_1".to_string()); + + let mut field1_metadata = HashMap::new(); + field1_metadata.insert("fk_1".to_string(), "fv_1".to_string()); + + let field1 = Field::new("a", DataType::UInt8, false).with_metadata(field1_metadata.clone()); + let field2 = Field::new("b", DataType::Int32, true); + + let schema = Schema::new_with_metadata(vec![field1, field2], schema_metadata.clone()); + + let json_str = schema.to_json().unwrap(); + assert_eq!( + serde_json::from_str::(&json_str).unwrap(), + json!({ + "fields": [ + { + "name": "a", + "type": { + "type": "uint8" + }, + "nullable": false, + "metadata": { + "fk_1": "fv_1" + } + }, + { + "name": "b", + "type": { + "type": "int32" + }, + "nullable": true + } + ], + "metadata": { + "sk_1": "sv_1" + } + }) + ); + + let actual = Schema::from_json(&json_str).unwrap(); + assert_eq!(schema, actual); + + assert_eq!(actual.metadata, schema_metadata); + + assert_eq!(actual.field(0).metadata(), &field1_metadata); + assert_eq!(actual.field(1).metadata(), &HashMap::new()); + } } diff --git a/rust/lance/src/bin/lq.rs b/rust/lance/src/bin/lq.rs index 9cbe0fac334..2615d5e6085 100644 --- a/rust/lance/src/bin/lq.rs +++ b/rust/lance/src/bin/lq.rs @@ -8,7 +8,7 @@ use arrow_array::RecordBatch; use clap::{Parser, Subcommand, ValueEnum}; use futures::stream::StreamExt; use futures::TryStreamExt; -use snafu::{location, Location}; +use snafu::location; use lance::dataset::Dataset; use lance::index::vector::VectorIndexParams; diff --git a/rust/lance/src/datafusion/dataframe.rs b/rust/lance/src/datafusion/dataframe.rs index e5b12b006c0..c1ede20b51e 100644 --- a/rust/lance/src/datafusion/dataframe.rs +++ b/rust/lance/src/datafusion/dataframe.rs @@ -9,9 +9,9 @@ use std::{ use arrow_schema::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion::{ - catalog::Session, + catalog::{streaming::StreamingTable, Session}, dataframe::DataFrame, - datasource::{streaming::StreamingTable, TableProvider}, + datasource::TableProvider, error::DataFusionError, execution::{context::SessionContext, TaskContext}, logical_expr::{Expr, TableProviderFilterPushDown, TableType}, @@ -22,6 +22,7 @@ use lance_core::{ROW_ADDR_FIELD, ROW_ID_FIELD}; use crate::Dataset; +#[derive(Debug)] pub struct LanceTableProvider { dataset: Arc, full_schema: Arc, @@ -153,6 +154,14 @@ impl OneShotPartitionStream { } } +impl std::fmt::Debug for OneShotPartitionStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OneShotPartitionStream") + .field("schema", &self.schema) + .finish() + } +} + impl PartitionStream for OneShotPartitionStream { fn schema(&self) -> &SchemaRef { &self.schema diff --git a/rust/lance/src/datafusion/logical_plan.rs b/rust/lance/src/datafusion/logical_plan.rs index b45bdedbe2b..9c20a3d43c9 100644 --- a/rust/lance/src/datafusion/logical_plan.rs +++ b/rust/lance/src/datafusion/logical_plan.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::{any::Any, sync::Arc}; +use std::{any::Any, borrow::Cow, sync::Arc}; use arrow_schema::Schema as ArrowSchema; use async_trait::async_trait; @@ -13,6 +13,7 @@ use datafusion::{ physical_plan::ExecutionPlan, prelude::Expr, }; +use lance_core::datatypes::{OnMissing, OnTypeMismatch}; use crate::Dataset; @@ -34,7 +35,7 @@ impl TableProvider for Dataset { None } - fn get_logical_plan(&self) -> Option<&LogicalPlan> { + fn get_logical_plan(&self) -> Option> { None } @@ -52,7 +53,11 @@ impl TableProvider for Dataset { if projection.len() != schema_ref.fields.len() { let arrow_schema: ArrowSchema = schema_ref.into(); let arrow_schema = arrow_schema.project(projection)?; - schema_ref.project_by_schema(&arrow_schema)? + schema_ref.project_by_schema( + &arrow_schema, + OnMissing::Error, + OnTypeMismatch::Error, + )? } else { schema_ref.clone() } diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index a4271f9d758..18ae38ac7e0 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -12,21 +12,25 @@ use futures::future::BoxFuture; use futures::stream::{self, StreamExt, TryStreamExt}; use futures::{FutureExt, Stream}; use itertools::Itertools; +use lance_core::datatypes::{OnMissing, OnTypeMismatch, Projectable, Projection}; use lance_core::traits::DatasetTakeRows; use lance_core::utils::address::RowAddress; use lance_core::utils::tokio::get_num_compute_intensive_cpus; +use lance_core::utils::tracing::{AUDIT_MODE_CREATE, AUDIT_TYPE_MANIFEST, TRACE_FILE_AUDIT}; +use lance_core::ROW_ADDR; use lance_datafusion::projection::ProjectionPlan; use lance_file::datatypes::populate_schema_dictionary; use lance_file::version::LanceFileVersion; -use lance_io::object_store::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry}; -use lance_io::object_writer::ObjectWriter; +use lance_index::DatasetIndexExt; +use lance_io::object_store::{ObjectStore, ObjectStoreParams}; +use lance_io::object_writer::{ObjectWriter, WriteResult}; use lance_io::traits::WriteExt; use lance_io::utils::{read_last_block, read_metadata_offset, read_struct}; use lance_table::format::{ DataStorageFormat, Fragment, Index, Manifest, MAGIC, MAJOR_VERSION, MINOR_VERSION, }; use lance_table::io::commit::{ - migrate_scheme_to_v2, CommitError, CommitHandler, CommitLock, ManifestLocation, + migrate_scheme_to_v2, CommitConfig, CommitError, CommitHandler, CommitLock, ManifestLocation, ManifestNamingScheme, }; use lance_table::io::manifest::{read_manifest, write_manifest}; @@ -34,13 +38,13 @@ use object_store::path::Path; use prost::Message; use rowids::get_row_id_index; use serde::{Deserialize, Serialize}; -use snafu::{location, Location}; +use snafu::location; use std::borrow::Cow; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; -use tracing::instrument; +use tracing::{info, instrument}; mod blob; pub mod builder; @@ -54,6 +58,7 @@ pub mod refs; pub(crate) mod rowids; pub mod scanner; mod schema_evolution; +pub mod statistics; mod take; pub mod transaction; pub mod updater; @@ -69,7 +74,10 @@ use self::transaction::{Operation, Transaction}; use self::write::write_fragments_internal; use crate::datatypes::Schema; use crate::error::box_error; -use crate::io::commit::{commit_detached_transaction, commit_new_dataset, commit_transaction}; +use crate::io::commit::{ + commit_detached_transaction, commit_new_dataset, commit_transaction, + detect_overlapping_fragments, +}; use crate::session::Session; use crate::utils::temporal::{timestamp_to_nanos, utc_now, SystemTime}; use crate::{Error, Result}; @@ -82,7 +90,8 @@ pub use schema_evolution::{ }; pub use take::TakeBuilder; pub use write::merge_insert::{ - MergeInsertBuilder, MergeInsertJob, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, + MergeInsertBuilder, MergeInsertJob, MergeStats, WhenMatched, WhenNotMatched, + WhenNotMatchedBySource, }; pub use write::update::{UpdateBuilder, UpdateJob}; #[allow(deprecated)] @@ -98,7 +107,7 @@ pub(crate) const DEFAULT_INDEX_CACHE_SIZE: usize = 256; pub(crate) const DEFAULT_METADATA_CACHE_SIZE: usize = 256; /// Lance Dataset -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct Dataset { pub object_store: Arc, pub(crate) commit_handler: Arc, @@ -115,6 +124,18 @@ pub struct Dataset { pub(crate) session: Arc, pub tags: Tags, pub manifest_naming_scheme: ManifestNamingScheme, + pub manifest_e_tag: Option, +} + +impl std::fmt::Debug for Dataset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Dataset") + .field("uri", &self.uri) + .field("base", &self.base) + .field("version", &self.manifest.version) + .field("cache_num_items", &self.session.approx_num_items()) + .finish() + } } /// Dataset Version @@ -176,8 +197,6 @@ pub struct ReadParams { /// If a custom object store is provided (via store_params.object_store) then this /// must also be provided. pub commit_handler: Option>, - - pub object_store_registry: Arc, } impl ReadParams { @@ -199,15 +218,6 @@ impl ReadParams { self } - /// Provide an object store registry for custom object stores - pub fn with_object_store_registry( - &mut self, - object_store_registry: Arc, - ) -> &mut Self { - self.object_store_registry = object_store_registry; - self - } - /// Use the explicit locking to resolve the latest version pub fn set_commit_lock(&mut self, lock: Arc) { self.commit_handler = Some(Arc::new(lock)); @@ -222,7 +232,6 @@ impl Default for ReadParams { session: None, store_options: None, commit_handler: None, - object_store_registry: Arc::new(ObjectStoreRegistry::default()), } } } @@ -254,7 +263,7 @@ impl ProjectionRequest { /// /// # Parameters /// - `columns`: A list of tuples where the first element is resulted column name and the second - /// element is the SQL expression. + /// element is the SQL expression. pub fn from_sql( columns: impl IntoIterator, impl Into)>, ) -> Self { @@ -269,7 +278,11 @@ impl ProjectionRequest { pub fn into_projection_plan(self, dataset_schema: &Schema) -> Result { match self { Self::Schema(schema) => Ok(ProjectionPlan::new_empty( - Arc::new(dataset_schema.project_by_schema(schema.as_ref())?), + Arc::new(dataset_schema.project_by_schema( + schema.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?), /*load_blobs=*/ false, )), Self::Sql(columns) => { @@ -317,12 +330,28 @@ impl Dataset { Ok(()) } + fn already_checked_out(&self, location: &ManifestLocation) -> bool { + // We check the e_tag here just in case it has been overwritten. This can + // happen if the table has been dropped then re-created recently. + self.manifest.version == location.version + && location.e_tag.as_ref().is_some_and(|e_tag| { + self.manifest_e_tag + .as_ref() + .is_some_and(|current_e_tag| e_tag == current_e_tag) + }) + } + async fn checkout_by_version_number(&self, version: u64) -> Result { let base_path = self.base.clone(); let manifest_location = self .commit_handler .resolve_version_location(&base_path, version, &self.object_store.inner) .await?; + + if self.already_checked_out(&manifest_location) { + return Ok(self.clone()); + } + let manifest = Self::load_manifest(self.object_store.as_ref(), &manifest_location).await?; Self::checkout_manifest( self.object_store.clone(), @@ -333,6 +362,7 @@ impl Dataset { self.session.clone(), self.commit_handler.clone(), manifest_location.naming_scheme, + manifest_location.e_tag, ) .await } @@ -418,6 +448,7 @@ impl Dataset { session: Arc, commit_handler: Arc, manifest_naming_scheme: ManifestNamingScheme, + e_tag: Option, ) -> Result { let tags = Tags::new( object_store.clone(), @@ -434,6 +465,7 @@ impl Dataset { session, tags, manifest_naming_scheme, + manifest_e_tag: e_tag, }) } @@ -514,6 +546,7 @@ impl Dataset { self.session.clone(), self.commit_handler.clone(), ManifestNamingScheme::V2, + blob_manifest_location.e_tag, ) .await?; Ok(Some(Arc::new(blobs_dataset))) @@ -535,7 +568,8 @@ impl Dataset { .commit_handler .resolve_latest_location(&self.base, &self.object_store) .await?; - if location.version == self.manifest.version { + + if self.already_checked_out(&location) { return Ok((self.manifest.as_ref().clone(), self.manifest_file.clone())); } let mut manifest = read_manifest(&self.object_store, &location.path, location.size).await?; @@ -580,19 +614,8 @@ impl Dataset { None, ); - let (restored_manifest, path) = commit_transaction( - self, - &self.object_store, - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; - - self.manifest = Arc::new(restored_manifest); - self.manifest_file = path; + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(()) } @@ -617,6 +640,7 @@ impl Dataset { /// # Returns /// /// * `RemovalStats` - Statistics about the removal operation + #[instrument(level = "debug", skip(self))] pub fn cleanup_old_versions( &self, older_than: Duration, @@ -641,7 +665,7 @@ impl Dataset { read_version: Option, store_params: Option, commit_handler: Option>, - object_store_registry: Arc, + session: Arc, enable_v2_manifest_paths: bool, detached: bool, ) -> Result { @@ -659,8 +683,8 @@ impl Dataset { let transaction = Transaction::new(read_version, operation, blobs_op, None); let mut builder = CommitBuilder::new(base_uri) - .with_object_store_registry(object_store_registry) .enable_v2_manifest_paths(enable_v2_manifest_paths) + .with_session(session) .with_detached(detached); if let Some(store_params) = store_params { @@ -714,7 +738,7 @@ impl Dataset { read_version: Option, store_params: Option, commit_handler: Option>, - object_store_registry: Arc, + session: Arc, enable_v2_manifest_paths: bool, ) -> Result { Self::do_commit( @@ -726,7 +750,7 @@ impl Dataset { read_version, store_params, commit_handler, - object_store_registry, + session, enable_v2_manifest_paths, /*detached=*/ false, ) @@ -747,7 +771,7 @@ impl Dataset { read_version: Option, store_params: Option, commit_handler: Option>, - object_store_registry: Arc, + session: Arc, enable_v2_manifest_paths: bool, ) -> Result { Self::do_commit( @@ -759,13 +783,37 @@ impl Dataset { read_version, store_params, commit_handler, - object_store_registry, + session, enable_v2_manifest_paths, /*detached=*/ true, ) .await } + pub(crate) async fn apply_commit( + &mut self, + transaction: Transaction, + write_config: &ManifestWriteConfig, + commit_config: &CommitConfig, + ) -> Result<()> { + let (manifest, manifest_path, manifest_e_tag) = commit_transaction( + self, + self.object_store(), + self.commit_handler.as_ref(), + &transaction, + write_config, + commit_config, + self.manifest_naming_scheme, + ) + .await?; + + self.manifest = Arc::new(manifest); + self.manifest_file = manifest_path; + self.manifest_e_tag = manifest_e_tag; + + Ok(()) + } + /// Create a Scanner to scan the dataset. pub fn scan(&self) -> Scanner { Scanner::new(Arc::new(self.clone())) @@ -792,7 +840,7 @@ impl Dataset { pub(crate) async fn count_all_rows(&self) -> Result { let cnts = stream::iter(self.get_fragments()) - .map(|f| async move { f.count_rows().await }) + .map(|f| async move { f.count_rows(None).await }) .buffer_unordered(16) .try_collect::>() .await?; @@ -928,19 +976,8 @@ impl Dataset { None, ); - let (manifest, path) = commit_transaction( - self, - &self.object_store, - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; - - self.manifest = Arc::new(manifest); - self.manifest_file = path; + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(()) } @@ -995,10 +1032,10 @@ impl Dataset { pub async fn versions(&self) -> Result> { let mut versions: Vec = self .commit_handler - .list_manifests(&self.base, &self.object_store.inner) + .list_manifest_locations(&self.base, &self.object_store.inner) .await? - .try_filter_map(|path| async move { - match read_manifest(&self.object_store, &path, None).await { + .try_filter_map(|location| async move { + match read_manifest(&self.object_store, &location.path, location.size).await { Ok(manifest) => Ok(Some(Version::from(&manifest))), Err(e) => Err(e), } @@ -1016,9 +1053,11 @@ impl Dataset { /// This is meant to be a fast path for checking if a dataset has changed. This is why /// we don't return the full version struct. pub async fn latest_version_id(&self) -> Result { - self.commit_handler - .resolve_latest_version_id(&self.base, &self.object_store) - .await + Ok(self + .commit_handler + .resolve_latest_location(&self.base, &self.object_store) + .await? + .version) } pub fn count_fragments(&self) -> usize { @@ -1035,6 +1074,11 @@ impl Dataset { &self.manifest.local_schema } + /// Creates a new empty projection into the dataset schema + pub fn empty_projection(self: &Arc) -> Projection { + Projection::empty(self.clone()) + } + /// Get fragments. /// /// If `filter` is provided, only fragments with the given name will be returned. @@ -1288,6 +1332,45 @@ impl Dataset { .try_collect::>() .await?; + // Validate indices + let indices = self.load_indices().await?; + self.validate_indices(&indices)?; + + Ok(()) + } + + fn validate_indices(&self, indices: &[Index]) -> Result<()> { + // Make sure there are no duplicate ids + let mut index_ids = HashSet::new(); + for index in indices.iter() { + if !index_ids.insert(&index.uuid) { + return Err(Error::corrupt_file( + self.manifest_file.clone(), + format!( + "Duplicate index id {} found in dataset {:?}", + &index.uuid, self.base + ), + location!(), + )); + } + } + + // For each index name, make sure there is no overlap in fragment bitmaps + if let Err(err) = detect_overlapping_fragments(indices) { + let mut message = "Overlapping fragments detected in dataset.".to_string(); + for (index_name, overlapping_frags) in err.bad_indices { + message.push_str(&format!( + "\nIndex {:?} has overlapping fragments: {:?}", + index_name, overlapping_frags + )); + } + return Err(Error::corrupt_file( + self.manifest_file.clone(), + message, + location!(), + )); + }; + Ok(()) } @@ -1337,7 +1420,7 @@ impl Dataset { /// - [Self::add_columns()]: Add new columns to the dataset, similar to `ALTER TABLE ADD COLUMN`. /// - [Self::drop_columns()]: Drop columns from the dataset, similar to `ALTER TABLE DROP COLUMN`. /// - [Self::alter_columns()]: Modify columns in the dataset, changing their name, type, or nullability. -/// Similar to `ALTER TABLE ALTER COLUMN`. +/// Similar to `ALTER TABLE ALTER COLUMN`. /// /// In addition, one operation is unique to Lance: [`merge`](Self::merge). This /// operation allows inserting precomputed data into the dataset. @@ -1395,7 +1478,7 @@ impl Dataset { right_on: &str, ) -> Result<()> { // Sanity check. - if self.schema().field(left_on).is_none() { + if self.schema().field(left_on).is_none() && left_on != ROW_ID && left_on != ROW_ADDR { return Err(Error::invalid_input( format!("Column {} does not exist in the left side dataset", left_on), location!(), @@ -1457,19 +1540,8 @@ impl Dataset { None, ); - let (manifest, manifest_path) = commit_transaction( - self, - &self.object_store, - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; - - self.manifest = Arc::new(manifest); - self.manifest_file = manifest_path; + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(()) } @@ -1495,65 +1567,68 @@ impl Dataset { self.merge_impl(stream, left_on, right_on).await } + async fn update_op(&mut self, op: Operation) -> Result<()> { + let transaction = + Transaction::new(self.manifest.version, op, /*blobs_op=*/ None, None); + + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; + + Ok(()) + } + /// Update key-value pairs in config. pub async fn update_config( &mut self, upsert_values: impl IntoIterator, ) -> Result<()> { - let transaction = Transaction::new( - self.manifest.version, - Operation::UpdateConfig { - upsert_values: Some(HashMap::from_iter(upsert_values)), - delete_keys: None, - }, - /*blobs_op=*/ None, - None, - ); - - let (manifest, manifest_path) = commit_transaction( - self, - &self.object_store, - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; - - self.manifest = Arc::new(manifest); - self.manifest_file = manifest_path; - - Ok(()) + self.update_op(Operation::UpdateConfig { + upsert_values: Some(HashMap::from_iter(upsert_values)), + delete_keys: None, + schema_metadata: None, + field_metadata: None, + }) + .await } /// Delete keys from the config. pub async fn delete_config_keys(&mut self, delete_keys: &[&str]) -> Result<()> { - let transaction = Transaction::new( - self.manifest.version, - Operation::UpdateConfig { - upsert_values: None, - delete_keys: Some(Vec::from_iter(delete_keys.iter().map(ToString::to_string))), - }, - /*blob_op=*/ None, - None, - ); - - let (manifest, manifest_path) = commit_transaction( - self, - &self.object_store, - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; + self.update_op(Operation::UpdateConfig { + upsert_values: None, + delete_keys: Some(Vec::from_iter(delete_keys.iter().map(ToString::to_string))), + schema_metadata: None, + field_metadata: None, + }) + .await + } - self.manifest = Arc::new(manifest); - self.manifest_file = manifest_path; + /// Update schema metadata + pub async fn replace_schema_metadata( + &mut self, + upsert_values: impl IntoIterator, + ) -> Result<()> { + self.update_op(Operation::UpdateConfig { + upsert_values: None, + delete_keys: None, + schema_metadata: Some(HashMap::from_iter(upsert_values)), + field_metadata: None, + }) + .await + } - Ok(()) + /// Update field metadata + pub async fn replace_field_metadata( + &mut self, + new_values: impl IntoIterator)>, + ) -> Result<()> { + let new_values = new_values.into_iter().collect::>(); + self.update_op(Operation::UpdateConfig { + upsert_values: None, + delete_keys: None, + schema_metadata: None, + field_metadata: Some(new_values), + }) + .await } } @@ -1598,7 +1673,7 @@ pub(crate) async fn write_manifest_file( indices: Option>, config: &ManifestWriteConfig, naming_scheme: ManifestNamingScheme, -) -> std::result::Result { +) -> std::result::Result { if config.auto_set_feature_flags { apply_feature_flags(manifest, config.use_move_stable_row_ids)?; } @@ -1624,18 +1699,25 @@ fn write_manifest_file_to_path<'a>( manifest: &'a mut Manifest, indices: Option>, path: &'a Path, -) -> BoxFuture<'a, Result<()>> { +) -> BoxFuture<'a, Result> { Box::pin(async { let mut object_writer = ObjectWriter::new(object_store, path).await?; let pos = write_manifest(&mut object_writer, manifest, indices).await?; object_writer .write_magics(pos, MAJOR_VERSION, MINOR_VERSION, MAGIC) .await?; - object_writer.shutdown().await?; - Ok(()) + let res = object_writer.shutdown().await?; + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_CREATE, type=AUDIT_TYPE_MANIFEST, path = path.to_string()); + Ok(res) }) } +impl Projectable for Dataset { + fn schema(&self) -> &Schema { + self.schema() + } +} + #[cfg(test)] mod tests { use std::vec; @@ -1643,12 +1725,14 @@ mod tests { use super::*; use crate::arrow::FixedSizeListArrayExt; use crate::dataset::optimize::{compact_files, CompactionOptions}; + use crate::dataset::transaction::DataReplacementGroup; use crate::dataset::WriteMode::Overwrite; use crate::index::vector::VectorIndexParams; use crate::utils::test::TestDatasetGenerator; - use arrow::array::as_struct_array; + use arrow::array::{as_struct_array, AsArray, GenericListBuilder, GenericStringBuilder}; use arrow::compute::concat_batches; + use arrow::datatypes::UInt64Type; use arrow_array::{ builder::StringDictionaryBuilder, cast::as_string_array, @@ -1658,7 +1742,7 @@ mod tests { }; use arrow_array::{ Array, FixedSizeListArray, GenericStringArray, Int16Array, Int16DictionaryArray, - StructArray, + StructArray, UInt64Array, }; use arrow_ord::sort::sort_to_indices; use arrow_schema::{ @@ -1667,19 +1751,22 @@ mod tests { use lance_arrow::bfloat16::{self, ARROW_EXT_META_KEY, ARROW_EXT_NAME_KEY, BFLOAT16_EXT_NAME}; use lance_core::datatypes::LANCE_STORAGE_CLASS_SCHEMA_META_KEY; use lance_datagen::{array, gen, BatchCount, Dimension, RowCount}; + use lance_file::v2::writer::FileWriter; use lance_file::version::LanceFileVersion; + use lance_index::scalar::inverted::query::{MatchQuery, Operator, PhraseQuery}; + use lance_index::scalar::inverted::TokenizerConfig; use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams}; use lance_index::{scalar::ScalarIndexParams, vector::DIST_COL, DatasetIndexExt, IndexType}; use lance_linalg::distance::MetricType; use lance_table::feature_flags; - use lance_table::format::WriterVersion; - use lance_table::io::commit::RenameCommitHandler; + use lance_table::format::{DataFile, WriterVersion}; + use lance_table::io::deletion::read_deletion_file; use lance_testing::datagen::generate_random_array; use pretty_assertions::assert_eq; + use rand::seq::SliceRandom; use rstest::rstest; use tempfile::{tempdir, TempDir}; - use url::Url; // Used to validate that futures returned are Send. fn require_send(t: T) -> T { @@ -1919,6 +2006,8 @@ mod tests { // Need to use in-memory for accurate IOPS tracking. use crate::utils::test::IoTrackingStore; + // Use consistent session so memory store can be reused. + let session = Arc::new(Session::default()); let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( "i", DataType::Int32, @@ -1930,26 +2019,33 @@ mod tests { ) .unwrap(); let batches = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); - let dataset = Dataset::write(batches, "memory://test", None) - .await - .unwrap(); - - // Then open with wrapping store. - let memory_store = dataset.object_store.inner.clone(); let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); + let _original_ds = Dataset::write( + batches, + "memory://test", + Some(WriteParams { + store_params: Some(ObjectStoreParams { + object_store_wrapper: Some(io_stats_wrapper.clone()), + ..Default::default() + }), + session: Some(session.clone()), + ..Default::default() + }), + ) + .await + .unwrap(); + + io_stats.lock().unwrap().read_iops = 0; + let _dataset = DatasetBuilder::from_uri("memory://test") .with_read_params(ReadParams { store_options: Some(ObjectStoreParams { object_store_wrapper: Some(io_stats_wrapper), ..Default::default() }), + session: Some(session), ..Default::default() }) - .with_object_store( - memory_store, - Url::parse("memory://test").unwrap(), - Arc::new(RenameCommitHandler), - ) .load() .await .unwrap(); @@ -2004,9 +2100,9 @@ mod tests { assert_eq!(fragments.len(), 10); assert_eq!(dataset.count_fragments(), 10); for fragment in &fragments { - assert_eq!(fragment.count_rows().await.unwrap(), 100); + assert_eq!(fragment.count_rows(None).await.unwrap(), 100); let reader = fragment - .open(dataset.schema(), FragReadConfig::default(), None) + .open(dataset.schema(), FragReadConfig::default()) .await .unwrap(); // No group / batch concept in v2 @@ -2047,6 +2143,7 @@ mod tests { test_uri, Some(WriteParams { data_storage_version: Some(data_storage_version), + auto_cleanup: None, ..Default::default() }), ); @@ -2058,9 +2155,10 @@ mod tests { dataset.object_store(), &dataset .commit_handler - .resolve_latest_version(&dataset.base, dataset.object_store()) + .resolve_latest_location(&dataset.base, dataset.object_store()) .await - .unwrap(), + .unwrap() + .path, None, ) .await @@ -2081,9 +2179,10 @@ mod tests { dataset.object_store(), &dataset .commit_handler - .resolve_latest_version(&dataset.base, dataset.object_store()) + .resolve_latest_location(&dataset.base, dataset.object_store()) .await - .unwrap(), + .unwrap() + .path, None, ) .await @@ -2798,45 +2897,163 @@ mod tests { assert_eq!(batch.num_rows(), 0); } + #[rstest] #[tokio::test] - async fn test_create_fts_index_with_empty_strings() { + async fn test_create_int8_index( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, + ) { + use lance_testing::datagen::generate_random_int8_array; + let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); + let dimension = 16; let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( - "text", - DataType::Utf8, + "embeddings", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Int8, true)), + dimension, + ), false, )])); - let batches: Vec = vec![RecordBatch::try_new( - schema.clone(), - vec![Arc::new(StringArray::from(vec!["", "", ""]))], - ) - .unwrap()]; + let int8_arr = generate_random_int8_array(512 * dimension as usize); + let vectors = Arc::new( + ::try_new_from_values( + int8_arr, dimension, + ) + .unwrap(), + ); + let batches = vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()]; + + let test_uri = test_dir.path().to_str().unwrap(); + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); - let mut dataset = Dataset::write(reader, test_uri, None) - .await - .expect("write dataset"); - let params = InvertedIndexParams::default(); + let mut dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + dataset.validate().await.unwrap(); + + // Make sure valid arguments should create index successfully + let params = VectorIndexParams::ivf_pq(10, 8, 2, MetricType::L2, 50); dataset - .create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .create_index(&["embeddings"], IndexType::Vector, None, ¶ms, true) .await .unwrap(); + dataset.validate().await.unwrap(); - let batch = dataset - .scan() - .full_text_search(FullTextSearchQuery::new("lance".to_owned())) - .unwrap() - .try_into_batch() + // The version should match the table version it was created from. + let indices = dataset.load_indices().await.unwrap(); + let actual = indices.first().unwrap().dataset_version; + let expected = dataset.manifest.version - 1; + assert_eq!(actual, expected); + let fragment_bitmap = indices.first().unwrap().fragment_bitmap.as_ref().unwrap(); + assert_eq!(fragment_bitmap.len(), 1); + assert!(fragment_bitmap.contains(0)); + + // Append should inherit index + let write_params = WriteParams { + mode: WriteMode::Append, + data_storage_version: Some(data_storage_version), + ..Default::default() + }; + let batches = vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()]; + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write(reader, test_uri, Some(write_params)) .await .unwrap(); - assert_eq!(batch.num_rows(), 0); - } + let indices = dataset.load_indices().await.unwrap(); + let actual = indices.first().unwrap().dataset_version; + let expected = dataset.manifest.version - 2; + assert_eq!(actual, expected); + dataset.validate().await.unwrap(); + // Fragment bitmap should show the original fragments, and not include + // the newly appended fragment. + let fragment_bitmap = indices.first().unwrap().fragment_bitmap.as_ref().unwrap(); + assert_eq!(fragment_bitmap.len(), 1); + assert!(fragment_bitmap.contains(0)); - #[rstest] - #[tokio::test] + let actual_statistics: serde_json::Value = + serde_json::from_str(&dataset.index_statistics("embeddings_idx").await.unwrap()) + .unwrap(); + let actual_statistics = actual_statistics.as_object().unwrap(); + assert_eq!(actual_statistics["index_type"].as_str().unwrap(), "IVF_PQ"); + + let deltas = actual_statistics["indices"].as_array().unwrap(); + assert_eq!(deltas.len(), 1); + assert_eq!(deltas[0]["metric_type"].as_str().unwrap(), "l2"); + assert_eq!(deltas[0]["num_partitions"].as_i64().unwrap(), 10); + + assert!(dataset.index_statistics("non-existent_idx").await.is_err()); + assert!(dataset.index_statistics("").await.is_err()); + + // Overwrite should invalidate index + let write_params = WriteParams { + mode: WriteMode::Overwrite, + data_storage_version: Some(data_storage_version), + ..Default::default() + }; + let batches = vec![RecordBatch::try_new(schema.clone(), vec![vectors]).unwrap()]; + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); + let dataset = Dataset::write(reader, test_uri, Some(write_params)) + .await + .unwrap(); + assert!(dataset.manifest.index_section.is_none()); + assert!(dataset.load_indices().await.unwrap().is_empty()); + dataset.validate().await.unwrap(); + + let fragment_bitmap = indices.first().unwrap().fragment_bitmap.as_ref().unwrap(); + assert_eq!(fragment_bitmap.len(), 1); + assert!(fragment_bitmap.contains(0)); + } + + #[tokio::test] + async fn test_create_fts_index_with_empty_strings() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "text", + DataType::Utf8, + false, + )])); + + let batches: Vec = vec![RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec!["", "", ""]))], + ) + .unwrap()]; + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone()); + let mut dataset = Dataset::write(reader, test_uri, None) + .await + .expect("write dataset"); + + let params = InvertedIndexParams::default(); + dataset + .create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + let batch = dataset + .scan() + .full_text_search(FullTextSearchQuery::new("lance".to_owned())) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(batch.num_rows(), 0); + } + + #[rstest] + #[tokio::test] async fn test_bad_field_name( #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, @@ -2939,7 +3156,7 @@ mod tests { None, None, None, - Arc::new(ObjectStoreRegistry::default()), + Default::default(), true, // enable_v2_manifest_paths ) .await @@ -2950,6 +3167,41 @@ mod tests { assert_all_manifests_use_scheme(&test_dir, ManifestNamingScheme::V2); } + #[tokio::test] + async fn test_strict_overwrite() { + let schema = Schema::try_from(&ArrowSchema::new(vec![ArrowField::new( + "x", + DataType::Int32, + false, + )])) + .unwrap(); + let operation = Operation::Overwrite { + fragments: vec![], + schema, + config_upsert_values: None, + }; + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let read_version_0_transaction = Transaction::new(0, operation, None, None); + let strict_builder = CommitBuilder::new(test_uri).with_max_retries(0); + let unstrict_builder = CommitBuilder::new(test_uri).with_max_retries(1); + strict_builder + .clone() + .execute(read_version_0_transaction.clone()) + .await + .expect("Strict overwrite should succeed when writing a new dataset"); + strict_builder + .clone() + .execute(read_version_0_transaction.clone()) + .await + .expect_err("Strict overwrite should fail when committing to a stale version"); + unstrict_builder + .clone() + .execute(read_version_0_transaction.clone()) + .await + .expect("Unstrict overwrite should succeed when committing to a stale version"); + } + #[rstest] #[tokio::test] async fn test_merge( @@ -3117,6 +3369,137 @@ mod tests { dataset.validate().await.unwrap(); } + #[rstest] + #[tokio::test] + async fn test_merge_on_row_id( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + #[values(false, true)] use_stable_row_id: bool, + ) { + // Tests a merge on _rowid + + let data = lance_datagen::gen() + .col("key", array::step::()) + .col("value", array::fill_utf8("value".to_string())) + .into_reader_rows(RowCount::from(1_000), BatchCount::from(10)); + + let write_params = WriteParams { + mode: WriteMode::Append, + data_storage_version: Some(data_storage_version), + max_rows_per_file: 1024, + max_rows_per_group: 150, + enable_move_stable_row_ids: use_stable_row_id, + ..Default::default() + }; + let mut dataset = Dataset::write(data, "memory://", Some(write_params.clone())) + .await + .unwrap(); + assert_eq!(dataset.fragments().len(), 10); + assert_eq!(dataset.manifest.max_fragment_id(), Some(9)); + + let data = dataset.scan().with_row_id().try_into_batch().await.unwrap(); + let row_ids: Arc = data[ROW_ID].clone(); + let key = data["key"].as_primitive::(); + let new_schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("rowid", DataType::UInt64, false), + ArrowField::new("new_value", DataType::Int32, false), + ])); + let new_value = Arc::new( + key.into_iter() + .map(|v| v.unwrap() + 1) + .collect::(), + ); + let len = new_value.len() as u32; + let new_batch = RecordBatch::try_new(new_schema.clone(), vec![row_ids, new_value]).unwrap(); + // shuffle new_batch + let mut rng = rand::thread_rng(); + let mut indices: Vec = (0..len).collect(); + indices.shuffle(&mut rng); + let indices = arrow_array::UInt32Array::from_iter_values(indices); + let new_batch = arrow::compute::take_record_batch(&new_batch, &indices).unwrap(); + let new_data = RecordBatchIterator::new(vec![Ok(new_batch)], new_schema.clone()); + dataset.merge(new_data, ROW_ID, "rowid").await.unwrap(); + dataset.validate().await.unwrap(); + assert_eq!(dataset.schema().fields.len(), 3); + assert!(dataset.schema().field("key").is_some()); + assert!(dataset.schema().field("value").is_some()); + assert!(dataset.schema().field("new_value").is_some()); + let batch = dataset.scan().try_into_batch().await.unwrap(); + let key = batch["key"].as_primitive::(); + let new_value = batch["new_value"].as_primitive::(); + for i in 0..key.len() { + assert_eq!(key.value(i) + 1, new_value.value(i)); + } + } + + #[rstest] + #[tokio::test] + async fn test_merge_on_row_addr( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + #[values(false, true)] use_stable_row_id: bool, + ) { + // Tests a merge on _rowaddr + + let data = lance_datagen::gen() + .col("key", array::step::()) + .col("value", array::fill_utf8("value".to_string())) + .into_reader_rows(RowCount::from(1_000), BatchCount::from(10)); + + let write_params = WriteParams { + mode: WriteMode::Append, + data_storage_version: Some(data_storage_version), + max_rows_per_file: 1024, + max_rows_per_group: 150, + enable_move_stable_row_ids: use_stable_row_id, + ..Default::default() + }; + let mut dataset = Dataset::write(data, "memory://", Some(write_params.clone())) + .await + .unwrap(); + + assert_eq!(dataset.fragments().len(), 10); + assert_eq!(dataset.manifest.max_fragment_id(), Some(9)); + + let data = dataset + .scan() + .with_row_address() + .try_into_batch() + .await + .unwrap(); + let row_addrs = data[ROW_ADDR].clone(); + let key = data["key"].as_primitive::(); + let new_schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("rowaddr", DataType::UInt64, false), + ArrowField::new("new_value", DataType::Int32, false), + ])); + let new_value = Arc::new( + key.into_iter() + .map(|v| v.unwrap() + 1) + .collect::(), + ); + let len = new_value.len() as u32; + let new_batch = + RecordBatch::try_new(new_schema.clone(), vec![row_addrs, new_value]).unwrap(); + // shuffle new_batch + let mut rng = rand::thread_rng(); + let mut indices: Vec = (0..len).collect(); + indices.shuffle(&mut rng); + let indices = arrow_array::UInt32Array::from_iter_values(indices); + let new_batch = arrow::compute::take_record_batch(&new_batch, &indices).unwrap(); + let new_data = RecordBatchIterator::new(vec![Ok(new_batch)], new_schema.clone()); + dataset.merge(new_data, ROW_ADDR, "rowaddr").await.unwrap(); + dataset.validate().await.unwrap(); + assert_eq!(dataset.schema().fields.len(), 3); + assert!(dataset.schema().field("key").is_some()); + assert!(dataset.schema().field("value").is_some()); + assert!(dataset.schema().field("new_value").is_some()); + let batch = dataset.scan().try_into_batch().await.unwrap(); + let key = batch["key"].as_primitive::(); + let new_value = batch["new_value"].as_primitive::(); + for i in 0..key.len() { + assert_eq!(key.value(i) + 1, new_value.value(i)); + } + } + #[rstest] #[tokio::test] async fn test_delete( @@ -3385,8 +3768,8 @@ mod tests { let reader = RecordBatchIterator::new(vec![data.unwrap()].into_iter().map(Ok), schema); let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap(); - let mut desired_config = HashMap::new(); - desired_config.insert("lance:test".to_string(), "value".to_string()); + let mut desired_config = dataset.manifest.config.clone(); + desired_config.insert("lance.test".to_string(), "value".to_string()); desired_config.insert("other-key".to_string(), "other-value".to_string()); dataset.update_config(desired_config.clone()).await.unwrap(); @@ -3397,6 +3780,68 @@ mod tests { assert_eq!(dataset.manifest.config, desired_config); } + #[rstest] + #[tokio::test] + async fn test_replace_schema_metadata_preserves_fragments() { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::UInt32, + false, + )])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(UInt32Array::from_iter_values(0..100))], + ); + + let reader = RecordBatchIterator::new(vec![data.unwrap()].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(reader, "memory://", None).await.unwrap(); + + let manifest_before = dataset.manifest.clone(); + + let mut new_schema_meta = HashMap::new(); + new_schema_meta.insert("new_key".to_string(), "new_value".to_string()); + dataset + .replace_schema_metadata(new_schema_meta.clone()) + .await + .unwrap(); + + let manifest_after = dataset.manifest.clone(); + + assert_eq!(manifest_before.fragments, manifest_after.fragments); + } + + #[rstest] + #[tokio::test] + async fn test_replace_fragment_metadata_preserves_fragments() { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::UInt32, + false, + )])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(UInt32Array::from_iter_values(0..100))], + ); + + let reader = RecordBatchIterator::new(vec![data.unwrap()].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(reader, "memory://", None).await.unwrap(); + + let manifest_before = dataset.manifest.clone(); + + let mut new_field_meta = HashMap::new(); + new_field_meta.insert("new_key".to_string(), "new_value".to_string()); + dataset + .replace_field_metadata(vec![(0, new_field_meta.clone())]) + .await + .unwrap(); + + let manifest_after = dataset.manifest.clone(); + + assert_eq!(manifest_before.fragments, manifest_after.fragments); + } + #[rstest] #[tokio::test] async fn test_tag( @@ -3590,7 +4035,7 @@ mod tests { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let data = gen().col("vec", array::rand_vec::(Dimension::from(32))); + let data = gen().col("vec", array::rand_vec::(Dimension::from(128))); let reader = data.into_reader_rows(RowCount::from(1000), BatchCount::from(10)); let mut dataset = Dataset::write( reader, @@ -4110,6 +4555,38 @@ mod tests { ); } + #[tokio::test] + async fn test_fix_v0_21_0_corrupt_fragment_bitmap() { + // In v0.21.0 and earlier, delta indices had a bug where the fragment bitmap + // could contain fragments that are part of other index deltas. + + // Copy over table + let test_dir = copy_test_data_to_tmp("v0.21.0/bad_index_fragment_bitmap").unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let mut dataset = Dataset::open(test_uri).await.unwrap(); + + let validate_res = dataset.validate().await; + assert!(validate_res.is_err()); + assert_eq!(dataset.load_indices().await.unwrap()[0].name, "vector_idx"); + + // Calling index statistics will force a migration + let stats = dataset.index_statistics("vector_idx").await.unwrap(); + let stats: serde_json::Value = serde_json::from_str(&stats).unwrap(); + assert_eq!(stats["num_indexed_fragments"], 2); + + dataset.checkout_latest().await.unwrap(); + dataset.validate().await.unwrap(); + + let indices = dataset.load_indices().await.unwrap(); + assert_eq!(indices.len(), 2); + fn get_bitmap(meta: &Index) -> Vec { + meta.fragment_bitmap.as_ref().unwrap().iter().collect() + } + assert_eq!(get_bitmap(&indices[0]), vec![0]); + assert_eq!(get_bitmap(&indices[1]), vec![1]); + } + #[rstest] #[tokio::test] async fn test_bfloat16_roundtrip( @@ -4295,13 +4772,66 @@ mod tests { ); } + #[tokio::test] + async fn test_fts_fuzzy_query() { + let tempdir = tempfile::tempdir().unwrap(); + + let params = InvertedIndexParams::default(); + let text_col = GenericStringArray::::from(vec![ + "fa", "fo", "fob", "focus", "foo", "food", "foul", // # spellchecker:disable-line + ]); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "text", + text_col.data_type().to_owned(), + false, + )]) + .into(), + vec![Arc::new(text_col) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, tempdir.path().to_str().unwrap(), None) + .await + .unwrap(); + dataset + .create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + let results = dataset + .scan() + .full_text_search(FullTextSearchQuery::new_fuzzy("foo".to_owned(), Some(1))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 4); + let texts = results["text"] + .as_string::() + .iter() + .map(|s| s.unwrap().to_owned()) + .collect::>(); + assert_eq!( + texts, + vec![ + "foo".to_owned(), // 0 edits + "fo".to_owned(), // 1 deletion # spellchecker:disable-line + "fob".to_owned(), // 1 substitution # spellchecker:disable-line + "food".to_owned(), // 1 insertion # spellchecker:disable-line + ] + .into_iter() + .collect() + ); + } + #[tokio::test] async fn test_fts_on_multiple_columns() { let tempdir = tempfile::tempdir().unwrap(); let params = InvertedIndexParams::default(); let title_col = - GenericStringArray::::from(vec!["title hello", "title lance", "title common"]); + GenericStringArray::::from(vec!["title common", "title hello", "title lance"]); let content_col = GenericStringArray::::from(vec![ "content world", "content database", @@ -4359,9 +4889,35 @@ mod tests { .await .unwrap(); assert_eq!(results.num_rows(), 2); - } - #[tokio::test] + let results = dataset + .scan() + .full_text_search( + FullTextSearchQuery::new("common".to_owned()) + .with_column("title".to_owned()) + .unwrap(), + ) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 1); + + let results = dataset + .scan() + .full_text_search( + FullTextSearchQuery::new("common".to_owned()) + .with_column("content".to_owned()) + .unwrap(), + ) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 1); + } + + #[tokio::test] async fn test_fts_unindexed_data() { let tempdir = tempfile::tempdir().unwrap(); @@ -4451,6 +5007,390 @@ mod tests { assert_eq!(results.num_rows(), 1); } + #[tokio::test] + async fn test_fts_rank() { + let tempdir = tempfile::tempdir().unwrap(); + + let params = InvertedIndexParams::default(); + let text_col = + GenericStringArray::::from(vec!["score", "find score", "try to find score"]); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "text", + text_col.data_type().to_owned(), + false, + )]) + .into(), + vec![Arc::new(text_col) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, tempdir.path().to_str().unwrap(), None) + .await + .unwrap(); + dataset + .create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + let results = dataset + .scan() + .with_row_id() + .full_text_search(FullTextSearchQuery::new("score".to_owned())) + .unwrap() + .limit(Some(3), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 3); + let row_ids = results[ROW_ID].as_primitive::().values(); + assert_eq!(row_ids, &[0, 1, 2]); + + let results = dataset + .scan() + .with_row_id() + .full_text_search(FullTextSearchQuery::new("score".to_owned())) + .unwrap() + .limit(Some(2), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 2); + let row_ids = results[ROW_ID].as_primitive::().values(); + assert_eq!(row_ids, &[0, 1]); + + let results = dataset + .scan() + .with_row_id() + .full_text_search(FullTextSearchQuery::new("score".to_owned())) + .unwrap() + .limit(Some(1), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 1); + let row_ids = results[ROW_ID].as_primitive::().values(); + assert_eq!(row_ids, &[0]); + } + + async fn create_fts_dataset< + Offset: arrow::array::OffsetSizeTrait, + ListOffset: arrow::array::OffsetSizeTrait, + >( + is_list: bool, + with_position: bool, + tokenizer: TokenizerConfig, + ) -> Dataset { + let tempdir = tempfile::tempdir().unwrap(); + let uri = tempdir.path().to_str().unwrap().to_owned(); + tempdir.close().unwrap(); + + let mut params = InvertedIndexParams::default().with_position(with_position); + params.tokenizer_config = tokenizer; + let doc_col: Arc = if is_list { + let string_builder = GenericStringBuilder::::new(); + let mut list_col = GenericListBuilder::::new(string_builder); + // Create a list of strings + list_col.values().append_value("lance database"); // for testing phrase query + list_col.values().append_value("the"); + list_col.values().append_value("search"); + list_col.append(true); + list_col.values().append_value("lance database"); // for testing phrase query + list_col.append(true); + list_col.values().append_value("lance"); + list_col.values().append_value("search"); + list_col.append(true); + list_col.values().append_value("database"); + list_col.values().append_value("search"); + list_col.append(true); + list_col.values().append_value("unrelated doc"); + list_col.append(true); + list_col.values().append_value("unrelated"); + list_col.append(true); + list_col.values().append_value("mots"); + list_col.values().append_value("accentués"); + list_col.append(true); + list_col.append(false); + Arc::new(list_col.finish()) + } else { + Arc::new(GenericStringArray::::from(vec![ + "lance database the search", + "lance database", + "lance search", + "database search", + "unrelated doc", + "unrelated", + "mots accentués", + ])) + }; + let ids = UInt64Array::from_iter_values(0..doc_col.len() as u64); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), true), + arrow_schema::Field::new("id", DataType::UInt64, false), + ]) + .into(), + vec![Arc::new(doc_col) as ArrayRef, Arc::new(ids) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, &uri, None).await.unwrap(); + + dataset + .create_index(&["doc"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + dataset + } + + async fn test_fts_index< + Offset: arrow::array::OffsetSizeTrait, + ListOffset: arrow::array::OffsetSizeTrait, + >( + is_list: bool, + ) { + let ds = + create_fts_dataset::(is_list, false, TokenizerConfig::default()) + .await; + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("lance".to_owned()).limit(Some(3))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 3); + let ids = result["id"].as_primitive::().values(); + assert!(ids.contains(&0)); + assert!(ids.contains(&1)); + assert!(ids.contains(&2)); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("database".to_owned()).limit(Some(3))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 3); + let ids = result["id"].as_primitive::().values(); + assert!(ids.contains(&0)); + assert!(ids.contains(&1)); + assert!(ids.contains(&3)); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search( + FullTextSearchQuery::new_query( + MatchQuery::new("lance database".to_owned()) + .with_operator(Operator::And) + .into(), + ) + .limit(Some(3)), + ) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 2); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("unknown null".to_owned()).limit(Some(3))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 0); + + // test phrase query + // for non-phrasal query, the order of the tokens doesn't matter + // so there should be 4 documents that contain "database" or "lance" + + // we built the index without position, so the phrase query will not work + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search( + FullTextSearchQuery::new_query( + PhraseQuery::new("lance database".to_owned()).into(), + ) + .limit(Some(10)), + ) + .unwrap() + .try_into_batch() + .await; + let err = result.unwrap_err().to_string(); + assert!(err.contains("position is not found but required for phrase queries, try recreating the index with position"),"{}",err); + + // recreate the index with position + let ds = + create_fts_dataset::(is_list, true, TokenizerConfig::default()) + .await; + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("lance database".to_owned()).limit(Some(10))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 4); + let ids = result["id"].as_primitive::().values(); + assert!(ids.contains(&0)); + assert!(ids.contains(&1)); + assert!(ids.contains(&2)); + assert!(ids.contains(&3)); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search( + FullTextSearchQuery::new_query( + PhraseQuery::new("lance database".to_owned()).into(), + ) + .limit(Some(10)), + ) + .unwrap() + .try_into_batch() + .await + .unwrap(); + let ids = result["id"].as_primitive::().values(); + assert_eq!(result.num_rows(), 2, "{:?}", ids); + assert!(ids.contains(&0)); + assert!(ids.contains(&1)); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search( + FullTextSearchQuery::new_query( + PhraseQuery::new("database lance".to_owned()).into(), + ) + .limit(Some(10)), + ) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 0); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search( + FullTextSearchQuery::new_query(PhraseQuery::new("lance unknown".to_owned()).into()) + .limit(Some(10)), + ) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 0); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search( + FullTextSearchQuery::new_query(PhraseQuery::new("unknown null".to_owned()).into()) + .limit(Some(3)), + ) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 0); + } + + #[tokio::test] + async fn test_fts_index_with_string() { + test_fts_index::(false).await; + test_fts_index::(true).await; + test_fts_index::(true).await; + } + + #[tokio::test] + async fn test_fts_index_with_large_string() { + test_fts_index::(false).await; + test_fts_index::(true).await; + test_fts_index::(true).await; + } + + #[tokio::test] + async fn test_fts_accented_chars() { + let ds = create_fts_dataset::(false, false, TokenizerConfig::default()).await; + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 1); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 0); + + // with ascii folding enabled, the search should be accent-insensitive + let ds = create_fts_dataset::( + false, + false, + TokenizerConfig::default().ascii_folding(true), + ) + .await; + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 1); + + let result = ds + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3))) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), 1); + } + #[tokio::test] async fn concurrent_create() { async fn write(uri: &str) -> Result<()> { @@ -4775,4 +5715,493 @@ mod tests { assert!(result.is_err()); assert!(matches!(result, Err(Error::SchemaMismatch { .. }))); } + + #[tokio::test] + async fn test_datafile_replacement() { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])); + let empty_reader = RecordBatchIterator::new(vec![], schema.clone()); + let dataset = Arc::new( + Dataset::write(empty_reader, "memory://", None) + .await + .unwrap(), + ); + dataset.validate().await.unwrap(); + + // Test empty replacement should commit a new manifest and do nothing + let mut dataset = Dataset::commit( + WriteDestination::Dataset(dataset.clone()), + Operation::DataReplacement { + replacements: vec![], + }, + Some(1), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap(); + dataset.validate().await.unwrap(); + + assert_eq!(dataset.version().version, 2); + assert_eq!(dataset.get_fragments().len(), 0); + + // try the same thing on a non-empty dataset + let vals: Int32Array = vec![1, 2, 3].into(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(vals)]).unwrap(); + dataset + .append( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + None, + ) + .await + .unwrap(); + + let dataset = Dataset::commit( + WriteDestination::Dataset(Arc::new(dataset)), + Operation::DataReplacement { + replacements: vec![], + }, + Some(3), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap(); + dataset.validate().await.unwrap(); + + assert_eq!(dataset.version().version, 4); + assert_eq!(dataset.get_fragments().len(), 1); + + let batch = dataset.scan().try_into_batch().await.unwrap(); + assert_eq!(batch.num_rows(), 3); + assert_eq!( + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[1, 2, 3] + ); + + // write a new datafile + let object_writer = dataset + .object_store + .create(&Path::from("data/test.lance")) + .await + .unwrap(); + let mut writer = FileWriter::try_new( + object_writer, + schema.as_ref().try_into().unwrap(), + Default::default(), + ) + .unwrap(); + + let vals: Int32Array = vec![4, 5, 6].into(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(vals)]).unwrap(); + writer.write_batch(&batch).await.unwrap(); + writer.finish().await.unwrap(); + + // find the datafile we want to replace + let frag = dataset.get_fragment(0).unwrap(); + let data_file = frag.data_file_for_field(0).unwrap(); + let mut new_data_file = data_file.clone(); + new_data_file.path = "test.lance".to_string(); + + let dataset = Dataset::commit( + WriteDestination::Dataset(Arc::new(dataset)), + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, new_data_file)], + }, + Some(5), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.version().version, 5); + assert_eq!(dataset.get_fragments().len(), 1); + assert_eq!(dataset.get_fragments()[0].metadata.files.len(), 1); + + let batch = dataset.scan().try_into_batch().await.unwrap(); + assert_eq!(batch.num_rows(), 3); + assert_eq!( + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[4, 5, 6] + ); + } + + #[tokio::test] + async fn test_datafile_partial_replacement() { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])); + let empty_reader = RecordBatchIterator::new(vec![], schema.clone()); + let mut dataset = Dataset::write(empty_reader, "memory://", None) + .await + .unwrap(); + dataset.validate().await.unwrap(); + + let vals: Int32Array = vec![1, 2, 3].into(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(vals)]).unwrap(); + dataset + .append( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + None, + ) + .await + .unwrap(); + + let fragment = dataset.get_fragments().pop().unwrap().metadata; + + let extended_schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, true), + ArrowField::new("b", DataType::Int32, true), + ])); + + // add all null column + let dataset = Dataset::commit( + WriteDestination::Dataset(Arc::new(dataset)), + Operation::Merge { + fragments: vec![fragment], + schema: extended_schema.as_ref().try_into().unwrap(), + }, + Some(2), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap(); + + let partial_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "b", + DataType::Int32, + true, + )])); + + // write a new datafile + let object_writer = dataset + .object_store + .create(&Path::from("data/test.lance")) + .await + .unwrap(); + let mut writer = FileWriter::try_new( + object_writer, + partial_schema.as_ref().try_into().unwrap(), + Default::default(), + ) + .unwrap(); + + let vals: Int32Array = vec![4, 5, 6].into(); + let batch = RecordBatch::try_new(partial_schema.clone(), vec![Arc::new(vals)]).unwrap(); + writer.write_batch(&batch).await.unwrap(); + writer.finish().await.unwrap(); + + // find the datafile we want to replace + let new_data_file = DataFile { + path: "test.lance".to_string(), + // the second column in the dataset + fields: vec![1], + // is located in the first column of this datafile + column_indices: vec![0], + file_major_version: 2, + file_minor_version: 0, + }; + + let dataset = Dataset::commit( + WriteDestination::Dataset(Arc::new(dataset)), + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, new_data_file)], + }, + Some(3), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.version().version, 4); + assert_eq!(dataset.get_fragments().len(), 1); + assert_eq!(dataset.get_fragments()[0].metadata.files.len(), 2); + assert_eq!(dataset.get_fragments()[0].metadata.files[0].fields, vec![0]); + assert_eq!(dataset.get_fragments()[0].metadata.files[1].fields, vec![1]); + + let batch = dataset.scan().try_into_batch().await.unwrap(); + assert_eq!(batch.num_rows(), 3); + assert_eq!( + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[1, 2, 3] + ); + assert_eq!( + batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[4, 5, 6] + ); + + // do it again but on the first column + // find the datafile we want to replace + let new_data_file = DataFile { + path: "test.lance".to_string(), + // the first column in the dataset + fields: vec![0], + // is located in the first column of this datafile + column_indices: vec![0], + file_major_version: 2, + file_minor_version: 0, + }; + + let dataset = Dataset::commit( + WriteDestination::Dataset(Arc::new(dataset)), + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, new_data_file)], + }, + Some(4), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap(); + + assert_eq!(dataset.version().version, 5); + assert_eq!(dataset.get_fragments().len(), 1); + assert_eq!(dataset.get_fragments()[0].metadata.files.len(), 2); + + let batch = dataset.scan().try_into_batch().await.unwrap(); + assert_eq!(batch.num_rows(), 3); + assert_eq!( + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[4, 5, 6] + ); + assert_eq!( + batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[4, 5, 6] + ); + } + + #[tokio::test] + async fn test_datafile_replacement_error() { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + true, + )])); + let empty_reader = RecordBatchIterator::new(vec![], schema.clone()); + let mut dataset = Dataset::write(empty_reader, "memory://", None) + .await + .unwrap(); + dataset.validate().await.unwrap(); + + let vals: Int32Array = vec![1, 2, 3].into(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(vals)]).unwrap(); + dataset + .append( + RecordBatchIterator::new(vec![Ok(batch)], schema.clone()), + None, + ) + .await + .unwrap(); + + let fragment = dataset.get_fragments().pop().unwrap().metadata; + + let extended_schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, true), + ArrowField::new("b", DataType::Int32, true), + ])); + + // add all null column + let dataset = Dataset::commit( + WriteDestination::Dataset(Arc::new(dataset)), + Operation::Merge { + fragments: vec![fragment], + schema: extended_schema.as_ref().try_into().unwrap(), + }, + Some(2), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap(); + + // find the datafile we want to replace + let new_data_file = DataFile { + path: "test.lance".to_string(), + // the second column in the dataset + fields: vec![1], + // is located in the first column of this datafile + column_indices: vec![0], + file_major_version: 2, + file_minor_version: 0, + }; + + let new_data_file = DataFile { + fields: vec![0, 1], + ..new_data_file + }; + + let err = Dataset::commit( + WriteDestination::Dataset(Arc::new(dataset.clone())), + Operation::DataReplacement { + replacements: vec![DataReplacementGroup(0, new_data_file)], + }, + Some(4), + None, + None, + Arc::new(Default::default()), + false, + ) + .await + .unwrap_err(); + assert!(err + .to_string() + .contains("Expected to modify the fragment but no changes were made")); + } + + #[tokio::test] + async fn test_replace_dataset() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let data = gen() + .col("int", array::step::()) + .into_batch_rows(RowCount::from(20)) + .unwrap(); + let data1 = data.slice(0, 10); + let data2 = data.slice(10, 10); + let mut ds = InsertBuilder::new(test_uri) + .execute(vec![data1]) + .await + .unwrap(); + + let test_path = Path::from_filesystem_path(test_uri).unwrap(); + ds.object_store().remove_dir_all(test_path).await.unwrap(); + + let ds2 = InsertBuilder::new(test_uri) + .execute(vec![data2.clone()]) + .await + .unwrap(); + + ds.checkout_latest().await.unwrap(); + let roundtripped = ds.scan().try_into_batch().await.unwrap(); + assert_eq!(roundtripped, data2); + + ds.validate().await.unwrap(); + ds2.validate().await.unwrap(); + assert_eq!(ds.manifest.version, 1); + assert_eq!(ds2.manifest.version, 1); + } + + #[tokio::test] + async fn test_session_store_registry() { + // Create a session + let session = Arc::new(Session::default()); + let registry = session.store_registry(); + assert!(registry.active_stores().is_empty()); + + // Create a dataset with memory store + let write_params = WriteParams { + session: Some(session.clone()), + ..Default::default() + }; + let batch = RecordBatch::try_new( + Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + false, + )])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let dataset = InsertBuilder::new("memory://test") + .with_params(&write_params) + .execute(vec![batch.clone()]) + .await + .unwrap(); + + // Assert there is one active store. + assert_eq!(registry.active_stores().len(), 1); + + // If we create another dataset also in memory, it should re-use the + // existing store. + let dataset2 = InsertBuilder::new("memory://test2") + .with_params(&write_params) + .execute(vec![batch.clone()]) + .await + .unwrap(); + assert_eq!(registry.active_stores().len(), 1); + assert_eq!( + Arc::as_ptr(&dataset.object_store().inner), + Arc::as_ptr(&dataset2.object_store().inner) + ); + + // If we create another with **different parameters**, it should create a new store. + let write_params2 = WriteParams { + session: Some(session.clone()), + store_params: Some(ObjectStoreParams { + block_size: Some(10_000), + ..Default::default() + }), + ..Default::default() + }; + let dataset3 = InsertBuilder::new("memory://test3") + .with_params(&write_params2) + .execute(vec![batch.clone()]) + .await + .unwrap(); + assert_eq!(registry.active_stores().len(), 2); + assert_ne!( + Arc::as_ptr(&dataset.object_store().inner), + Arc::as_ptr(&dataset3.object_store().inner) + ); + + // Remove both datasets + drop(dataset3); + assert_eq!(registry.active_stores().len(), 1); + drop(dataset2); + drop(dataset); + assert_eq!(registry.active_stores().len(), 0); + } } diff --git a/rust/lance/src/dataset/blob.rs b/rust/lance/src/dataset/blob.rs index 67f8c7081bf..8185391e7c5 100644 --- a/rust/lance/src/dataset/blob.rs +++ b/rust/lance/src/dataset/blob.rs @@ -19,7 +19,7 @@ use lance_core::{ }; use lance_io::traits::Reader; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tokio::sync::Mutex; use crate::io::exec::{ShareableRecordBatchStream, ShareableRecordBatchStreamAdapter}; diff --git a/rust/lance/src/dataset/builder.rs b/rust/lance/src/dataset/builder.rs index 342965852aa..81de2b42afb 100644 --- a/rust/lance/src/dataset/builder.rs +++ b/rust/lance/src/dataset/builder.rs @@ -4,8 +4,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use lance_file::datatypes::populate_schema_dictionary; use lance_io::object_store::{ - ObjectStore, ObjectStoreParams, ObjectStoreRegistry, StorageOptions, - DEFAULT_CLOUD_IO_PARALLELISM, + ObjectStore, ObjectStoreParams, StorageOptions, DEFAULT_CLOUD_IO_PARALLELISM, }; use lance_table::{ format::Manifest, @@ -13,7 +12,7 @@ use lance_table::{ }; use object_store::{aws::AwsCredentialProvider, path::Path, DynObjectStore}; use prost::Message; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use url::Url; @@ -39,7 +38,6 @@ pub struct DatasetBuilder { options: ObjectStoreParams, version: Option, table_uri: String, - object_store_registry: Arc, } impl DatasetBuilder { @@ -53,7 +51,6 @@ impl DatasetBuilder { session: None, version: None, manifest: None, - object_store_registry: Arc::new(ObjectStoreRegistry::default()), } } } @@ -114,6 +111,8 @@ impl DatasetBuilder { } /// Directly set the object store to use. + #[deprecated(note = "Implement an ObjectStoreProvider instead")] + #[allow(deprecated)] pub fn with_object_store( mut self, object_store: Arc, @@ -179,8 +178,6 @@ impl DatasetBuilder { self.commit_handler = Some(commit_handler); } - self.object_store_registry = read_params.object_store_registry.clone(); - self } @@ -194,8 +191,6 @@ impl DatasetBuilder { self.commit_handler = Some(commit_handler); } - self.object_store_registry = write_params.object_store_registry.clone(); - self } @@ -209,13 +204,10 @@ impl DatasetBuilder { self } - pub fn with_object_store_registry(mut self, registry: Arc) -> Self { - self.object_store_registry = registry; - self - } - /// Build a lance object store for the given config - pub async fn build_object_store(self) -> Result<(ObjectStore, Path, Arc)> { + pub async fn build_object_store( + self, + ) -> Result<(Arc, Path, Arc)> { let commit_handler = match self.commit_handler { Some(commit_handler) => Ok(commit_handler), None => commit_handler_from_url(&self.table_uri, &Some(self.options.clone())).await, @@ -229,9 +221,16 @@ impl DatasetBuilder { .unwrap_or_default(); let download_retry_count = storage_options.download_retry_count(); + let store_registry = self + .session + .as_ref() + .map(|s| s.store_registry()) + .unwrap_or_default(); + + #[allow(deprecated)] match &self.options.object_store { Some(store) => Ok(( - ObjectStore::new( + Arc::new(ObjectStore::new( store.0.clone(), store.1.clone(), self.options.block_size, @@ -242,13 +241,13 @@ impl DatasetBuilder { // cloud-like DEFAULT_CLOUD_IO_PARALLELISM, download_retry_count, - ), + )), Path::from(store.1.path()), commit_handler, )), None => { let (store, path) = ObjectStore::from_uri_and_params( - self.object_store_registry.clone(), + store_registry, &self.table_uri, &self.options, ) @@ -260,11 +259,12 @@ impl DatasetBuilder { #[instrument(skip_all)] pub async fn load(mut self) -> Result { - let session = match self.session.take() { - Some(session) => session, + let session = match self.session.as_ref() { + Some(session) => session.clone(), None => Arc::new(Session::new( self.index_cache_size, self.metadata_cache_size, + Default::default(), )), }; @@ -283,7 +283,7 @@ impl DatasetBuilder { Ref::Version(v) => Some(v), Ref::Tag(t) => { let tags = Tags::new( - Arc::new(object_store.clone()), + object_store.clone(), commit_handler.clone(), base_path.clone(), ); @@ -323,7 +323,7 @@ impl DatasetBuilder { }; Dataset::checkout_manifest( - Arc::new(object_store), + object_store, base_path, table_uri, manifest, @@ -331,6 +331,7 @@ impl DatasetBuilder { session, commit_handler, location.naming_scheme, + location.e_tag, ) .await } diff --git a/rust/lance/src/dataset/cleanup.rs b/rust/lance/src/dataset/cleanup.rs index 16dd432c864..206a36b3faf 100644 --- a/rust/lance/src/dataset/cleanup.rs +++ b/rust/lance/src/dataset/cleanup.rs @@ -35,10 +35,18 @@ use chrono::{DateTime, TimeDelta, Utc}; use futures::{stream, StreamExt, TryStreamExt}; -use lance_core::{Error, Result}; +use humantime::parse_duration; +use lance_core::{ + utils::tracing::{ + AUDIT_MODE_DELETE, AUDIT_MODE_DELETE_UNVERIFIED, AUDIT_TYPE_DATA, AUDIT_TYPE_DELETION, + AUDIT_TYPE_INDEX, AUDIT_TYPE_MANIFEST, TRACE_FILE_AUDIT, + }, + Error, Result, +}; use lance_table::{ format::{Index, Manifest}, io::{ + commit::ManifestLocation, deletion::deletion_file_path, manifest::{read_manifest, read_manifest_indexes}, }, @@ -49,6 +57,7 @@ use std::{ future, sync::{Mutex, MutexGuard}, }; +use tracing::{info, instrument, Span}; use crate::{utils::temporal::utc_now, Dataset}; @@ -143,6 +152,7 @@ impl<'a> CleanupTask<'a> { self.delete_unreferenced_files(inspection).await } + #[instrument(level = "debug", skip_all)] async fn process_manifests( &'a self, tagged_versions: &HashSet, @@ -150,10 +160,10 @@ impl<'a> CleanupTask<'a> { let inspection = Mutex::new(CleanupInspection::default()); self.dataset .commit_handler - .list_manifests(&self.dataset.base, &self.dataset.object_store.inner) + .list_manifest_locations(&self.dataset.base, &self.dataset.object_store.inner) .await? - .try_for_each_concurrent(self.dataset.object_store.io_parallelism(), |path| { - self.process_manifest_file(path, &inspection, tagged_versions) + .try_for_each_concurrent(self.dataset.object_store.io_parallelism(), |location| { + self.process_manifest_file(location, &inspection, tagged_versions) }) .await?; Ok(inspection.into_inner().unwrap()) @@ -161,7 +171,7 @@ impl<'a> CleanupTask<'a> { async fn process_manifest_file( &self, - path: Path, + location: ManifestLocation, inspection: &Mutex, tagged_versions: &HashSet, ) -> Result<()> { @@ -171,7 +181,8 @@ impl<'a> CleanupTask<'a> { // ignore it then we might delete valid data files thinking they are not // referenced. - let manifest = read_manifest(&self.dataset.object_store, &path, None).await?; + let manifest = + read_manifest(&self.dataset.object_store, &location.path, location.size).await?; let dataset_version = self.dataset.version().version; // Don't delete the latest version, even if it is old. Don't delete tagged versions, @@ -180,7 +191,8 @@ impl<'a> CleanupTask<'a> { let is_latest = dataset_version <= manifest.version; let is_tagged = tagged_versions.contains(&manifest.version); let in_working_set = is_latest || manifest.timestamp() >= self.before || is_tagged; - let indexes = read_manifest_indexes(&self.dataset.object_store, &path, &manifest).await?; + let indexes = + read_manifest_indexes(&self.dataset.object_store, &location.path, &manifest).await?; let mut inspection = inspection.lock().unwrap(); @@ -191,7 +203,7 @@ impl<'a> CleanupTask<'a> { self.process_manifest(&manifest, &indexes, in_working_set, &mut inspection)?; if !in_working_set { - inspection.old_manifests.push(path.clone()); + inspection.old_manifests.push(location.path.clone()); } Ok(()) } @@ -239,6 +251,7 @@ impl<'a> CleanupTask<'a> { Ok(()) } + #[instrument(level = "debug", skip_all, fields(old_versions = inspection.old_manifests.len(), bytes_removed = tracing::field::Empty))] async fn delete_unreferenced_files( &self, inspection: CleanupInspection, @@ -279,7 +292,12 @@ impl<'a> CleanupTask<'a> { .try_fold(0, |acc, size| async move { Ok(acc + (size as u64)) }) .await; - let old_manifests_stream = stream::iter(old_manifests).map(Result::::Ok).boxed(); + let old_manifests_stream = stream::iter(old_manifests) + .map(|path| { + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_DELETE, type=AUDIT_TYPE_MANIFEST, path = path.to_string()); + Ok(path) + }) + .boxed(); let all_paths_to_remove = stream::iter(vec![unreferenced_paths, old_manifests_stream]).flatten(); @@ -294,6 +312,10 @@ impl<'a> CleanupTask<'a> { let mut removal_stats = removal_stats.into_inner().unwrap(); removal_stats.old_versions = num_old_manifests as u64; removal_stats.bytes_removed += manifest_bytes_removed?; + + let span = Span::current(); + span.record("bytes_removed", removal_stats.bytes_removed); + Ok(removal_stats) } @@ -325,12 +347,15 @@ impl<'a> CleanupTask<'a> { .contains(uuid.as_ref()) { return Ok(None); - } else if !maybe_in_progress - || inspection - .verified_files - .index_uuids - .contains(uuid.as_ref()) + } else if !maybe_in_progress { + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_DELETE_UNVERIFIED, type=AUDIT_TYPE_INDEX, path = path.to_string()); + return Ok(Some(path)); + } else if inspection + .verified_files + .index_uuids + .contains(uuid.as_ref()) { + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_DELETE, type=AUDIT_TYPE_INDEX, path = path.to_string()); return Ok(Some(path)); } } else { @@ -346,12 +371,15 @@ impl<'a> CleanupTask<'a> { .contains(&relative_path) { Ok(None) - } else if !maybe_in_progress - || inspection - .verified_files - .data_paths - .contains(&relative_path) + } else if !maybe_in_progress { + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_DELETE_UNVERIFIED, type=AUDIT_TYPE_DATA, path = path.to_string()); + Ok(Some(path)) + } else if inspection + .verified_files + .data_paths + .contains(&relative_path) { + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_DELETE, type=AUDIT_TYPE_DATA, path = path.to_string()); Ok(Some(path)) } else { Ok(None) @@ -373,12 +401,15 @@ impl<'a> CleanupTask<'a> { .contains(&relative_path) { Ok(None) - } else if !maybe_in_progress - || inspection - .verified_files - .delete_paths - .contains(&relative_path) + } else if !maybe_in_progress { + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_DELETE_UNVERIFIED, type=AUDIT_TYPE_DELETION, path = path.to_string()); + Ok(Some(path)) + } else if inspection + .verified_files + .delete_paths + .contains(&relative_path) { + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_DELETE, type=AUDIT_TYPE_DELETION, path = path.to_string()); Ok(Some(path)) } else { Ok(None) @@ -438,6 +469,53 @@ pub async fn cleanup_old_versions( cleanup.run().await } +/// If the dataset config has `lance.auto_cleanup` parameters set, +/// this function automatically calls `dataset.cleanup_old_versions` +/// every `lance.auto_cleanup.interval` versions. This function calls +/// `dataset.cleanup_old_versions` with `lance.auto_cleanup.older_than` +/// for `older_than` and `Some(false)` for both `delete_unverified` and +/// `error_if_tagged_old_versions`. +pub async fn auto_cleanup_hook( + dataset: &Dataset, + manifest: &Manifest, +) -> Result> { + if let Some(older_than) = manifest.config.get("lance.auto_cleanup.older_than") { + if let Some(interval) = manifest.config.get("lance.auto_cleanup.interval") { + let std_older_than = match parse_duration(older_than) { + Ok(t) => t, + Err(e) => { + return Err(Error::Cleanup { + message: format!( + "Error encountered while parsing lance.auto_cleanup.older_than as std::time::Duration: {}", + e + ), + }) + } + }; + let older_than = TimeDelta::from_std(std_older_than).unwrap_or(TimeDelta::MAX); + let interval: u64 = match interval.parse() { + Ok(i) => i, + Err(e) => { + return Err(Error::Cleanup { + message: format!( + "Error encountered while parsing lance.auto_cleanup.interval as u64: {}", + e + ), + }) + } + }; + if manifest.version % interval == 0 { + return Ok(Some( + dataset + .cleanup_old_versions(older_than, Some(false), Some(false)) + .await?, + )); + } + } + } + Ok(None) +} + fn tagged_old_versions_cleanup_error( tags: &HashMap, tagged_old_versions: &HashSet, @@ -476,7 +554,7 @@ mod tests { use lance_linalg::distance::MetricType; use lance_table::io::commit::RenameCommitHandler; use lance_testing::datagen::{some_batch, BatchGenerator, IncrementingInt32}; - use snafu::{location, Location}; + use snafu::location; use crate::{ dataset::{builder::DatasetBuilder, ReadParams, WriteMode, WriteParams}, @@ -556,7 +634,7 @@ mod tests { pub clock: MockClock<'a>, } - impl<'a> MockDatasetFixture<'a> { + impl MockDatasetFixture<'_> { fn try_new() -> Result { let tmpdir = tempdir()?; // let tmpdir_uri = to_obj_store_uri(tmpdir.path())?; @@ -931,6 +1009,96 @@ mod tests { assert_eq!(removed.old_versions, 1); } + #[tokio::test] + async fn auto_cleanup_old_versions() { + // Every n commits, all versions older than T should be deleted. + // + // We first make many commits and check that all of the versions are + // present. We then wait until the "older_than" period has elapsed and + // make many more commits. We check that, without explicitly calling + // `fixture.run_cleanup`, the old versions are automatically cleaned + // up and only the new ones remain. File counts are made after every + // commit. + let fixture = MockDatasetFixture::try_new().unwrap(); + + fixture.create_some_data().await.unwrap(); + + let dataset_config = &fixture.open().await.unwrap().manifest.config; + let cleanup_interval: usize = dataset_config + .get("lance.auto_cleanup.interval") + .unwrap() + .parse() + .unwrap(); + + let cleanup_older_than = TimeDelta::from_std( + parse_duration(dataset_config.get("lance.auto_cleanup.older_than").unwrap()).unwrap(), + ) + .unwrap(); + + // Helper function to check that the number of files is correct. + async fn check_num_files<'a>( + fixture: &'a MockDatasetFixture<'a>, + num_expected_files: usize, + ) { + let file_count = fixture.count_files().await.unwrap(); + + assert_eq!(file_count.num_data_files, num_expected_files); + assert_eq!(file_count.num_manifest_files, num_expected_files); + assert_eq!(file_count.num_tx_files, num_expected_files); + } + + // First, write many files within the "older_than" window. Check that + // no files are automatically cleaned up. + for num_expected_files in 2..2 * cleanup_interval { + fixture.overwrite_some_data().await.unwrap(); + check_num_files(&fixture, num_expected_files).await; + } + + // Fast forward so we are outside of the "older_than" window. + fixture + .clock + .set_system_time(cleanup_older_than + TimeDelta::minutes(1)); + + // Write more files and check that those outside of the "older_than" window + // are cleaned up. + for num_expected_files in 2..cleanup_interval { + fixture.overwrite_some_data().await.unwrap(); + check_num_files(&fixture, num_expected_files).await; + } + + // Overwrite auto cleanup params with custom values + let mut dataset = *(fixture.open().await.unwrap()); + let mut new_autoclean_params = HashMap::new(); + + let new_cleanup_older_than_str = "1month 2days 2h 42min 6sec"; + let new_cleanup_older_than = + TimeDelta::from_std(parse_duration(new_cleanup_older_than_str).unwrap()).unwrap(); + new_autoclean_params.insert( + "lance.auto_cleanup.older_than".to_string(), + new_cleanup_older_than_str.to_string(), + ); + + let new_cleanup_interval = 5; + new_autoclean_params.insert( + "lance.auto_cleanup.interval".to_string(), + new_cleanup_interval.to_string(), + ); + + dataset.update_config(new_autoclean_params).await.unwrap(); + + // Fast forward so we are outside of the new "older_than" window. + fixture + .clock + .set_system_time(cleanup_older_than + new_cleanup_older_than + TimeDelta::minutes(2)); + + fixture.overwrite_some_data().await.unwrap(); + + for num_expected_files in 2..new_cleanup_interval { + fixture.overwrite_some_data().await.unwrap(); + check_num_files(&fixture, num_expected_files).await; + } + } + #[tokio::test] async fn cleanup_recent_verified_files() { let fixture = MockDatasetFixture::try_new().unwrap(); @@ -1023,7 +1191,7 @@ mod tests { let before_count = fixture.count_files().await.unwrap(); // we store 2 files (index and quantized storage) for each index - assert_eq!(before_count.num_index_files, 1); + assert_eq!(before_count.num_index_files, 2); // Two user data files assert_eq!(before_count.num_data_files, 2); // Creating an index creates a new manifest so there are 3 total diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index b3cf1a10e40..75e867eab51 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -6,27 +6,30 @@ pub mod write; use std::borrow::Cow; -use std::collections::{BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::ops::Range; use std::sync::Arc; use arrow::compute::concat_batches; use arrow_array::cast::as_primitive_array; -use arrow_array::{new_null_array, RecordBatch, StructArray, UInt32Array, UInt64Array}; +use arrow_array::{ + new_null_array, RecordBatch, RecordBatchReader, StructArray, UInt32Array, UInt64Array, +}; use arrow_schema::Schema as ArrowSchema; use datafusion::logical_expr::Expr; use datafusion::scalar::ScalarValue; use futures::future::try_join_all; use futures::{join, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; -use lance_core::datatypes::SchemaCompareOptions; +use lance_core::datatypes::{OnMissing, OnTypeMismatch, SchemaCompareOptions}; use lance_core::utils::deletion::DeletionVector; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{datatypes::Schema, Error, Result}; -use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID_FIELD}; +use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD}; use lance_datafusion::utils::StreamingWriteSource; use lance_encoding::decoder::DecoderPlugins; use lance_file::reader::{read_batch, FileReader}; use lance_file::v2::reader::{CachedFileMetadata, FileReaderOptions, ReaderProjection}; +use lance_file::v2::LanceEncodingsIo; use lance_file::version::LanceFileVersion; use lance_file::{determine_file_version, v2}; use lance_io::object_store::ObjectStore; @@ -39,13 +42,14 @@ use lance_table::utils::stream::{ wrap_with_row_id_and_delete, ReadBatchFutStream, ReadBatchTask, ReadBatchTaskStream, RowIdAndDeletesConfig, }; -use snafu::{location, Location}; +use snafu::location; use self::write::FragmentCreateBuilder; use super::hash_joiner::HashJoiner; use super::rowids::load_row_id_sequence; use super::scanner::Scanner; +use super::statistics::FieldStatistics; use super::updater::Updater; use super::{schema_evolution, NewColumnTransform, WriteParams}; use crate::arrow::*; @@ -87,6 +91,7 @@ pub trait GenericFileReader: std::fmt::Debug + Send + Sync { indices: &[u32], batch_size: u32, projection: Arc, + take_priority: Option, ) -> Result; /// Return the number of rows in the file @@ -95,6 +100,9 @@ pub trait GenericFileReader: std::fmt::Debug + Send + Sync { /// Schema of the reader fn projection(&self) -> &Arc; + /// Update storage statistics (ignored by v1 reader) + fn update_storage_stats(&self, field_stats: &mut HashMap); + // Helper functions to fallback to the legacy implementation while we // slowly migrate functionality over to the generic reader @@ -216,6 +224,7 @@ impl GenericFileReader for V1Reader { indices: &[u32], _batch_size: u32, projection: Arc, + _take_priority: Option, ) -> Result { let indices_vec = indices.to_vec(); let reader = self.reader.clone(); @@ -238,6 +247,10 @@ impl GenericFileReader for V1Reader { self.reader.len() as u32 } + fn update_storage_stats(&self, _field_stats: &mut HashMap) { + // No-op for v1 files + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } @@ -265,6 +278,8 @@ mod v2_adapter { reader: Arc, projection: Arc, field_id_to_column_idx: Arc>, + default_priority: u32, + file_scheduler: FileScheduler, } impl Reader { @@ -272,11 +287,15 @@ mod v2_adapter { reader: Arc, projection: Arc, field_id_to_column_idx: Arc>, + default_priority: u32, + file_scheduler: FileScheduler, ) -> Self { Self { reader, projection, field_id_to_column_idx, + default_priority, + file_scheduler, } } } @@ -340,6 +359,7 @@ mod v2_adapter { indices: &[u32], batch_size: u32, projection: Arc, + take_priority: Option, ) -> Result { let indices = UInt32Array::from(indices.to_vec()); let projection = ReaderProjection::from_field_ids( @@ -347,8 +367,19 @@ mod v2_adapter { projection.as_ref(), self.field_id_to_column_idx.as_ref(), )?; - Ok(self - .reader + + let reader = if let Some(take_priority) = take_priority { + let op_priority = ((take_priority as u64) << 32) | self.default_priority as u64; + let scheduler = self.file_scheduler.with_priority(op_priority); + Arc::new( + self.reader + .with_scheduler(Arc::new(LanceEncodingsIo(scheduler))), + ) + } else { + self.reader.clone() + }; + + Ok(reader .read_tasks( ReadBatchParams::Indices(indices), batch_size, @@ -362,6 +393,29 @@ mod v2_adapter { .boxed()) } + fn update_storage_stats(&self, field_stats: &mut HashMap) { + let file_statistics = self.reader.file_statistics(); + let column_idx_to_field_id = self + .field_id_to_column_idx + .iter() + .map(|(field_id, column_idx)| (*column_idx, *field_id)) + .collect::>(); + + // Some fields span more than one column. We assume a column that doesn't have an + // entry in the field_id_to_column_idx map is a continuation of the previous field. + let mut current_field_id = 0; + for (column_idx, stats) in file_statistics.columns.iter().enumerate() { + if let Some(field_id) = column_idx_to_field_id.get(&(column_idx as u32)) { + current_field_id = *field_id; + } + // If the field_id is not in the map then the field may no longer be part of the + // dataset + if let Some(field_stats) = field_stats.get_mut(¤t_field_id) { + field_stats.bytes_on_disk += stats.size_bytes; + } + } + } + fn projection(&self) -> &Arc { &self.projection } @@ -454,11 +508,16 @@ impl GenericFileReader for NullReader { indices: &[u32], batch_size: u32, projection: Arc, + _take_priority: Option, ) -> Result { let num_rows = indices.len() as u64; self.read_range_tasks(0..num_rows, batch_size, projection) } + fn update_storage_stats(&self, _field_stats: &mut HashMap) { + // No-op for null reader + } + fn projection(&self) -> &Arc { &self.schema } @@ -490,6 +549,19 @@ pub struct FragReadConfig { pub with_row_id: bool, // Add the row address column pub with_row_address: bool, + /// The scan scheduler to use for reading data files. + /// + /// This should be specified if multiple readers are being used in + /// an operation + pub scan_scheduler: Option>, + /// The default scan priority to use for reading data files + /// + /// Only used if `scan_scheduler` is provided + /// + /// The overall priority for reads will be + /// + /// operation_priority: u32 | reader_priority: u32 | file_position: u64 + pub reader_priority: Option, } impl FragReadConfig { @@ -502,6 +574,16 @@ impl FragReadConfig { self.with_row_address = value; self } + + pub fn with_scan_scheduler(mut self, value: Arc) -> Self { + self.scan_scheduler = Some(value); + self + } + + pub fn with_reader_priority(mut self, value: u32) -> Self { + self.reader_priority = Some(value); + self + } } impl FileFragment { @@ -531,6 +613,21 @@ impl FileFragment { builder.write(source, Some(id as u64)).await } + /// Create a list of [`FileFragment`] from a [`StreamingWriteSource`]. + pub async fn create_fragments( + dataset_uri: &str, + source: impl StreamingWriteSource, + params: Option, + ) -> Result> { + let mut builder = FragmentCreateBuilder::new(dataset_uri); + + if let Some(params) = params.as_ref() { + builder = builder.write_params(params); + } + + builder.write_fragments(source).await + } + pub async fn create_from_file( filename: &str, dataset: &Dataset, @@ -605,6 +702,24 @@ impl FileFragment { } } + pub(crate) async fn update_storage_stats( + &self, + field_stats: &mut HashMap, + dataset_schema: &Schema, + scan_scheduler: Arc, + ) -> Result<()> { + for reader in self + .open_readers( + dataset_schema, + &FragReadConfig::default().with_scan_scheduler(scan_scheduler), + ) + .await? + { + reader.update_storage_stats(field_stats); + } + Ok(()) + } + pub fn dataset(&self) -> &Dataset { self.dataset.as_ref() } @@ -646,7 +761,7 @@ impl FileFragment { /// - `projection`: The projection schema. /// - `read_config`: Controls what columns are included in the output. /// - `scan_scheduler`: The scheduler to use for reading data files. If not supplied - /// and the data is v2 data then a new scheduler will be created + /// and the data is v2 data then a new scheduler will be created /// /// `projection` may be an empty schema only if `with_row_id` is true. In that /// case, the reader will only be generating row ids. @@ -654,9 +769,8 @@ impl FileFragment { &self, projection: &Schema, read_config: FragReadConfig, - scan_scheduler: Option<(Arc, u64)>, ) -> Result { - let open_files = self.open_readers(projection, scan_scheduler); + let open_files = self.open_readers(projection, &read_config); let deletion_vec_load = self.load_deletion_vector(&self.dataset.object_store, &self.metadata); @@ -693,7 +807,7 @@ impl FileFragment { row_id_sequence, opened_files, ArrowSchema::from(projection), - self.count_rows().await?, + self.count_rows(None).await?, num_physical_rows, )?; @@ -715,7 +829,7 @@ impl FileFragment { &self, data_file: &DataFile, projection: Option<&Schema>, - scan_scheduler: Option<(Arc, u64)>, + read_config: &FragReadConfig, ) -> Result>> { let full_schema = self.dataset.schema(); // The data file may contain fields that are not part of the dataset any longer, remove those @@ -739,9 +853,11 @@ impl FileFragment { Some(&self.dataset.session.file_metadata_cache), ) .await?; - let initialized_schema = reader - .schema() - .project_by_schema(schema_per_file.as_ref())?; + let initialized_schema = reader.schema().project_by_schema( + schema_per_file.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?; let reader = V1Reader::new(reader, Arc::new(initialized_schema)); Ok(Some(Box::new(reader))) } else { @@ -751,22 +867,30 @@ impl FileFragment { Ok(None) } else { let path = self.dataset.data_dir().child(data_file.path.as_str()); - let (store_scheduler, priority_offset) = scan_scheduler.unwrap_or_else(|| { - ( - ScanScheduler::new( - self.dataset.object_store.clone(), - SchedulerConfig::max_bandwidth(&self.dataset.object_store), - ), - 0, - ) - }); + let (store_scheduler, reader_priority) = + if let Some(scan_scheduler) = read_config.scan_scheduler.as_ref() { + ( + scan_scheduler.clone(), + read_config.reader_priority.unwrap_or(0), + ) + } else { + ( + ScanScheduler::new( + self.dataset.object_store.clone(), + SchedulerConfig::max_bandwidth(&self.dataset.object_store), + ), + 0, + ) + }; let file_scheduler = store_scheduler - .open_file_with_priority(&path, priority_offset) + .open_file_with_priority(&path, reader_priority as u64) .await?; let file_metadata = self.get_file_metadata(&file_scheduler).await?; + let path = file_scheduler.reader().path().clone(); let reader = Arc::new( v2::reader::FileReader::try_open_with_file_metadata( - file_scheduler, + Arc::new(LanceEncodingsIo(file_scheduler.clone())), + path, None, Arc::::default(), file_metadata, @@ -789,7 +913,13 @@ impl FileFragment { } }), )); - let reader = v2_adapter::Reader::new(reader, schema_per_file, field_id_to_column_idx); + let reader = v2_adapter::Reader::new( + reader, + schema_per_file, + field_id_to_column_idx, + reader_priority, + file_scheduler, + ); Ok(Some(Box::new(reader))) } } @@ -797,20 +927,21 @@ impl FileFragment { async fn open_readers( &self, projection: &Schema, - scan_scheduler: Option<(Arc, u64)>, + read_config: &FragReadConfig, ) -> Result>> { let mut opened_files = vec![]; for data_file in &self.metadata.files { if let Some(reader) = self - .open_reader(data_file, Some(projection), scan_scheduler.clone()) + .open_reader(data_file, Some(projection), read_config) .await? { opened_files.push(reader); } } - // This should return immediately on modern datasets. - let num_rows = self.count_rows().await?; + // This should return immediately on modern datasets. Need to use physical_rows because + // deletions will be applied later + let num_rows = self.physical_rows().await?; // Check if there are any fields that are not in any data files let field_ids_in_files = opened_files @@ -830,15 +961,27 @@ impl FileFragment { } /// Count the rows in this fragment. - pub async fn count_rows(&self) -> Result { - let total_rows = self.physical_rows(); - - let deletion_count = self.count_deletions(); + pub async fn count_rows(&self, filter: Option) -> Result { + match filter { + Some(expr) => self + .scan() + .project(&Vec::::default()) + .unwrap() + .with_row_id() + .filter(&expr)? + .count_rows() + .await + .map(|v| v as usize), + None => { + let total_rows = self.physical_rows(); + let deletion_count = self.count_deletions(); - let (total_rows, deletion_count) = - futures::future::try_join(total_rows, deletion_count).await?; + let (total_rows, deletion_count) = + futures::future::try_join(total_rows, deletion_count).await?; - Ok(total_rows - deletion_count) + Ok(total_rows - deletion_count) + } + } } /// Get the number of rows that have been deleted in this fragment. @@ -914,7 +1057,7 @@ impl FileFragment { // Just open any file. All of them should have same size. let some_file = &self.metadata.files[0]; let reader = self - .open_reader(some_file, None, None) + .open_reader(some_file, None, &FragReadConfig::default()) .await? .ok_or_else(|| Error::Internal { message: format!( @@ -990,7 +1133,7 @@ impl FileFragment { let get_lengths = self.metadata.files.iter().map(|data_file| async move { let reader = self - .open_reader(data_file, None, None) + .open_reader(data_file, None, &FragReadConfig::default()) .await? .ok_or_else(|| { Error::corrupt_file( @@ -1208,7 +1351,6 @@ impl FileFragment { .open( projection, FragReadConfig::default().with_row_address(with_row_address), - None, ) .await?; @@ -1218,7 +1360,7 @@ impl FileFragment { reader.legacy_read_range_as_batch(range).await } else { // FIXME, change this method to streams - reader.take_as_batch(row_offsets).await + reader.take_as_batch(row_offsets, None).await } } @@ -1270,11 +1412,14 @@ impl FileFragment { let mut schema = self.dataset.schema().clone(); let mut with_row_addr = false; + let mut with_row_id = false; if let Some(columns) = columns { let mut projection = Vec::new(); for column in columns { if column.as_ref() == ROW_ADDR { with_row_addr = true; + } else if column.as_ref() == ROW_ID { + with_row_id = true; } else { projection.push(column.as_ref()); } @@ -1290,12 +1435,13 @@ impl FileFragment { } // If there is no projection, we at least need to read the row addresses - with_row_addr |= schema.fields.is_empty(); + with_row_addr |= !with_row_id && schema.fields.is_empty(); let reader = self.open( &schema, - FragReadConfig::default().with_row_address(with_row_addr), - None, + FragReadConfig::default() + .with_row_address(with_row_addr) + .with_row_id(with_row_id), ); let deletion_vector = read_deletion_file( &self.dataset.base, @@ -1309,6 +1455,66 @@ impl FileFragment { Updater::try_new(self.clone(), reader, deletion_vector, schemas, batch_size) } + pub async fn merge_columns( + &mut self, + stream: impl RecordBatchReader + Send + 'static, + left_on: &str, + right_on: &str, + max_field_id: i32, + ) -> Result<(Fragment, Schema)> { + let stream = Box::new(stream); + if self.schema().field(left_on).is_none() && left_on != ROW_ID && left_on != ROW_ADDR { + return Err(Error::invalid_input( + format!( + "Column {} does not exist in the left side fragment", + left_on + ), + location!(), + )); + }; + let right_schema = stream.schema(); + if right_schema.field_with_name(right_on).is_err() { + return Err(Error::invalid_input( + format!( + "Column {} does not exist in the right side fragment", + right_on + ), + location!(), + )); + }; + + for field in right_schema.fields() { + if field.name() == right_on { + // right_on is allowed to exist in the dataset, since it may be + // the same as left_on. + continue; + } + if self.schema().field(field.name()).is_some() { + return Err(Error::invalid_input( + format!( + "Column {} exists in left side fragment and right side dataset", + field.name() + ), + location!(), + )); + } + } + // Hash join + let joiner = Arc::new(HashJoiner::try_new(stream, right_on).await?); + // Final schema is union of current schema, plus the RHS schema without + // the right_on key. + let mut new_schema: Schema = self.schema().merge(joiner.out_schema().as_ref())?; + new_schema.set_field_id(Some(max_field_id)); + + let new_fragment = self + .clone() + .merge(left_on, &joiner) + .await + .map(|f| f.metadata)?; + + Ok((new_fragment, new_schema)) + } + pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result { let mut updater = self.updater(Some(&[join_column]), None, None).await?; @@ -1838,12 +2044,7 @@ impl FragmentReader { let merged = if self.with_row_addr as usize + self.with_row_id as usize == self.output_schema.fields.len() { - let selected_rows = params - .slice(0, total_num_rows as usize) - .unwrap() - .to_offsets() - .unwrap() - .len(); + let selected_rows = params.to_offsets_total(total_num_rows).len(); let tasks = (0..selected_rows) .step_by(batch_size as usize) .map(move |offset| { @@ -1986,20 +2187,36 @@ impl FragmentReader { } /// Take rows from this fragment. - pub async fn take(&self, indices: &[u32], batch_size: u32) -> Result { + pub async fn take( + &self, + indices: &[u32], + batch_size: u32, + take_priority: Option, + ) -> Result { let indices_arr = UInt32Array::from(indices.to_vec()); self.new_read_impl( ReadBatchParams::Indices(indices_arr), batch_size, - move |reader| reader.take_all_tasks(indices, batch_size, reader.projection().clone()), + move |reader| { + reader.take_all_tasks( + indices, + batch_size, + reader.projection().clone(), + take_priority, + ) + }, ) } /// Take rows from this fragment, will perform a copy if the underlying reader returns multiple /// batches. May return an error if the taken rows do not fit into a single batch. - pub async fn take_as_batch(&self, indices: &[u32]) -> Result { + pub async fn take_as_batch( + &self, + indices: &[u32], + take_priority: Option, + ) -> Result { let batches = self - .take(indices, u32::MAX) + .take(indices, u32::MAX, take_priority) .await? .buffered(get_num_compute_intensive_cpus()) .try_collect::>() @@ -2017,7 +2234,6 @@ mod tests { use lance_core::ROW_ID; use lance_datagen::{array, gen, RowCount}; use lance_file::version::LanceFileVersion; - use lance_io::object_store::ObjectStoreRegistry; use pretty_assertions::assert_eq; use rstest::rstest; use tempfile::tempdir; @@ -2203,7 +2419,6 @@ mod tests { .open( fragment.schema(), FragReadConfig::default().with_row_id(with_row_id), - None, ) .await .unwrap(); @@ -2227,7 +2442,69 @@ mod tests { .open( fragment.schema(), FragReadConfig::default().with_row_id(with_row_id), - None, + ) + .await + .unwrap(); + for valid_range in [0..20, 0..10, 10..20] { + reader + .read_range(valid_range, 100) + .unwrap() + .buffered(1) + .try_collect::>() + .await + .unwrap(); + } + for invalid_range in [0..21, 21..22] { + assert!(reader.read_range(invalid_range, 100).is_err()); + } + } + } + + #[tokio::test] + async fn test_rowid_rowaddr_only() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + // Creates 400 rows in 10 fragments + let mut dataset = create_dataset(test_uri, LanceFileVersion::Legacy).await; + // Delete last 20 rows in first fragment + dataset.delete("i >= 20").await.unwrap(); + // Last fragment has 20 rows but 40 addressable rows + let fragment = &dataset.get_fragments()[0]; + assert_eq!(fragment.metadata.num_rows().unwrap(), 20); + + // Test with take_range (all rows addressable) + for (with_row_id, with_row_address) in [(false, true), (true, false), (true, true)] { + let reader = fragment + .open( + &fragment.schema().project::<&str>(&[]).unwrap(), + FragReadConfig::default() + .with_row_id(with_row_id) + .with_row_address(with_row_address), + ) + .await + .unwrap(); + for valid_range in [0..40, 20..40] { + reader + .take_range(valid_range, 100) + .unwrap() + .buffered(1) + .try_collect::>() + .await + .unwrap(); + } + for invalid_range in [0..41, 41..42] { + assert!(reader.take_range(invalid_range, 100).is_err()); + } + } + + // Test with read_range (only non-deleted rows addressable) + for (with_row_id, with_row_address) in [(false, true), (true, false), (true, true)] { + let reader = fragment + .open( + &fragment.schema().project::<&str>(&[]).unwrap(), + FragReadConfig::default() + .with_row_id(with_row_id) + .with_row_address(with_row_address), ) .await .unwrap(); @@ -2262,7 +2539,6 @@ mod tests { .open( dataset.schema(), FragReadConfig::default().with_row_id(true), - None, ) .await .unwrap(); @@ -2354,7 +2630,6 @@ mod tests { .open( dataset.schema(), FragReadConfig::default().with_row_id(true), - None, ) .await .unwrap(); @@ -2547,10 +2822,10 @@ mod tests { config_upsert_values: None, }; - let registry = Arc::new(ObjectStoreRegistry::default()); - let new_dataset = Dataset::commit(test_uri, op, None, None, None, registry, false) - .await - .unwrap(); + let new_dataset = + Dataset::commit(test_uri, op, None, None, None, Default::default(), false) + .await + .unwrap(); assert_eq!(new_dataset.count_rows(None).await.unwrap(), dataset_rows); @@ -2560,7 +2835,7 @@ mod tests { assert_eq!(fragments.len(), 5); for f in fragments { assert_eq!(f.metadata.num_rows(), Some(40)); - assert_eq!(f.count_rows().await.unwrap(), 40); + assert_eq!(f.count_rows(None).await.unwrap(), 40); assert_eq!(f.metadata().deletion_file, None); } } @@ -2576,10 +2851,18 @@ mod tests { let dataset = create_dataset(test_uri, data_storage_version).await; let fragment = dataset.get_fragments().pop().unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 40); + assert_eq!(fragment.count_rows(None).await.unwrap(), 40); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_none()); + assert_eq!( + fragment + .count_rows(Some("i < 170".to_string())) + .await + .unwrap(), + 10 + ); + let fragment = fragment .delete("i >= 160 and i <= 172") .await @@ -2588,7 +2871,7 @@ mod tests { fragment.validate().await.unwrap(); - assert_eq!(fragment.count_rows().await.unwrap(), 27); + assert_eq!(fragment.count_rows(None).await.unwrap(), 27); assert_eq!(fragment.physical_rows().await.unwrap(), 40); assert!(fragment.metadata.deletion_file.is_some()); assert_eq!( @@ -2648,10 +2931,10 @@ mod tests { config_upsert_values: None, }; - let registry = Arc::new(ObjectStoreRegistry::default()); - let dataset = Dataset::commit(test_uri, op, None, None, None, registry, false) - .await - .unwrap(); + let dataset = + Dataset::commit(test_uri, op, None, None, None, Default::default(), false) + .await + .unwrap(); // We only kept the first fragment of 40 rows assert_eq!( @@ -2869,7 +3152,6 @@ mod tests { // Rearrange schema so it's `s` then `i`. let schema = updater.schema().unwrap().clone().project(&["s", "i"])?; - let registry = Arc::new(ObjectStoreRegistry::default()); let dataset = Dataset::commit( test_uri, @@ -2880,7 +3162,7 @@ mod tests { Some(dataset.manifest.version), None, None, - registry, + Default::default(), false, ) .await?; @@ -2894,9 +3176,9 @@ mod tests { .get_fragments() .first() .unwrap() - .open(dataset.schema(), FragReadConfig::default(), None) + .open(dataset.schema(), FragReadConfig::default()) .await?; - let actual_data = reader.take_as_batch(&[0, 1, 2]).await?; + let actual_data = reader.take_as_batch(&[0, 1, 2], None).await?; assert_eq!(expected_data.slice(0, 3), actual_data); let actual_data = reader @@ -2948,7 +3230,6 @@ mod tests { .open( &dataset.schema().project::<&str>(&[])?, FragReadConfig::default().with_row_id(true), - None, ) .await?; let batch = reader.legacy_read_range_as_batch(0..20).await?; @@ -2964,7 +3245,6 @@ mod tests { .open( &dataset.schema().project::<&str>(&[])?, FragReadConfig::default(), - None, ) .await; assert!(matches!(res, Err(Error::IO { .. }))); @@ -3011,14 +3291,13 @@ mod tests { let op = Operation::Append { fragments: vec![frag], }; - let object_store_registry = Arc::new(ObjectStoreRegistry::default()); let dataset = Dataset::commit( &dataset.uri, op, Some(dataset.version().version), None, None, - object_store_registry, + Default::default(), false, ) .await diff --git a/rust/lance/src/dataset/fragment/write.rs b/rust/lance/src/dataset/fragment/write.rs index 83e0fe8e21f..77bd34840b4 100644 --- a/rust/lance/src/dataset/fragment/write.rs +++ b/rust/lance/src/dataset/fragment/write.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::borrow::Cow; - use arrow_schema::Schema as ArrowSchema; use datafusion::execution::SendableRecordBatchStream; use futures::{StreamExt, TryStreamExt}; @@ -16,10 +14,12 @@ use lance_file::writer::FileWriter; use lance_io::object_store::ObjectStore; use lance_table::format::{DataFile, Fragment}; use lance_table::io::manifest::ManifestDescribing; -use snafu::{location, Location}; +use snafu::location; +use std::borrow::Cow; use uuid::Uuid; use crate::dataset::builder::DatasetBuilder; +use crate::dataset::write::do_write_fragments; use crate::dataset::{WriteMode, WriteParams, DATA_DIR}; use crate::Result; @@ -68,6 +68,15 @@ impl<'a> FragmentCreateBuilder<'a> { self.write_impl(stream, schema, id).await } + /// Write multi fragment which separated by max_rows_per_file. + pub async fn write_fragments( + &self, + source: impl StreamingWriteSource, + ) -> Result> { + let (stream, schema) = self.get_stream_and_schema(Box::new(source)).await?; + self.write_fragments_v2_impl(stream, schema).await + } + async fn write_v2_impl( &self, stream: SendableRecordBatchStream, @@ -80,7 +89,7 @@ impl<'a> FragmentCreateBuilder<'a> { Self::validate_schema(&schema, stream.schema().as_ref())?; let (object_store, base_path) = ObjectStore::from_uri_and_params( - params.object_store_registry.clone(), + params.store_registry(), self.dataset_uri, ¶ms.store_params.clone().unwrap_or_default(), ) @@ -136,6 +145,31 @@ impl<'a> FragmentCreateBuilder<'a> { Ok(fragment) } + async fn write_fragments_v2_impl( + &self, + stream: SendableRecordBatchStream, + schema: Schema, + ) -> Result> { + let params = self.write_params.map(Cow::Borrowed).unwrap_or_default(); + + Self::validate_schema(&schema, stream.schema().as_ref())?; + + let (object_store, base_path) = ObjectStore::from_uri_and_params( + params.store_registry(), + self.dataset_uri, + ¶ms.store_params.clone().unwrap_or_default(), + ) + .await?; + do_write_fragments( + object_store, + &base_path, + &schema, + stream, + params.into_owned(), + LanceFileVersion::Stable, + ) + .await + } async fn write_impl( &self, @@ -157,7 +191,7 @@ impl<'a> FragmentCreateBuilder<'a> { Self::validate_schema(&schema, stream.schema().as_ref())?; let (object_store, base_path) = ObjectStore::from_uri_and_params( - params.object_store_registry.clone(), + params.store_registry(), self.dataset_uri, ¶ms.store_params.clone().unwrap_or_default(), ) @@ -207,7 +241,15 @@ impl<'a> FragmentCreateBuilder<'a> { } async fn existing_dataset_schema(&self) -> Result> { - match DatasetBuilder::from_uri(self.dataset_uri).load().await { + let mut builder = DatasetBuilder::from_uri(self.dataset_uri); + let storage_options = self + .write_params + .and_then(|p| p.store_params.as_ref()) + .and_then(|p| p.storage_options.clone()); + if let Some(storage_options) = storage_options { + builder = builder.with_storage_options(storage_options); + } + match builder.load().await { Ok(dataset) => { // Use the schema from the dataset, because it has the correct // field ids. @@ -353,4 +395,93 @@ mod tests { assert_eq!(fragment.files[0].fields, vec![3, 1]); assert_eq!(fragment.files[0].column_indices, vec![0, 1]); } + + #[tokio::test] + async fn test_write_fragments_validation() { + // Writing with empty schema produces an error + let empty_schema = Arc::new(ArrowSchema::empty()); + let empty_reader = Box::new(RecordBatchIterator::new(vec![], empty_schema)); + let tmp_dir = tempfile::tempdir().unwrap(); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(empty_reader) + .await; + assert!(result.is_err()); + assert!( + matches!(result.as_ref().unwrap_err(), Error::InvalidInput { source, .. } + if source.to_string().contains("Cannot write with an empty schema.")), + "{:?}", + &result + ); + + // Writing empty reader produces an error + let arrow_schema = test_data().schema(); + let empty_reader = Box::new(RecordBatchIterator::new(vec![], arrow_schema.clone())); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(empty_reader) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 0); + + // Writing with incorrect schema produces an error. + let wrong_schema = arrow_schema + .as_ref() + .try_with_column(ArrowField::new("c", DataType::Utf8, false)) + .unwrap(); + let wrong_schema = Schema::try_from(&wrong_schema).unwrap(); + let result = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .schema(&wrong_schema) + .write_fragments(test_data()) + .await; + assert!(result.is_err()); + assert!( + matches!(result.as_ref().unwrap_err(), Error::SchemaMismatch { difference, .. } + if difference.contains("fields did not match")), + "{:?}", + &result + ); + } + + #[tokio::test] + async fn test_write_fragments_default_schema() { + // Infers schema and uses 0 as default field id + let data = test_data(); + let tmp_dir = tempfile::tempdir().unwrap(); + let fragments = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_fragments(data) + .await + .unwrap(); + + // If unspecified, the fragment id should be 0. + assert_eq!(fragments.len(), 1); + assert_eq!(fragments[0].deletion_file, None); + assert_eq!(fragments[0].files.len(), 1); + assert_eq!(fragments[0].files[0].fields, vec![0, 1]); + } + + #[tokio::test] + async fn test_write_fragments_with_options() { + // Uses provided schema. Field ids are correct in fragment metadata. + let data = test_data(); + let tmp_dir = tempfile::tempdir().unwrap(); + let writer_params = WriteParams { + max_rows_per_file: 1, + ..Default::default() + }; + let fragments = FragmentCreateBuilder::new(tmp_dir.path().to_str().unwrap()) + .write_params(&writer_params) + .write_fragments(data) + .await + .unwrap(); + + assert_eq!(fragments.len(), 3); + assert_eq!(fragments[0].deletion_file, None); + assert_eq!(fragments[0].files.len(), 1); + assert_eq!(fragments[0].files[0].column_indices, vec![0, 1]); + assert_eq!(fragments[1].deletion_file, None); + assert_eq!(fragments[1].files.len(), 1); + assert_eq!(fragments[1].files[0].column_indices, vec![0, 1]); + assert_eq!(fragments[2].deletion_file, None); + assert_eq!(fragments[2].files.len(), 1); + assert_eq!(fragments[2].files[0].column_indices, vec![0, 1]); + } } diff --git a/rust/lance/src/dataset/hash_joiner.rs b/rust/lance/src/dataset/hash_joiner.rs index 6333acb1d1e..8fdba38a4b8 100644 --- a/rust/lance/src/dataset/hash_joiner.rs +++ b/rust/lance/src/dataset/hash_joiner.rs @@ -13,7 +13,7 @@ use arrow_select::interleave::interleave; use dashmap::{DashMap, ReadOnlyView}; use futures::{StreamExt, TryStreamExt}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; -use snafu::{location, Location}; +use snafu::location; use tokio::task; use crate::datatypes::lance_supports_nulls; diff --git a/rust/lance/src/dataset/index.rs b/rust/lance/src/dataset/index.rs index 56c211d4c7d..ec0ce4f25ab 100644 --- a/rust/lance/src/dataset/index.rs +++ b/rust/lance/src/dataset/index.rs @@ -78,7 +78,7 @@ impl LanceIndexStoreExt for LanceIndexStore { fn from_dataset(dataset: &Dataset, uuid: &str) -> Self { let index_dir = dataset.indices_dir().child(uuid); Self::new( - dataset.object_store.as_ref().clone(), + dataset.object_store.clone(), index_dir, dataset.session.file_metadata_cache.clone(), ) diff --git a/rust/lance/src/dataset/optimize.rs b/rust/lance/src/dataset/optimize.rs index 034168c26c0..7cd4d02edd0 100644 --- a/rust/lance/src/dataset/optimize.rs +++ b/rust/lance/src/dataset/optimize.rs @@ -597,7 +597,7 @@ async fn reserve_fragment_ids( None, ); - let (manifest, _) = commit_transaction( + let (manifest, _, _) = commit_transaction( dataset, dataset.object_store(), dataset.commit_handler.as_ref(), @@ -902,19 +902,9 @@ pub async fn commit_compaction( None, ); - let (manifest, manifest_path) = commit_transaction( - dataset, - dataset.object_store(), - dataset.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - dataset.manifest_naming_scheme, - ) - .await?; - - dataset.manifest = Arc::new(manifest); - dataset.manifest_file = manifest_path; + dataset + .apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(metrics) } @@ -982,11 +972,9 @@ mod tests { assert!(!single_bin.is_noop()); let big_bin = CandidateBin { - fragments: std::iter::repeat(fragment).take(8).collect(), + fragments: std::iter::repeat_n(fragment, 8).collect(), pos_range: 0..8, - candidacy: std::iter::repeat(CompactionCandidacy::CompactItself) - .take(8) - .collect(), + candidacy: std::iter::repeat_n(CompactionCandidacy::CompactItself, 8).collect(), row_counts: vec![100, 400, 200, 200, 400, 300, 300, 100], indices: vec![], // Will group into: [[100, 400], [200, 200, 400], [300, 300, 100]] @@ -1672,8 +1660,9 @@ mod tests { async fn vector_query(dataset: &Dataset) -> RecordBatch { let mut scanner = dataset.scan(); + let query = Float32Array::from(vec![0.0f32; 128]); scanner - .nearest("vec", &vec![0.0; 128].into(), 10) + .nearest("vec", &query, 10) .unwrap() .project(&["i"]) .unwrap(); diff --git a/rust/lance/src/dataset/optimize/remapping.rs b/rust/lance/src/dataset/optimize/remapping.rs index 026cbcc3560..4b09bf7b3f2 100644 --- a/rust/lance/src/dataset/optimize/remapping.rs +++ b/rust/lance/src/dataset/optimize/remapping.rs @@ -95,7 +95,7 @@ impl<'a, I: Iterator> MissingIds<'a, I> { } } -impl<'a, I: Iterator> Iterator for MissingIds<'a, I> { +impl> Iterator for MissingIds<'_, I> { type Item = u64; fn next(&mut self) -> Option { diff --git a/rust/lance/src/dataset/refs.rs b/rust/lance/src/dataset/refs.rs index a3fb07ed646..664cdcd15ae 100644 --- a/rust/lance/src/dataset/refs.rs +++ b/rust/lance/src/dataset/refs.rs @@ -117,18 +117,24 @@ impl Tags { let manifest_file = self .commit_handler - .resolve_version(&self.base, version, &self.object_store.inner) + .resolve_version_location(&self.base, version, &self.object_store.inner) .await?; - if !self.object_store().exists(&manifest_file).await? { + if !self.object_store().exists(&manifest_file.path).await? { return Err(Error::VersionNotFound { message: format!("version {} does not exist", version), }); } + let manifest_size = if let Some(size) = manifest_file.size { + size as usize + } else { + self.object_store().size(&manifest_file.path).await? + }; + let tag_contents = TagContents { version, - manifest_size: self.object_store().size(&manifest_file).await?, + manifest_size, }; self.object_store() @@ -137,6 +143,7 @@ impl Tags { serde_json::to_string_pretty(&tag_contents)?.as_bytes(), ) .await + .map(|_| ()) } pub async fn delete(&mut self, tag: &str) -> Result<()> { @@ -166,18 +173,24 @@ impl Tags { let manifest_file = self .commit_handler - .resolve_version(&self.base, version, &self.object_store.inner) + .resolve_version_location(&self.base, version, &self.object_store.inner) .await?; - if !self.object_store().exists(&manifest_file).await? { + if !self.object_store().exists(&manifest_file.path).await? { return Err(Error::VersionNotFound { message: format!("version {} does not exist", version), }); } + let manifest_size = if let Some(size) = manifest_file.size { + size as usize + } else { + self.object_store().size(&manifest_file.path).await? + }; + let tag_contents = TagContents { version, - manifest_size: self.object_store().size(&manifest_file).await?, + manifest_size, }; self.object_store() @@ -186,6 +199,7 @@ impl Tags { serde_json::to_string_pretty(&tag_contents)?.as_bytes(), ) .await + .map(|_| ()) } pub(crate) fn object_store(&self) -> &ObjectStore { diff --git a/rust/lance/src/dataset/rowids.rs b/rust/lance/src/dataset/rowids.rs index e8ab5f64e04..c3843b789ff 100644 --- a/rust/lance/src/dataset/rowids.rs +++ b/rust/lance/src/dataset/rowids.rs @@ -4,7 +4,7 @@ use super::Dataset; use crate::{Error, Result}; use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; -use snafu::{location, Location}; +use snafu::location; use std::sync::Arc; use lance_table::{ diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index cb5fe094537..bfebe36334a 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -7,15 +7,17 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use arrow::array::AsArray; use arrow_array::{Array, Float32Array, Int64Array, RecordBatch}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef, SortOptions}; use arrow_select::concat::concat_batches; use async_recursion::async_recursion; +use datafusion::common::SchemaExt; +use datafusion::functions_aggregate; use datafusion::functions_aggregate::count::count_udaf; -use datafusion::logical_expr::{lit, Expr}; +use datafusion::logical_expr::Expr; use datafusion::physical_expr::PhysicalSortExpr; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions; use datafusion::physical_plan::projection::ProjectionExec as DFProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -26,23 +28,29 @@ use datafusion::physical_plan::{ filter::FilterExec, limit::GlobalLimitExec, repartition::RepartitionExec, - udaf::create_aggregate_expr, union::UnionExec, ExecutionPlan, SendableRecordBatchStream, }; use datafusion::scalar::ScalarValue; -use datafusion_physical_expr::{Partitioning, PhysicalExpr}; +use datafusion_expr::Operator; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::{LexOrdering, Partitioning, PhysicalExpr}; +use futures::future::BoxFuture; use futures::stream::{Stream, StreamExt}; -use futures::TryStreamExt; +use futures::{FutureExt, TryStreamExt}; use lance_arrow::floats::{coerce_float_vector, FloatType}; use lance_arrow::DataTypeExt; -use lance_core::datatypes::Field; +use lance_core::datatypes::{Field, OnMissing, Projection}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD}; -use lance_datafusion::exec::{execute_plan, LanceExecutionOptions}; +use lance_datafusion::exec::{analyze_plan, execute_plan, LanceExecutionOptions}; use lance_datafusion::projection::ProjectionPlan; +use lance_index::metrics::NoOpMetricsCollector; use lance_index::scalar::expression::PlannerIndexExt; -use lance_index::scalar::inverted::{FTS_SCHEMA, SCORE_COL}; +use lance_index::scalar::inverted::query::{ + fill_fts_query_column, FtsQuery, FtsSearchParams, MatchQuery, +}; +use lance_index::scalar::inverted::SCORE_COL; use lance_index::scalar::{FullTextSearchQuery, ScalarIndexType}; use lance_index::vector::{Query, DIST_COL}; use lance_index::{scalar::expression::ScalarIndexExpr, DatasetIndexExt}; @@ -55,8 +63,10 @@ use tracing::{info_span, instrument, Span}; use super::Dataset; use crate::datatypes::Schema; use crate::index::scalar::detect_scalar_index_type; +use crate::index::vector::utils::{get_vector_dim, get_vector_type}; use crate::index::DatasetIndexInternalExt; -use crate::io::exec::fts::{FlatFtsExec, FtsExec}; +use crate::io::exec::fts::{BoostQueryExec, FlatMatchQueryExec, MatchQueryExec, PhraseQueryExec}; +use crate::io::exec::knn::MultivectorScoringExec; use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec}; use crate::io::exec::{get_physical_optimizer, LanceScanConfig}; use crate::io::exec::{ @@ -64,8 +74,9 @@ use crate::io::exec::{ LancePushdownScanExec, LanceScanExec, Planner, PreFilterSource, ScanConfig, TakeExec, }; use crate::{Error, Result}; -use snafu::{location, Location}; +use snafu::location; +pub use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts}; #[cfg(feature = "substrait")] use lance_datafusion::substrait::parse_substrait; @@ -82,6 +93,9 @@ pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4; lazy_static::lazy_static! { pub static ref DEFAULT_FRAGMENT_READAHEAD: Option = std::env::var("LANCE_DEFAULT_FRAGMENT_READAHEAD") .map(|val| Some(val.parse().unwrap())).unwrap_or(None); + + pub static ref DEFAULT_XTR_OVERFETCH: u32 = std::env::var("LANCE_XTR_OVERFETCH") + .map(|val| val.parse().unwrap()).unwrap_or(10); } // We want to support ~256 concurrent reads to maximize throughput on cloud storage systems @@ -183,6 +197,7 @@ impl MaterializationStyle { } /// Filter for filtering rows +#[derive(Debug)] pub enum LanceFilter { /// The filter is an SQL string Sql(String), @@ -321,6 +336,12 @@ pub struct Scanner { /// This is essentially a weak consistency search. Users can run index or optimize index /// to make the index catch up with the latest data. fast_search: bool, + + /// If true, the scanner will emit deleted rows + include_deleted_rows: bool, + + /// If set, this callback will be called after the scan with summary statistics + scan_stats_callback: Option, } fn escape_column_name(name: &str) -> String { @@ -359,6 +380,8 @@ impl Scanner { fragments: None, fast_search: false, use_scalar_index: true, + include_deleted_rows: false, + scan_stats_callback: None, } } @@ -453,6 +476,12 @@ impl Scanner { self } + /// Set the callback to be called after the scan with summary statistics + pub fn scan_stats_callback(&mut self, callback: ExecutionStatsCallback) -> &mut Self { + self.scan_stats_callback = Some(callback); + self + } + /// Set the materialization style for the scan /// /// This controls when columns are fetched from storage. The default should work @@ -503,11 +532,12 @@ impl Scanner { /// .into_stream(); /// ``` pub fn full_text_search(&mut self, query: FullTextSearchQuery) -> Result<&mut Self> { - if !query.columns.is_empty() { - for column in &query.columns { - if self.dataset.schema().field(column).is_none() { + let fields = query.columns(); + if !fields.is_empty() { + for field in fields.iter() { + if self.dataset.schema().field(field).is_none() { return Err(Error::invalid_input( - format!("Column {} not found", column), + format!("Column {} not found", field), location!(), )); } @@ -527,7 +557,7 @@ impl Scanner { Ok(self) } - pub(crate) fn filter_expr(&mut self, filter: Expr) -> &mut Self { + pub fn filter_expr(&mut self, filter: Expr) -> &mut Self { self.filter = Some(LanceFilter::Datafusion(filter)); self } @@ -538,6 +568,21 @@ impl Scanner { self } + /// Include deleted rows + /// + /// These are rows that have been deleted from the dataset but are still present in the + /// underlying storage. These rows will have the `_rowid` column set to NULL. The other columns + /// (include _rowaddr) will be set to their deleted values. + /// + /// This can be useful for generating aligned fragments or debugging + /// + /// Note: when entire fragments are deleted, the scanner will not emit any rows for that fragment + /// since the fragment is no longer present in the dataset. + pub fn include_deleted_rows(&mut self) -> &mut Self { + self.include_deleted_rows = true; + self + } + /// Set the I/O buffer size /// /// This is the amount of RAM that will be reserved for holding I/O received from @@ -630,7 +675,9 @@ impl Scanner { } /// Find k-nearest neighbor within the vector column. - pub fn nearest(&mut self, column: &str, q: &Float32Array, k: usize) -> Result<&mut Self> { + /// the query can be a Float16Array, Float32Array, Float64Array, UInt8Array, + /// or a ListArray/FixedSizeListArray of the above types. + pub fn nearest(&mut self, column: &str, q: &dyn Array, k: usize) -> Result<&mut Self> { if !self.prefilter { // We can allow fragment scan if the input to nearest is a prefilter. // The fragment scan will be performed by the prefilter. @@ -650,35 +697,82 @@ impl Scanner { )); } // make sure the field exists - let field = self - .dataset - .schema() - .field(column) - .ok_or(Error::invalid_input( - format!("Column {} not found", column), - location!(), - ))?; - let key = match field.data_type() { - DataType::FixedSizeList(dt, _) => { - if dt.data_type().is_floating() { - coerce_float_vector(q, FloatType::try_from(dt.data_type())?)? + let (vector_type, element_type) = get_vector_type(self.dataset.schema(), column)?; + let dim = get_vector_dim(self.dataset.schema(), column)?; + + let q = match q.data_type() { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + if !matches!(vector_type, DataType::List(_)) { + return Err(Error::invalid_input( + format!( + "Query is multivector but column {}({})is not multivector", + column, vector_type, + ), + location!(), + )); + } + + if let Some(list_array) = q.as_list_opt::() { + for i in 0..list_array.len() { + let vec = list_array.value(i); + if vec.len() != dim { + return Err(Error::invalid_input( + format!( + "query dim({}) doesn't match the column {} vector dim({})", + vec.len(), + column, + dim, + ), + location!(), + )); + } + } + list_array.values().clone() } else { + let fsl = q.as_fixed_size_list(); + if fsl.value_length() as usize != dim { + return Err(Error::invalid_input( + format!( + "query dim({}) doesn't match the column {} vector dim({})", + fsl.value_length(), + column, + dim, + ), + location!(), + )); + } + fsl.values().clone() + } + } + _ => { + if q.len() != dim { return Err(Error::invalid_input( format!( - "Column {} is not a vector column (type: {})", + "query dim({}) doesn't match the column {} vector dim({})", + q.len(), column, - field.data_type() + dim, ), location!(), )); } + q.slice(0, q.len()) } + }; + + let key = match element_type { + dt if dt == *q.data_type() => q, + dt if dt.is_floating() => coerce_float_vector( + q.as_any().downcast_ref::().unwrap(), + FloatType::try_from(&dt)?, + )?, _ => { return Err(Error::invalid_input( format!( - "Column {} is not a vector column (type: {})", + "Column {} has element type {} and the query vector is {}", column, - field.data_type() + element_type, + q.data_type(), ), location!(), )); @@ -687,8 +781,10 @@ impl Scanner { self.nearest = Some(Query { column: column.to_string(), - key: key.into(), + key, k, + lower_bound: None, + upper_bound: None, nprobes: 1, ef: None, refine_factor: None, @@ -698,6 +794,19 @@ impl Scanner { Ok(self) } + /// Set the distance thresholds for the nearest neighbor search. + pub fn distance_range( + &mut self, + lower_bound: Option, + upper_bound: Option, + ) -> &mut Self { + if let Some(q) = self.nearest.as_mut() { + q.lower_bound = lower_bound; + q.upper_bound = upper_bound; + } + self + } + pub fn nprobs(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { q.nprobes = n; @@ -731,9 +840,9 @@ impl Scanner { /// and using the original vector values to re-rank the distances. /// /// * `factor` - the factor of extra elements to read. For example, if factor is 2, then - /// the search will read 2x more elements than the requested k before performing - /// the re-ranking. Note: even if the factor is 1, the results will still be - /// re-ranked without fetching additional elements. + /// the search will read 2x more elements than the requested k before performing + /// the re-ranking. Note: even if the factor is 1, the results will still be + /// re-ranked without fetching additional elements. pub fn refine(&mut self, factor: u32) -> &mut Self { if let Some(q) = self.nearest.as_mut() { q.refine_factor = Some(factor) @@ -929,19 +1038,34 @@ impl Scanner { /// Create a stream from the Scanner. #[instrument(skip_all)] - pub async fn try_into_stream(&self) -> Result { - let plan = self.create_plan().await?; - Ok(DatasetRecordBatchStream::new(execute_plan( - plan, - LanceExecutionOptions::default(), - )?)) + pub fn try_into_stream(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let plan = self.create_plan().await?; + + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + LanceExecutionOptions { + batch_size: self.batch_size, + execution_stats_callback: self.scan_stats_callback.clone(), + ..Default::default() + }, + )?)) + } + .boxed() } pub(crate) async fn try_into_dfstream( &self, - options: LanceExecutionOptions, + mut options: LanceExecutionOptions, ) -> Result { let plan = self.create_plan().await?; + + // Use the scan stats callback if the user didn't set an execution stats callback + if options.execution_stats_callback.is_none() { + options.execution_stats_callback = self.scan_stats_callback.clone(); + } + execute_plan(plan, options) } @@ -952,64 +1076,76 @@ impl Scanner { Ok(concat_batches(&schema, &batches)?) } - /// Scan and return the number of matching rows - #[instrument(skip_all)] - pub async fn count_rows(&self) -> Result { - let plan = self.create_plan().await?; - // Datafusion interprets COUNT(*) as COUNT(1) - let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); - let count_expr = create_aggregate_expr( - &count_udaf(), - &[one], - &[lit(1)], - &[], - &[], - &plan.schema(), - None, - false, - false, - )?; - let plan_schema = plan.schema(); - let count_plan = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(Vec::new()), - vec![count_expr], - vec![None], - plan, - plan_schema, - )?); - let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; - - // A count plan will always return a single batch with a single row. - if let Some(first_batch) = stream.next().await { - let batch = first_batch?; - let array = batch - .column(0) - .as_any() - .downcast_ref::() - .ok_or(Error::io( - "Count plan did not return a UInt64Array".to_string(), + fn create_count_plan(&self) -> BoxFuture>> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + if !self.projection_plan.physical_schema.fields.is_empty() { + return Err(Error::invalid_input( + "count_rows should not be called on a plan selecting columns".to_string(), location!(), - ))?; - Ok(array.value(0) as u64) - } else { - Ok(0) + )); + } + + if self.limit.is_some() || self.offset.is_some() { + log::warn!( + "count_rows called with limit or offset which could have surprising results" + ); + } + + let plan = self.create_plan().await?; + // Datafusion interprets COUNT(*) as COUNT(1) + let one = Arc::new(Literal::new(ScalarValue::UInt8(Some(1)))); + + let input_phy_exprs: &[Arc] = &[one]; + let schema = plan.schema(); + + let mut builder = AggregateExprBuilder::new(count_udaf(), input_phy_exprs.to_vec()); + builder = builder.schema(schema); + builder = builder.alias("count_rows".to_string()); + + let count_expr = builder.build()?; + + let plan_schema = plan.schema(); + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(Vec::new()), + vec![Arc::new(count_expr)], + vec![None], + plan, + plan_schema, + )?) as Arc) } + .boxed() } - /// Given a base schema and a list of desired fields figure out which fields, if any, still need loaded - fn calc_new_fields>( - &self, - base_schema: &Schema, - columns: &[S], - ) -> Result> { - let new_schema = self.dataset.schema().project(columns)?; - let new_schema = new_schema.exclude(base_schema)?; - if new_schema.fields.is_empty() { - Ok(None) - } else { - Ok(Some(new_schema)) + /// Scan and return the number of matching rows + /// + /// Note: calling [`Dataset::count_rows`] can be more efficient than calling this method + /// especially if there is no filter. + #[instrument(skip_all)] + pub fn count_rows(&self) -> BoxFuture> { + // Future intentionally boxed here to avoid large futures on the stack + async move { + let count_plan = self.create_count_plan().await?; + let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; + + // A count plan will always return a single batch with a single row. + if let Some(first_batch) = stream.next().await { + let batch = first_batch?; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or(Error::io( + "Count plan did not return a UInt64Array".to_string(), + location!(), + ))?; + Ok(array.value(0) as u64) + } else { + Ok(0) + } } + .boxed() } // A "narrow" field is a field that is so small that we are better off reading the @@ -1041,47 +1177,46 @@ impl Scanner { let byte_width = field.data_type().byte_width_opt(); let is_cloud = self.dataset.object_store().is_cloud(); if is_cloud { - byte_width.map_or(false, |bw| bw < 1000) + byte_width.is_some_and(|bw| bw < 1000) } else { - byte_width.map_or(false, |bw| bw < 10) + byte_width.is_some_and(|bw| bw < 10) } } } } - fn calc_eager_columns(&self, filter_plan: &FilterPlan) -> Result> { - let columns = filter_plan.refine_columns(); - // If the column didn't exist in the scan output schema then we wouldn't make - // it to this point. However, there may be columns (like _rowid, _distance, etc.) - // which do not exist in the dataset schema but are added by the scan. We can ignore - // those as eager columns. - let filter_schema = self.dataset.schema().project_or_drop(&columns)?; + // If we are going to filter on `filter_plan`, then which columns are so small it is + // cheaper to read the entire column and filter in memory. + // + // Note: only add columns that we actually need to read + fn calc_eager_projection( + &self, + filter_plan: &FilterPlan, + desired_schema: &Schema, + ) -> Result { + let filter_columns = filter_plan.refine_columns(); + + let filter_schema = self + .dataset + .empty_projection() + .union_columns(filter_columns, OnMissing::Error)? + .into_schema(); if filter_schema.fields.iter().any(|f| !f.is_default_storage()) { return Err(Error::NotSupported { source: "non-default storage columns cannot be used as filters".into(), location: location!(), }); } - let physical_schema = self.projection_plan.physical_schema.clone(); - let remaining_schema = physical_schema.exclude(&filter_schema)?; - let narrow_fields = remaining_schema - .fields - .iter() - .filter(|f| self.is_early_field(f)) - .cloned() - .collect::>(); - - if narrow_fields.is_empty() { - Ok(Arc::new(filter_schema)) - } else { - let mut new_fields = filter_schema.fields; - new_fields.extend(narrow_fields); - Ok(Arc::new(Schema { - fields: new_fields, - metadata: HashMap::new(), - })) - } + Ok(self + .dataset + .empty_projection() + // Start with the desired schema + .union_schema(desired_schema) + // Subtract columns that are expensive + .subtract_predicate(|f| !self.is_early_field(f)) + // Add back columns that we need for filtering + .union_schema(&filter_schema)) } /// Create [`ExecutionPlan`] for Scan. @@ -1142,6 +1277,14 @@ impl Scanner { location: location!(), }); } + + if self.include_deleted_rows && !self.with_row_id { + return Err(Error::InvalidInput { + source: "include_deleted_rows is set but with_row_id is false".into(), + location: location!(), + }); + } + if let Some(first_blob_col) = self .projection_plan .physical_schema @@ -1206,12 +1349,18 @@ impl Scanner { } else { match (self.limit, self.offset) { (None, None) => None, - (Some(limit), None) => Some(0..limit as u64), + (Some(limit), None) => { + let num_rows = self.dataset.count_all_rows().await? as i64; + Some(0..limit.min(num_rows) as u64) + } (None, Some(offset)) => { - let num_rows = self.dataset.count_all_rows().await?; - Some(offset as u64..num_rows as u64) + let num_rows = self.dataset.count_all_rows().await? as i64; + Some(offset.min(num_rows) as u64..num_rows as u64) + } + (Some(limit), Some(offset)) => { + let num_rows = self.dataset.count_all_rows().await? as i64; + Some(offset.min(num_rows) as u64..(offset + limit).min(num_rows) as u64) } - (Some(limit), Some(offset)) => Some(offset as u64..(offset + limit) as u64), } }; let mut use_limit_node = true; @@ -1219,6 +1368,13 @@ impl Scanner { // Stage 1: source (either an (K|A)NN search, full text search or or a (full|indexed) scan) let mut plan: Arc = match (&self.nearest, &self.full_text_query) { (Some(_), None) => { + if self.include_deleted_rows { + return Err(Error::InvalidInput { + source: "Cannot include deleted rows in a nearest neighbor search".into(), + location: location!(), + }); + } + // The source is an nearest neighbor search if self.prefilter { // If we are prefiltering then the knn node will take care of the filter @@ -1233,6 +1389,13 @@ impl Scanner { } } (None, Some(query)) => { + if self.include_deleted_rows { + return Err(Error::InvalidInput { + source: "Cannot include deleted rows in an FTS search".into(), + location: location!(), + }); + } + // The source is an FTS search if self.prefilter { // If we are prefiltering then the fts node will take care of the filter @@ -1258,31 +1421,50 @@ impl Scanner { } else { self.use_stats }; - match (&filter_plan.index_query, &mut filter_plan.refine_expr) { - (Some(index_query), None) => { - self.scalar_indexed_scan( - self.projection_plan.physical_schema.as_ref(), - index_query, - ) - .await? + + if filter_plan.index_query.is_some() && self.include_deleted_rows { + return Err(Error::InvalidInput { + source: "Cannot include deleted rows in a scalar indexed scan".into(), + location: location!(), + }); + } + + match ( + filter_plan.index_query.is_some(), + filter_plan.refine_expr.is_some(), + ) { + (true, false) => { + let projection = self + .dataset + .empty_projection() + .union_schema(&self.projection_plan.physical_schema); + self.scalar_indexed_scan(projection, &filter_plan).await? } // TODO: support combined pushdown and scalar index scan - (Some(index_query), Some(_)) => { + (true, true) => { // If there is a filter then just load the eager columns and // "take" the other columns later. - let eager_schema = self.calc_eager_columns(&filter_plan)?; - self.scalar_indexed_scan(&eager_schema, index_query).await? + let eager_projection = self.calc_eager_projection( + &filter_plan, + self.projection_plan.physical_schema.as_ref(), + )?; + self.scalar_indexed_scan(eager_projection, &filter_plan) + .await? } - (None, Some(_)) if use_stats && self.batch_size.is_none() => { + (false, true) if use_stats && self.batch_size.is_none() => { self.pushdown_scan(false, filter_plan.refine_expr.take().unwrap())? } - (None, _) => { + (false, _) => { // The source is a full scan of the table let with_row_id = filter_plan.has_refine() || self.with_row_id; let eager_schema = if filter_plan.has_refine() { // If there is a filter then only load the filter columns in the // initial scan. We will `take` the remaining columns later - self.calc_eager_columns(&filter_plan)? + self.calc_eager_projection( + &filter_plan, + self.projection_plan.physical_schema.as_ref(), + )? + .into_schema_ref() } else { // If there is no filter we eagerly load everything self.projection_plan.physical_schema.clone() @@ -1296,7 +1478,7 @@ impl Scanner { self.scan( with_row_id, self.with_row_address, - false, + self.include_deleted_rows, scan_range, eager_schema, ) @@ -1312,34 +1494,29 @@ impl Scanner { }; // Stage 1.5 load columns needed for stages 2 & 3 - let mut additional_schema = None; + // Calculate the schema needed for the filter and ordering. + let mut pre_filter_projection = self.dataset.empty_projection(); + // We may need to take filter columns if we are going to refine - // an indexed scan. Otherwise, the filter was applied during the scan - // and this should be false + // an indexed scan. if filter_plan.has_refine() { - let eager_schema = self.calc_eager_columns(&filter_plan)?; - let base_schema = Schema::try_from(plan.schema().as_ref())?; - let still_to_load = eager_schema.exclude(base_schema)?; - if still_to_load.fields.is_empty() { - additional_schema = None; - } else { - additional_schema = Some(still_to_load); - } + // It's ok for some filter columns to be missing (e.g. _rowid) + pre_filter_projection = pre_filter_projection + .union_columns(filter_plan.refine_columns(), OnMissing::Ignore)?; } + + // TODO: Does it always make sense to take the ordering columns here? If there is a filter then + // maybe we wait until after the filter to take the ordering columns? Maybe it would be better to + // grab the ordering column in the initial scan (if it is eager) and if it isn't then we should + // take it after the filtering phase, if any (we already have a take there). if let Some(ordering) = &self.ordering { - additional_schema = self.calc_new_fields( - &additional_schema - .map(Ok::) - .unwrap_or_else(|| Schema::try_from(plan.schema().as_ref()))?, - &ordering - .iter() - .map(|col| &col.column_name) - .collect::>(), + pre_filter_projection = pre_filter_projection.union_columns( + ordering.iter().map(|col| &col.column_name), + OnMissing::Error, )?; } - if let Some(additional_schema) = additional_schema { - plan = self.take(plan, &additional_schema, self.batch_readahead)?; - } + + plan = self.take(plan, pre_filter_projection)?; // Stage 2: filter if let Some(refine_expr) = filter_plan.refine_expr { @@ -1353,19 +1530,13 @@ impl Scanner { // Stage 3: sort if let Some(ordering) = &self.ordering { - let order_by_schema = Arc::new( - self.dataset.schema().project( - &ordering - .iter() - .map(|col| &col.column_name) - .collect::>(), - )?, - ); - let remaining_schema = order_by_schema.exclude(plan.schema().as_ref())?; - if !remaining_schema.fields.is_empty() { - // We haven't loaded the sort column yet so take it now - plan = self.take(plan, &remaining_schema, self.batch_readahead)?; - } + let ordering_columns = ordering.iter().map(|col| &col.column_name); + let projection_with_ordering = self + .dataset + .empty_projection() + .union_columns(ordering_columns, OnMissing::Error)?; + // We haven't loaded the sort column yet so take it now + plan = self.take(plan, projection_with_ordering)?; let col_exprs = ordering .iter() .map(|col| { @@ -1378,7 +1549,7 @@ impl Scanner { }) }) .collect::>>()?; - plan = Arc::new(SortExec::new(col_exprs, plan)); + plan = Arc::new(SortExec::new(LexOrdering::new(col_exprs), plan)); } // Stage 4: limit / offset @@ -1389,12 +1560,14 @@ impl Scanner { // Stage 5: take remaining columns required for projection let physical_schema = self.scan_output_schema(&self.projection_plan.physical_schema, false)?; - let remaining_schema = physical_schema.exclude(plan.schema().as_ref())?; - if !remaining_schema.fields.is_empty() { - plan = self.take(plan, &remaining_schema, self.batch_readahead)?; - } + let physical_projection = self + .dataset + .empty_projection() + .union_schema(&physical_schema); + plan = self.take(plan, physical_projection)?; // Stage 6: physical projection -- reorder physical columns needed before final projection let output_arrow_schema = physical_schema.as_ref().into(); + if plan.schema().as_ref() != &output_arrow_schema { plan = Arc::new(project(plan, &physical_schema.as_ref().into())?); } @@ -1417,14 +1590,27 @@ impl Scanner { filter_plan: &FilterPlan, query: &FullTextSearchQuery, ) -> Result> { - let columns = if query.columns.is_empty() { - let string_columns = self.dataset.schema().fields.iter().filter_map(|f| { - if f.data_type() == DataType::Utf8 || f.data_type() == DataType::LargeUtf8 { - Some(&f.name) - } else { - None - } - }); + let columns = query.columns(); + let params = query.params().with_limit(self.limit.map(|l| l as usize)); + let query = if columns.is_empty() { + // the field is not specified, + // try to search over all indexed fields + let string_columns = + self.dataset + .schema() + .fields + .iter() + .filter_map(|f| match f.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => Some(&f.name), + DataType::List(field) | DataType::LargeList(field) => { + if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) { + Some(&f.name) + } else { + None + } + } + _ => None, + }); let mut indexed_columns = Vec::new(); for column in string_columns { @@ -1443,104 +1629,203 @@ impl Scanner { } } - indexed_columns + fill_fts_query_column(&query.query, &indexed_columns, false)? } else { - query.columns.clone() + query.query.clone() }; - if columns.is_empty() { - return Err(Error::invalid_input( - "Cannot perform full text search unless an INVERTED index has been created on at least one column".to_string(), - location!(), - )); - } - - // rewrite the query to be with the columns and limit - let query = query - .clone() - .columns(Some(columns.clone())) - .limit(self.limit); + let prefilter_source = self.prefilter_source(filter_plan).await?; + let fts_exec = self + .plan_fts(&query, ¶ms, filter_plan, &prefilter_source) + .await?; + Ok(fts_exec) + } - // load indices - let mut column_inputs = HashMap::with_capacity(columns.len()); - for column in columns { - let index = self - .dataset - .load_scalar_index_for_column(&column) - .await? - .ok_or(Error::invalid_input( - format!("Column {} has no inverted index", column), - location!(), - ))?; - let index_uuids: Vec<_> = self - .dataset - .load_indices_by_name(&index.name) - .await? - .into_iter() - .collect(); + async fn plan_fts( + &self, + query: &FtsQuery, + params: &FtsSearchParams, + filter_plan: &FilterPlan, + prefilter_source: &PreFilterSource, + ) -> Result> { + let plan: Arc = match query { + FtsQuery::Match(query) => { + self.plan_match_query(query, params, filter_plan, prefilter_source) + .await? + } + FtsQuery::Phrase(query) => Arc::new(PhraseQueryExec::new( + self.dataset.clone(), + query.clone(), + params.clone(), + prefilter_source.clone(), + )), + + FtsQuery::Boost(query) => { + // for boost query, we need to erase the limit so that we can find + // the documents that are not in the top-k results of the positive query, + // but in the final top-k results. + let unlimited_params = params.clone().with_limit(None); + let positive_exec = Box::pin(self.plan_fts( + &query.positive, + &unlimited_params, + filter_plan, + prefilter_source, + )); + let negative_exec = Box::pin(self.plan_fts( + &query.negative, + &unlimited_params, + filter_plan, + prefilter_source, + )); + let (positive_exec, negative_exec) = + futures::future::try_join(positive_exec, negative_exec).await?; + Arc::new(BoostQueryExec::new( + query.clone(), + params.clone(), + positive_exec, + negative_exec, + )) + } - let unindexed_fragments = self.dataset.unindexed_fragments(&index.name).await?; - let unindexed_scan_node = if unindexed_fragments.is_empty() { - Arc::new(EmptyExec::new(FTS_SCHEMA.clone())) - } else { - let mut columns = vec![column.clone()]; - if let Some(expr) = filter_plan.full_expr.as_ref() { - let filter_columns = Planner::column_names_in_expr(expr); - columns.extend(filter_columns); + FtsQuery::MultiMatch(query) => { + let mut children = Vec::with_capacity(query.match_queries.len()); + for match_query in &query.match_queries { + let child = + self.plan_match_query(match_query, params, filter_plan, prefilter_source); + children.push(child); } - let flat_fts_scan_schema = - Arc::new(self.dataset.schema().project(&columns).unwrap()); - let mut scan_node = self.scan_fragments( - true, - false, - true, - flat_fts_scan_schema, - Arc::new(unindexed_fragments), - None, - false, - ); + let children = futures::future::try_join_all(children).await?; + + let schema = children[0].schema(); + let group_expr = vec![( + expressions::col(ROW_ID, schema.as_ref())?, + ROW_ID.to_string(), + )]; + + let fts_node = Arc::new(UnionExec::new(children)); + let fts_node = Arc::new(RepartitionExec::try_new( + fts_node, + Partitioning::RoundRobinBatch(1), + )?); + // dedup by row_id and return the max score as final score + let fts_node = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(group_expr), + vec![Arc::new( + AggregateExprBuilder::new( + functions_aggregate::min_max::max_udaf(), + vec![expressions::col(SCORE_COL, &schema)?], + ) + .schema(schema.clone()) + .alias(SCORE_COL) + .build()?, + )], + vec![None], + fts_node, + schema, + )?); + let sort_expr = PhysicalSortExpr { + expr: expressions::col(SCORE_COL, fts_node.schema().as_ref())?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }; - if let Some(expr) = filter_plan.full_expr.as_ref() { - // If there is a prefilter we need to manually apply it to the new data - let planner = Planner::new(scan_node.schema()); - let physical_refine_expr = planner.create_physical_expr(expr)?; - scan_node = Arc::new(FilterExec::try_new(physical_refine_expr, scan_node)?); - } + Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]), fts_node) + .with_fetch(self.limit.map(|l| l as usize)), + ) + } + }; - scan_node - }; + Ok(plan) + } - column_inputs.insert(column.clone(), (index_uuids, unindexed_scan_node)); - } + async fn plan_match_query( + &self, + query: &MatchQuery, + params: &FtsSearchParams, + filter_plan: &FilterPlan, + prefilter_source: &PreFilterSource, + ) -> Result> { + let column = query + .column + .as_ref() + .ok_or(Error::invalid_input( + "the column must be specified in the query".to_string(), + location!(), + ))? + .clone(); - let indices = column_inputs - .iter() - .map(|(col, (idx, _))| (col.clone(), idx.clone())) - .collect(); - let prefilter_source = self.prefilter_source(filter_plan).await?; - let fts_plan = Arc::new(FtsExec::new( + let index = self + .dataset + .load_scalar_index_for_column(query.column.as_ref().unwrap()) + .await? + .ok_or(Error::invalid_input( + format!( + "Column {} has no inverted index", + query.column.as_ref().unwrap() + ), + location!(), + ))?; + + let unindexed_fragments = self.dataset.unindexed_fragments(&index.name).await?; + let mut match_plan: Arc = Arc::new(MatchQueryExec::new( self.dataset.clone(), - indices, query.clone(), - prefilter_source, - )) as Arc; - let flat_fts_plan = Arc::new(FlatFtsExec::new(self.dataset.clone(), column_inputs, query)); - let fts_node = Arc::new(UnionExec::new(vec![fts_plan, flat_fts_plan])); - let fts_node = Arc::new(RepartitionExec::try_new( - fts_node, - Partitioning::RoundRobinBatch(1), - )?); - let sort_expr = PhysicalSortExpr { - expr: expressions::col(SCORE_COL, fts_node.schema().as_ref())?, - options: SortOptions { - descending: true, - nulls_first: false, - }, - }; + params.clone(), + prefilter_source.clone(), + )); + if !unindexed_fragments.is_empty() { + let mut columns = vec![column.clone()]; + if let Some(expr) = filter_plan.full_expr.as_ref() { + let filter_columns = Planner::column_names_in_expr(expr); + columns.extend(filter_columns); + } + let flat_fts_scan_schema = Arc::new(self.dataset.schema().project(&columns).unwrap()); + let mut scan_node = self.scan_fragments( + true, + false, + true, + flat_fts_scan_schema, + Arc::new(unindexed_fragments), + None, + false, + ); - Ok(Arc::new( - SortExec::new(vec![sort_expr], fts_node).with_fetch(self.limit.map(|l| l as usize)), - )) + if let Some(expr) = filter_plan.full_expr.as_ref() { + // If there is a prefilter we need to manually apply it to the new data + let planner = Planner::new(scan_node.schema()); + let physical_refine_expr = planner.create_physical_expr(expr)?; + scan_node = Arc::new(FilterExec::try_new(physical_refine_expr, scan_node)?); + } + + let flat_match_plan = Arc::new(FlatMatchQueryExec::new( + self.dataset.clone(), + query.clone(), + params.clone(), + scan_node, + )); + + match_plan = Arc::new(UnionExec::new(vec![match_plan, flat_match_plan])); + match_plan = Arc::new(RepartitionExec::try_new( + match_plan, + Partitioning::RoundRobinBatch(1), + )?); + let sort_expr = PhysicalSortExpr { + expr: expressions::col(SCORE_COL, match_plan.schema().as_ref())?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }; + match_plan = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]), match_plan) + .with_fetch(params.limit), + ); + } + Ok(match_plan) } // ANN/KNN search execution node with optional prefilter @@ -1553,26 +1838,7 @@ impl Scanner { }; // Sanity check - let schema = self.dataset.schema(); - if let Some(field) = schema.field(&q.column) { - match field.data_type() { - DataType::FixedSizeList(subfield, _) if subfield.data_type().is_floating() => {} - _ => { - return Err(Error::invalid_input( - format!( - "Vector search error: column {} is not a vector type: expected FixedSizeList, got {}", - q.column, field.data_type(), - ), - location!(), - )); - } - } - } else { - return Err(Error::invalid_input( - format!("Vector search error: column {} not found", q.column), - location!(), - )); - } + let (vector_type, _) = get_vector_type(self.dataset.schema(), &q.column)?; let column_id = self.dataset.schema().field_id(q.column.as_str())?; let use_index = self.nearest.as_ref().map(|q| q.use_index).unwrap_or(false); @@ -1593,16 +1859,27 @@ impl Scanner { // Find all deltas with the same index name. let deltas = self.dataset.load_indices_by_name(&index.name).await?; - let ann_node = self.ann(q, &deltas, filter_plan).await?; // _distance, _rowid + let ann_node = match vector_type { + DataType::FixedSizeList(_, _) => self.ann(q, &deltas, filter_plan).await?, + DataType::List(_) => self.multivec_ann(q, &deltas, filter_plan).await?, + _ => unreachable!(), + }; let mut knn_node = if q.refine_factor.is_some() { - let with_vector = self.dataset.schema().project(&[&q.column])?; - let knn_node_with_vector = - self.take(ann_node, &with_vector, self.batch_readahead)?; + let vector_projection = self + .dataset + .empty_projection() + .union_column(&q.column, OnMissing::Error) + .unwrap(); + let knn_node_with_vector = self.take(ann_node, vector_projection)?; // TODO: now we just open an index to get its metric type. let idx = self .dataset - .open_vector_index(q.column.as_str(), &index.uuid.to_string()) + .open_vector_index( + q.column.as_str(), + &index.uuid.to_string(), + &NoOpMetricsCollector, + ) .await?; let mut q = q.clone(); q.metric_type = idx.metric_type(); @@ -1622,12 +1899,21 @@ impl Scanner { if let Some(refine_expr) = filter_plan.refine_expr.as_ref() { columns.extend(Planner::column_names_in_expr(refine_expr)); } - let vector_scan_projection = Arc::new(self.dataset.schema().project(&columns).unwrap()); - let mut plan = if let Some(index_query) = &filter_plan.index_query { - self.scalar_indexed_scan(&vector_scan_projection, index_query) + let vector_scan_projection = self + .dataset + .empty_projection() + .union_columns(&columns, OnMissing::Error)?; + let mut plan = if filter_plan.index_query.is_some() { + self.scalar_indexed_scan(vector_scan_projection, filter_plan) .await? } else { - self.scan(true, false, true, None, vector_scan_projection) + self.scan( + true, + false, + true, + None, + vector_scan_projection.into_schema_ref(), + ) }; if let Some(refine_expr) = &filter_plan.refine_expr { let planner = Planner::new(plan.schema()); @@ -1650,11 +1936,28 @@ impl Scanner { // Check if we've created new versions since the index was built. let unindexed_fragments = self.dataset.unindexed_fragments(&index.name).await?; if !unindexed_fragments.is_empty() { + // need to set the metric type to be the same as the index + // to make sure the distance is comparable. + let idx = self + .dataset + .open_vector_index( + q.column.as_str(), + &index.uuid.to_string(), + &NoOpMetricsCollector, + ) + .await?; + let mut q = q.clone(); + q.metric_type = idx.metric_type(); + // If the vector column is not present, we need to take the vector column, so // that the distance value is comparable with the flat search ones. if knn_node.schema().column_with_name(&q.column).is_none() { - let with_vector = self.dataset.schema().project(&[&q.column])?; - knn_node = self.take(knn_node, &with_vector, self.batch_readahead)?; + let vector_projection = self + .dataset + .empty_projection() + .union_column(&q.column, OnMissing::Error) + .unwrap(); + knn_node = self.take(knn_node, vector_projection)?; } let mut columns = vec![q.column.clone()]; @@ -1686,13 +1989,15 @@ impl Scanner { scan_node = Arc::new(FilterExec::try_new(physical_refine_expr, scan_node)?); } // first we do flat search on just the new data - let topk_appended = self.flat_knn(scan_node, q)?; + let topk_appended = self.flat_knn(scan_node, &q)?; // To do a union, we need to make the schemas match. Right now // knn_node: _distance, _rowid, vector // topk_appended: vector, , _rowid, _distance let topk_appended = project(topk_appended, knn_node.schema().as_ref())?; - assert_eq!(topk_appended.schema(), knn_node.schema()); + assert!(topk_appended + .schema() + .equivalent_names_and_types(&knn_node.schema())); // union let unioned = UnionExec::new(vec![Arc::new(topk_appended), knn_node]); // Enforce only 1 partition. @@ -1701,7 +2006,7 @@ impl Scanner { datafusion::physical_plan::Partitioning::RoundRobinBatch(1), )?; // then we do a flat search on KNN(new data) + ANN(indexed data) - return self.flat_knn(Arc::new(unioned), q); + return self.flat_knn(Arc::new(unioned), &q); } Ok(knn_node) @@ -1737,8 +2042,8 @@ impl Scanner { // target fragments with those ids async fn scalar_indexed_scan( &self, - projection: &Schema, - index_expr: &ScalarIndexExpr, + projection: Projection, + filter_plan: &FilterPlan, ) -> Result> { // One or more scalar indices cover this data and there is a filter which is // compatible with the indices. Use that filter to perform a take instead of @@ -1749,6 +2054,12 @@ impl Scanner { (**self.dataset.fragments()).clone() }; + // If this unwrap fails we have a bug because we shouldn't be using this function unless we've already + // checked that there is an index query + let index_expr = filter_plan.index_query.as_ref().unwrap(); + + let needs_recheck = index_expr.needs_recheck(); + // Figure out which fragments are covered by ALL of the indices we are using let covered_frags = self.fragments_covered_by_index_query(index_expr).await?; let mut relevant_frags = Vec::with_capacity(fragments.len()); @@ -1767,14 +2078,43 @@ impl Scanner { Arc::new(relevant_frags), )); - // If there is more than just _rowid in projection - let needs_take = match projection.fields.len() { - 0 => false, - 1 => projection.fields[0].name != ROW_ID, - _ => true, - }; + let refine_expr = filter_plan.refine_expr.as_ref(); + + // If all we want is the row ids then we can skip the take. However, if there is a refine + // or a recheck then we still need to do a take because we need filter columns. + let needs_take = + needs_recheck || projection.has_data_fields() || filter_plan.refine_expr.is_some(); if needs_take { - plan = self.take(plan, projection, self.batch_readahead)?; + let mut take_projection = projection.clone(); + if needs_recheck { + // If we need to recheck then we need to also take the columns used for the filter + let filter_expr = index_expr.to_expr(); + let filter_cols = Planner::column_names_in_expr(&filter_expr); + take_projection = take_projection.union_columns(filter_cols, OnMissing::Error)?; + } + if let Some(refine_expr) = refine_expr { + let refine_cols = Planner::column_names_in_expr(refine_expr); + take_projection = take_projection.union_columns(refine_cols, OnMissing::Error)?; + } + plan = self.take(plan, take_projection)?; + } + + let post_take_filter = match (needs_recheck, refine_expr) { + (false, None) => None, + (true, None) => { + // If we need to recheck then we need to apply the filter to the results + Some(index_expr.to_expr()) + } + (true, Some(_)) => Some(filter_plan.full_expr.as_ref().unwrap().clone()), + (false, Some(refine_expr)) => Some(refine_expr.clone()), + }; + + if let Some(post_take_filter) = post_take_filter { + let planner = Planner::new(plan.schema()); + let optimized_filter = planner.optimize_expr(post_take_filter)?; + let physical_refine_expr = planner.create_physical_expr(&optimized_filter)?; + + plan = Arc::new(FilterExec::try_new(physical_refine_expr, plan)?); } if self.with_row_address { @@ -1794,23 +2134,21 @@ impl Scanner { // If there were no extra columns then we still need the project // because Materialize -> Take puts the row id at the left and // Scan puts the row id at the right - let filter_expr = index_expr.to_expr(); - let filter_cols = Planner::column_names_in_expr(&filter_expr); - let full_schema = self - .calc_new_fields(projection, &filter_cols)? - .map(|filter_only_schema| projection.merge(&filter_only_schema)) - .transpose()?; - let schema = full_schema.as_ref().unwrap_or(projection); - - let planner = Planner::new(Arc::new(schema.into())); - let optimized_filter = planner.optimize_expr(filter_expr)?; + let filter = filter_plan.full_expr.as_ref().unwrap(); + let filter_cols = Planner::column_names_in_expr(filter); + let scan_projection = projection.union_columns(filter_cols, OnMissing::Error)?; + + let scan_schema = scan_projection.into_schema_ref(); + let scan_arrow_schema = Arc::new(scan_schema.as_ref().into()); + let planner = Planner::new(scan_arrow_schema); + let optimized_filter = planner.optimize_expr(filter.clone())?; let physical_refine_expr = planner.create_physical_expr(&optimized_filter)?; let new_data_scan = self.scan_fragments( true, self.with_row_address, false, - Arc::new(schema.clone()), + scan_schema, missing_frags.into(), // No pushdown of limit/offset when doing scalar indexed scan None, @@ -1943,16 +2281,59 @@ impl Scanner { q.metric_type, )?); + // filter out elements out of distance range + let lower_bound_expr = q + .lower_bound + .map(|v| { + let lower_bound = expressions::lit(v); + expressions::binary( + expressions::col(DIST_COL, flat_dist.schema().as_ref())?, + Operator::GtEq, + lower_bound, + flat_dist.schema().as_ref(), + ) + }) + .transpose()?; + let upper_bound_expr = q + .upper_bound + .map(|v| { + let upper_bound = expressions::lit(v); + expressions::binary( + expressions::col(DIST_COL, flat_dist.schema().as_ref())?, + Operator::Lt, + upper_bound, + flat_dist.schema().as_ref(), + ) + }) + .transpose()?; + let filter_expr = match (lower_bound_expr, upper_bound_expr) { + (Some(lower), Some(upper)) => Some(expressions::binary( + lower, + Operator::And, + upper, + flat_dist.schema().as_ref(), + )?), + (Some(lower), None) => Some(lower), + (None, Some(upper)) => Some(upper), + (None, None) => None, + }; + + let knn_plan: Arc = if let Some(filter_expr) = filter_expr { + Arc::new(FilterExec::try_new(filter_expr, flat_dist)?) + } else { + flat_dist + }; + // Use DataFusion's [SortExec] for Top-K search let sort = SortExec::new( - vec![PhysicalSortExpr { - expr: expressions::col(DIST_COL, flat_dist.schema().as_ref())?, + LexOrdering::new(vec![PhysicalSortExpr { + expr: expressions::col(DIST_COL, knn_plan.schema().as_ref())?, options: SortOptions { descending: false, nulls_first: false, }, - }], - flat_dist, + }]), + knn_plan, ) .with_fetch(Some(q.k)); @@ -1964,27 +2345,94 @@ impl Scanner { Ok(Arc::new(not_nulls)) } - /// Create an Execution plan to do indexed ANN search - async fn ann( + /// Create an Execution plan to do indexed ANN search + async fn ann( + &self, + q: &Query, + index: &[Index], + filter_plan: &FilterPlan, + ) -> Result> { + let prefilter_source = self.prefilter_source(filter_plan).await?; + let inner_fanout_search = new_knn_exec(self.dataset.clone(), index, q, prefilter_source)?; + let sort_expr = PhysicalSortExpr { + expr: expressions::col(DIST_COL, inner_fanout_search.schema().as_ref())?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + Ok(Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]), inner_fanout_search) + .with_fetch(Some(q.k * q.refine_factor.unwrap_or(1) as usize)), + )) + } + + // Create an Execution plan to do ANN over multivectors + async fn multivec_ann( &self, q: &Query, index: &[Index], filter_plan: &FilterPlan, ) -> Result> { + // we split the query procedure into two steps: + // 1. collect the candidates by vector searching on each query vector + // 2. scoring the candidates + + let over_fetch_factor = *DEFAULT_XTR_OVERFETCH; + let prefilter_source = self.prefilter_source(filter_plan).await?; + let dim = get_vector_dim(self.dataset.schema(), &q.column)?; + + let num_queries = q.key.len() / dim; + let new_queries = (0..num_queries) + .map(|i| q.key.slice(i * dim, dim)) + .map(|query_vec| { + let mut new_query = q.clone(); + new_query.key = query_vec; + // with XTR, we don't need to refine the result with original vectors, + // but here we really need to over-fetch the candidates to reach good enough recall. + // TODO: improve the recall with WARP, expose this parameter to the users. + new_query.refine_factor = Some(over_fetch_factor); + new_query + }); + let mut ann_nodes = Vec::with_capacity(new_queries.len()); + for query in new_queries { + // this produces `nprobes * k * over_fetch_factor * num_indices` candidates + let ann_node = new_knn_exec( + self.dataset.clone(), + index, + &query, + prefilter_source.clone(), + )?; + let sort_expr = PhysicalSortExpr { + expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + let ann_node = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]), ann_node) + .with_fetch(Some(q.k * over_fetch_factor as usize)), + ); + ann_nodes.push(ann_node as Arc); + } + + let ann_node = Arc::new(MultivectorScoringExec::try_new(ann_nodes, q.clone())?); - let inner_fanout_search = new_knn_exec(self.dataset.clone(), index, q, prefilter_source)?; let sort_expr = PhysicalSortExpr { - expr: expressions::col(DIST_COL, inner_fanout_search.schema().as_ref())?, + expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, options: SortOptions { descending: false, nulls_first: false, }, }; - Ok(Arc::new( - SortExec::new(vec![sort_expr], inner_fanout_search) + let ann_node = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]), ann_node) .with_fetch(Some(q.k * q.refine_factor.unwrap_or(1) as usize)), - )) + ); + + Ok(ann_node) } /// Create prefilter source from filter plan @@ -1993,23 +2441,18 @@ impl Scanner { &filter_plan.index_query, &filter_plan.refine_expr, self.prefilter, + filter_plan.skip_recheck, ) { - (Some(index_query), Some(refine_expr), _) => { - // The filter is only partially satisfied by the index. We need - // to do an indexed scan and then refine the results to determine - // the row ids. - let columns_in_filter = Planner::column_names_in_expr(refine_expr); - let filter_schema = Arc::new(self.dataset.schema().project(&columns_in_filter)?); - let filter_input = self - .scalar_indexed_scan(&filter_schema, index_query) + (Some(_), Some(_), _, _) | (Some(_), None, true, false) => { + // Prefilter source is covered by an index but either that index needs a recheck or there + // is a refine expression that needs to be applied to the results so we need to do a full + // filtered scan + let filtered_row_ids = self + .scalar_indexed_scan(self.dataset.empty_projection().with_row_id(), filter_plan) .await?; - let planner = Planner::new(filter_input.schema()); - let physical_refine_expr = planner.create_physical_expr(refine_expr)?; - let filtered_row_ids = - Arc::new(FilterExec::try_new(physical_refine_expr, filter_input)?); PreFilterSource::FilteredRowIds(filtered_row_ids) } // Should be index_scan -> filter - (Some(index_query), None, true) => { + (Some(index_query), None, true, true) => { // Index scan doesn't honor the fragment allowlist today. // TODO: we could filter the index scan results to only include the allowed fragments. self.ensure_not_fragment_scan()?; @@ -2023,7 +2466,7 @@ impl Scanner { )); PreFilterSource::ScalarIndexQuery(index_query) } - (None, Some(refine_expr), true) => { + (None, Some(refine_expr), true, _) => { // No indices match the filter. We need to do a full scan // of the filter columns to determine the valid row ids. let columns_in_filter = Planner::column_names_in_expr(refine_expr); @@ -2036,8 +2479,8 @@ impl Scanner { PreFilterSource::FilteredRowIds(filtered_row_ids) } // No prefilter - (None, None, true) => PreFilterSource::None, - (_, _, false) => PreFilterSource::None, + (None, None, true, _) => PreFilterSource::None, + (_, _, false, _) => PreFilterSource::None, }; Ok(prefilter_source) @@ -2047,16 +2490,20 @@ impl Scanner { fn take( &self, input: Arc, - projection: &Schema, - batch_readahead: usize, + output_projection: Projection, ) -> Result> { - let coalesced = Arc::new(CoalesceBatchesExec::new(input, self.get_batch_size())); - Ok(Arc::new(TakeExec::try_new( - self.dataset.clone(), - coalesced, - Arc::new(projection.clone()), - batch_readahead, - )?)) + let coalesced = Arc::new(CoalesceBatchesExec::new( + input.clone(), + self.get_batch_size(), + )); + if let Some(take_plan) = + TakeExec::try_new(self.dataset.clone(), coalesced, output_projection)? + { + Ok(Arc::new(take_plan)) + } else { + // No new columns needed + Ok(input) + } } /// Global offset-limit of the result of the input plan @@ -2068,6 +2515,20 @@ impl Scanner { )) } + #[instrument(level = "info", skip(self))] + pub async fn analyze_plan(&self) -> Result { + let plan = self.create_plan().await?; + + analyze_plan( + plan, + LanceExecutionOptions { + batch_size: self.batch_size, + ..Default::default() + }, + ) + .await + } + #[instrument(level = "info", skip(self))] pub async fn explain_plan(&self, verbose: bool) -> Result { let plan = self.create_plan().await?; @@ -2316,6 +2777,7 @@ mod test { use half::f16; use lance_datagen::{array, gen, BatchCount, ByteCount, Dimension, RowCount}; use lance_file::version::LanceFileVersion; + use lance_index::scalar::inverted::query::{MatchQuery, PhraseQuery}; use lance_index::scalar::InvertedIndexParams; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::ivf::IvfBuildParams; @@ -2500,60 +2962,50 @@ mod test { #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, #[values(false, true)] stable_row_ids: bool, + #[values(false, true)] build_index: bool, ) { - for build_index in &[true, false] { - let mut test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids) - .await - .unwrap(); - if *build_index { - test_ds.make_vector_index().await.unwrap(); - } - let dataset = &test_ds.dataset; - - let mut scan = dataset.scan(); - let key: Float32Array = (32..64).map(|v| v as f32).collect(); - scan.nearest("vec", &key, 5).unwrap(); - scan.refine(5); + let mut test_ds = TestVectorDataset::new(data_storage_version, stable_row_ids) + .await + .unwrap(); + if build_index { + test_ds.make_vector_index().await.unwrap(); + } + let dataset = &test_ds.dataset; - let results = scan - .try_into_stream() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); + let mut scan = dataset.scan(); + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + scan.nearest("vec", &key, 5).unwrap(); + scan.refine(5); - assert_eq!(results.len(), 1); - let batch = &results[0]; + let batch = scan.try_into_batch().await.unwrap(); - assert_eq!(batch.num_rows(), 5); - assert_eq!( - batch.schema().as_ref(), - &ArrowSchema::new(vec![ - ArrowField::new("i", DataType::Int32, true), - ArrowField::new("s", DataType::Utf8, true), - ArrowField::new( - "vec", - DataType::FixedSizeList( - Arc::new(ArrowField::new("item", DataType::Float32, true)), - 32, - ), - true, + assert_eq!(batch.num_rows(), 5); + assert_eq!( + batch.schema().as_ref(), + &ArrowSchema::new(vec![ + ArrowField::new("i", DataType::Int32, true), + ArrowField::new("s", DataType::Utf8, true), + ArrowField::new( + "vec", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + 32, ), - ArrowField::new(DIST_COL, DataType::Float32, true), - ]) - .with_metadata([("dataset".into(), "vector".into())].into()) - ); + true, + ), + ArrowField::new(DIST_COL, DataType::Float32, true), + ]) + .with_metadata([("dataset".into(), "vector".into())].into()) + ); - let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]); - let column_i = batch.column_by_name("i").unwrap(); - let actual_i: BTreeSet = as_primitive_array::(column_i.as_ref()) - .values() - .iter() - .copied() - .collect(); - assert_eq!(expected_i, actual_i); - } + let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]); + let column_i = batch.column_by_name("i").unwrap(); + let actual_i: BTreeSet = as_primitive_array::(column_i.as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_i, actual_i); } #[rstest] @@ -3249,7 +3701,7 @@ mod test { let query_key = Arc::new(Float32Array::from_iter_values((0..2).map(|x| x as f32))); let mut scan = dataset.scan(); scan.filter("filterable > 5").unwrap(); - scan.nearest("vector", &query_key, 1).unwrap(); + scan.nearest("vector", query_key.as_ref(), 1).unwrap(); scan.with_row_id(); let batches = scan @@ -3695,14 +4147,11 @@ mod test { .unwrap(); let dataset = Dataset::open(test_uri).await.unwrap(); - assert_eq!(32, dataset.scan().count_rows().await.unwrap()); + assert_eq!(32, dataset.count_rows(None).await.unwrap()); assert_eq!( 16, dataset - .scan() - .filter("`Filter_me` > 15") - .unwrap() - .count_rows() + .count_rows(Some("`Filter_me` > 15".to_string())) .await .unwrap() ); @@ -3730,7 +4179,7 @@ mod test { .unwrap(); let dataset = Dataset::open(test_uri).await.unwrap(); - assert_eq!(32, dataset.scan().count_rows().await.unwrap()); + assert_eq!(dataset.count_rows(None).await.unwrap(), 32); let mut scanner = dataset.scan(); @@ -3778,7 +4227,7 @@ mod test { .unwrap(); let dataset = Dataset::open(test_uri).await.unwrap(); - assert_eq!(32, dataset.scan().count_rows().await.unwrap()); + assert_eq!(dataset.count_rows(None).await.unwrap(), 32); let mut scanner = dataset.scan(); @@ -4301,20 +4750,30 @@ mod test { } } - /// Assert that the plan when formatted matches the expected string. - /// - /// Within expected, you can use `...` to match any number of characters. - async fn assert_plan_equals( - dataset: &Dataset, - plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>, + #[rstest] + #[tokio::test] + async fn test_index_take_batch_size() { + let fixture = ScalarIndexTestFixture::new(LanceFileVersion::Stable, false).await; + let stream = fixture + .dataset + .scan() + .filter("indexed > 0") + .unwrap() + .batch_size(16) + .try_into_stream() + .await + .unwrap(); + let batches = stream.collect::>().await; + assert_eq!(batches.len(), 1000_usize.div_ceil(16)); + } + + async fn assert_plan_node_equals( + plan_node: Arc, expected: &str, ) -> Result<()> { - let mut scan = dataset.scan(); - plan(&mut scan)?; - let exec_plan = scan.create_plan().await?; let plan_desc = format!( "{}", - datafusion::physical_plan::displayable(exec_plan.as_ref()).indent(true) + datafusion::physical_plan::displayable(plan_node.as_ref()).indent(true) ); let to_match = expected.split("...").collect::>(); @@ -4341,6 +4800,148 @@ mod test { Ok(()) } + /// Assert that the plan when formatted matches the expected string. + /// + /// Within expected, you can use `...` to match any number of characters. + async fn assert_plan_equals( + dataset: &Dataset, + plan: impl Fn(&mut Scanner) -> Result<&mut Scanner>, + expected: &str, + ) -> Result<()> { + let mut scan = dataset.scan(); + plan(&mut scan)?; + let exec_plan = scan.create_plan().await?; + assert_plan_node_equals(exec_plan, expected).await + } + + #[tokio::test] + async fn test_count_plan() { + // A count rows operation should load the minimal amount of data + let dim = 256; + let fixture = TestVectorDataset::new_with_dimension(LanceFileVersion::Stable, true, dim) + .await + .unwrap(); + + // By default, all columns are returned, this is bad for a count_rows op + let err = fixture + .dataset + .scan() + .create_count_plan() + .await + .unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. })); + + let mut scan = fixture.dataset.scan(); + scan.project(&Vec::::default()).unwrap(); + + // with_row_id needs to be specified + let err = scan.create_count_plan().await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. })); + + scan.with_row_id(); + + let plan = scan.create_count_plan().await.unwrap(); + + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[], aggr=[count_rows] + LanceScan: uri=..., projection=[], row_id=true, row_addr=false, ordered=true", + ) + .await + .unwrap(); + + scan.filter("s == ''").unwrap(); + + let plan = scan.create_count_plan().await.unwrap(); + + assert_plan_node_equals( + plan, + "AggregateExec: mode=Single, gby=[], aggr=[count_rows] + ProjectionExec: expr=[_rowid@1 as _rowid] + FilterExec: s@0 = + LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=true", + ) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_inexact_scalar_index_plans() { + let data = gen() + .col("ngram", array::rand_utf8(ByteCount::from(5), false)) + .col("exact", array::rand_type(&DataType::UInt32)) + .col("no_index", array::rand_type(&DataType::UInt32)) + .into_reader_rows(RowCount::from(1000), BatchCount::from(5)); + + let mut dataset = Dataset::write(data, "memory://test", None).await.unwrap(); + dataset + .create_index( + &["ngram"], + IndexType::NGram, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + dataset + .create_index( + &["exact"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + // Simple in-exact filter + assert_plan_equals( + &dataset, + |scanner| scanner.filter("contains(ngram, 'test string')"), + "ProjectionExec: expr=[ngram@1 as ngram, exact@2 as exact, no_index@3 as no_index] + FilterExec: contains(ngram@1, test string) + Take: columns=\"_rowid, (ngram), (exact), (no_index)\" + CoalesceBatchesExec: target_batch_size=8192 + MaterializeIndex: query=contains(ngram, Utf8(\"test string\"))", + ) + .await + .unwrap(); + + // Combined with exact filter + // + // TODO: The FilterExec _should_ be just contains(ngram, 'test string') + assert_plan_equals( + &dataset, + |scanner| scanner.filter("contains(ngram, 'test string') and exact < 50"), + "ProjectionExec: expr=[ngram@1 as ngram, exact@2 as exact, no_index@3 as no_index] + FilterExec: contains(ngram@1, test string) AND exact@2 < 50 + Take: columns=\"_rowid, (ngram), (exact), (no_index)\" + CoalesceBatchesExec: target_batch_size=8192 + MaterializeIndex: query=AND(contains(ngram, Utf8(\"test string\")),exact < 50)", + ) + .await + .unwrap(); + + // All three filters + // + // TODO: Maybe an optimizer rule to combine the filters? Not a big deal + assert_plan_equals( + &dataset, + |scanner| { + scanner.filter("contains(ngram, 'test string') and exact < 50 AND no_index > 100") + }, + "ProjectionExec: expr=[ngram@1 as ngram, exact@2 as exact, no_index@3 as no_index] + FilterExec: no_index@3 > 100 + FilterExec: contains(ngram@1, test string) AND exact@2 < 50 AND no_index@3 > 100 + Take: columns=\"_rowid, (ngram), (exact), (no_index)\" + CoalesceBatchesExec: target_batch_size=8192 + MaterializeIndex: query=AND(contains(ngram, Utf8(\"test string\")),exact < 50)", + ) + .await + .unwrap(); + } + #[rstest] #[tokio::test] async fn test_late_materialization( @@ -4526,8 +5127,11 @@ mod test { #[values(false, true)] stable_row_id: bool, ) -> Result<()> { // Create a vector dataset + + use lance_index::scalar::inverted::query::BoostQuery; + let dim = 256; let mut dataset = - TestVectorDataset::new_with_dimension(data_storage_version, stable_row_id, 256).await?; + TestVectorDataset::new_with_dimension(data_storage_version, stable_row_id, dim).await?; let lance_schema = dataset.dataset.schema(); // Scans @@ -4561,11 +5165,11 @@ mod test { assert_plan_equals( &dataset.dataset, |scan| scan.use_stats(false).filter("s IS NOT NULL"), - "ProjectionExec: expr=[i@1 as i, s@0 as s, vec@3 as vec] - Take: columns=\"s, i, _rowid, (vec)\" + "ProjectionExec: expr=[i@0 as i, s@1 as s, vec@3 as vec] + Take: columns=\"i, s, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 - FilterExec: s@0 IS NOT NULL - LanceScan: uri..., projection=[s, i], row_id=true, row_addr=false, ordered=true", + FilterExec: s@1 IS NOT NULL + LanceScan: uri..., projection=[i, s], row_id=true, row_addr=false, ordered=true", ) .await?; @@ -4577,9 +5181,9 @@ mod test { .materialization_style(MaterializationStyle::AllEarly) .filter("s IS NOT NULL") }, - "ProjectionExec: expr=[i@1 as i, s@0 as s, vec@2 as vec] - FilterExec: s@0 IS NOT NULL - LanceScan: uri..., projection=[s, i, vec], row_id=true, row_addr=false, ordered=true", + "ProjectionExec: expr=[i@0 as i, s@1 as s, vec@2 as vec] + FilterExec: s@1 IS NOT NULL + LanceScan: uri..., projection=[i, s, vec], row_id=true, row_addr=false, ordered=true", ) .await?; @@ -4624,7 +5228,7 @@ mod test { // KNN // --------------------------------------------------------------------- - let q: Float32Array = (32..64).map(|v| v as f32).collect(); + let q: Float32Array = (32..32 + dim).map(|v| v as f32).collect(); assert_plan_equals( &dataset.dataset, |scan| scan.nearest("vec", &q, 5), @@ -4638,6 +5242,23 @@ mod test { ) .await?; + // KNN + Limit (arguably the user, or us, should fold the limit into the KNN but we don't today) + // --------------------------------------------------------------------- + let q: Float32Array = (32..32 + dim).map(|v| v as f32).collect(); + assert_plan_equals( + &dataset.dataset, + |scan| scan.nearest("vec", &q, 5)?.limit(Some(1), None), + "ProjectionExec: expr=[i@3 as i, s@4 as s, vec@0 as vec, _distance@2 as _distance] + Take: columns=\"vec, _rowid, _distance, (i), (s)\" + CoalesceBatchesExec: target_batch_size=8192 + GlobalLimitExec: skip=0, fetch=1 + FilterExec: _distance@2 IS NOT NULL + SortExec: TopK(fetch=5), expr=... + KNNVectorDistance: metric=l2 + LanceScan: uri=..., projection=[vec], row_id=true, row_addr=false, ordered=false", + ) + .await?; + // ANN // --------------------------------------------------------------------- dataset.make_vector_index().await?; @@ -4980,9 +5601,9 @@ mod test { Take: columns=\"_rowid, (s)\" CoalesceBatchesExec: target_batch_size=8192 MaterializeIndex: query=i > 10 - ProjectionExec: expr=[_rowid@2 as _rowid, s@0 as s] - FilterExec: i@1 > 10 - LanceScan: uri=..., projection=[s, i], row_id=true, row_addr=false, ordered=false", + ProjectionExec: expr=[_rowid@2 as _rowid, s@1 as s] + FilterExec: i@0 > 10 + LanceScan: uri=..., projection=[i, s], row_id=true, row_addr=false, ordered=false", ) .await?; @@ -5040,9 +5661,9 @@ mod test { Take: columns=\"_rowid, (s)\" CoalesceBatchesExec: target_batch_size=8192 MaterializeIndex: query=i > 10 - ProjectionExec: expr=[_rowid@2 as _rowid, s@0 as s] - FilterExec: i@1 > 10 - LanceScan: uri=..., projection=[s, i], row_id=true, row_addr=false, ordered=false", + ProjectionExec: expr=[_rowid@2 as _rowid, s@1 as s] + FilterExec: i@0 > 10 + LanceScan: uri=..., projection=[i, s], row_id=true, row_addr=false, ordered=false", ) .await?; @@ -5060,12 +5681,45 @@ mod test { r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid] Take: columns="_rowid, _score, (s)" CoalesceBatchesExec: target_batch_size=8192 - SortExec: expr=[_score@1 DESC NULLS LAST], preserve_partitioning=[false] - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2 - UnionExec - Fts: query=hello - FlatFts: query=hello - EmptyExec"#, + MatchQuery: query=hello"#, + ) + .await?; + + // Phrase query + assert_plan_equals( + &dataset.dataset, + |scan| { + let query = PhraseQuery::new("hello world".to_owned()); + scan.project(&["s"])? + .with_row_id() + .full_text_search(FullTextSearchQuery::new_query(query.into())) + }, + r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid] + Take: columns="_rowid, _score, (s)" + CoalesceBatchesExec: target_batch_size=8192 + PhraseQuery: query=hello world"#, + ) + .await?; + + // Boost query + assert_plan_equals( + &dataset.dataset, + |scan| { + let positive = + MatchQuery::new("hello".to_owned()).with_column(Some("s".to_owned())); + let negative = + MatchQuery::new("world".to_owned()).with_column(Some("s".to_owned())); + let query = BoostQuery::new(positive.into(), negative.into(), Some(1.0)); + scan.project(&["s"])? + .with_row_id() + .full_text_search(FullTextSearchQuery::new_query(query.into())) + }, + r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid] + Take: columns="_rowid, _score, (s)" + CoalesceBatchesExec: target_batch_size=8192 + BoostQuery: negative_boost=1 + MatchQuery: query=hello + MatchQuery: query=world"#, ) .await?; @@ -5083,13 +5737,8 @@ mod test { r#"ProjectionExec: expr=[s@2 as s, _score@1 as _score, _rowid@0 as _rowid] Take: columns="_rowid, _score, (s)" CoalesceBatchesExec: target_batch_size=8192 - SortExec: expr=[_score@1 DESC NULLS LAST], preserve_partitioning=[false] - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2 - UnionExec - Fts: query=hello - ScalarIndexQuery: query=i > 10 - FlatFts: query=hello - EmptyExec"#, + MatchQuery: query=hello + ScalarIndexQuery: query=i > 10"#, ) .await?; @@ -5108,8 +5757,8 @@ mod test { SortExec: expr=[_score@1 DESC NULLS LAST], preserve_partitioning=[false] RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2 UnionExec - Fts: query=hello - FlatFts: query=hello + MatchQuery: query=hello + FlatMatchQuery: query=hello LanceScan: uri=..., projection=[s], row_id=true, row_addr=false, ordered=false"#, ) .await?; @@ -5130,9 +5779,9 @@ mod test { SortExec: expr=[_score@1 DESC NULLS LAST], preserve_partitioning=[false] RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=2 UnionExec - Fts: query=hello + MatchQuery: query=hello ScalarIndexQuery: query=i > 10 - FlatFts: query=hello + FlatMatchQuery: query=hello FilterExec: i@1 > 10 LanceScan: uri=..., projection=[s, i], row_id=true, row_addr=false, ordered=false"#, ) diff --git a/rust/lance/src/dataset/schema_evolution.rs b/rust/lance/src/dataset/schema_evolution.rs index 6e9993435a5..7b4d7f3fdad 100644 --- a/rust/lance/src/dataset/schema_evolution.rs +++ b/rust/lance/src/dataset/schema_evolution.rs @@ -3,7 +3,6 @@ use std::{collections::HashSet, sync::Arc}; -use crate::io::commit::commit_transaction; use crate::{io::exec::Planner, Error, Result}; use arrow::compute::CastOptions; use arrow_array::{RecordBatch, RecordBatchReader}; @@ -14,7 +13,7 @@ use lance_arrow::SchemaExt; use lance_core::datatypes::{Field, Schema}; use lance_datafusion::utils::StreamingWriteSource; use lance_table::format::Fragment; -use snafu::{location, Location}; +use snafu::location; use super::fragment::FileFragment; use super::{ @@ -22,6 +21,12 @@ use super::{ Dataset, }; +mod optimize; + +use optimize::{ + ChainedNewColumnTransformOptimizer, NewColumnTransformOptimizer, SqlToAllNullsOptimizer, +}; + #[derive(Debug, Clone, PartialEq)] pub struct BatchInfo { pub fragment_id: u32, @@ -60,6 +65,8 @@ pub enum NewColumnTransform { Stream(SendableRecordBatchStream), /// An iterator of RecordBatches that define new columns. Reader(Box), + /// Add new columns that are initially all null + AllNulls(Arc), } /// Definition of a change to a column in a dataset @@ -146,6 +153,14 @@ pub(super) async fn add_columns_to_fragments( Ok(()) }; + // Optimize the transforms + let mut optimizer = ChainedNewColumnTransformOptimizer::new(vec![]); + // ALlNull transform can not performed on legacy files + if !dataset.is_legacy_storage() { + optimizer.add_optimizer(Box::new(SqlToAllNullsOptimizer::new())); + } + let transforms = optimizer.optimize(dataset, transforms)?; + let (output_schema, fragments) = match transforms { NewColumnTransform::BatchUDF(udf) => { check_names(udf.output_schema.as_ref())?; @@ -238,6 +253,36 @@ pub(super) async fn add_columns_to_fragments( let fragments = add_columns_from_stream(fragments, stream, None, batch_size).await?; Ok((output_schema, fragments)) } + NewColumnTransform::AllNulls(output_schema) => { + check_names(output_schema.as_ref())?; + + // Check that the schema is compatible considering all the new columns must be nullable + let schema = Schema::try_from(output_schema.as_ref())?; + if !schema.all_fields_nullable() { + return Err(Error::InvalidInput { + source: "All-null columns must be nullable.".into(), + location: location!(), + }); + } + + let fragments = fragments + .iter() + .map(|f| f.metadata.clone()) + .collect::>(); + + // Check if any of the fragment's files are using the legacy dataset version if so, we + // can't add all-null columns as a metadata-only operation. The reason is because we + // use the NullReader for fragments that have missing columns and we can't mix legacy + // and non-legacy readers when reading the fragment. + if dataset.is_legacy_storage() { + return Err(Error::NotSupported { + source: "Cannot add all-null columns to legacy dataset version.".into(), + location: location!(), + }); + } + + Ok((output_schema, fragments)) + } }?; let mut schema = dataset.schema().merge(output_schema.as_ref())?; @@ -269,19 +314,9 @@ pub(super) async fn add_columns( /*blob_op= */ None, None, ); - let (new_manifest, new_path) = commit_transaction( - dataset, - &dataset.object_store, - dataset.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - dataset.manifest_naming_scheme, - ) - .await?; - - dataset.manifest = Arc::new(new_manifest); - dataset.manifest_file = new_path; + dataset + .apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(()) } @@ -591,19 +626,9 @@ pub(super) async fn alter_columns( // TODO: adjust the indices here for the new schema - let (manifest, manifest_path) = commit_transaction( - dataset, - &dataset.object_store, - dataset.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - dataset.manifest_naming_scheme, - ) - .await?; - - dataset.manifest = Arc::new(manifest); - dataset.manifest_file = manifest_path; + dataset + .apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(()) } @@ -653,19 +678,9 @@ pub(super) async fn drop_columns(dataset: &mut Dataset, columns: &[&str]) -> Res None, ); - let (manifest, manifest_path) = commit_transaction( - dataset, - &dataset.object_store, - dataset.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - dataset.manifest_naming_scheme, - ) - .await?; - - dataset.manifest = Arc::new(manifest); - dataset.manifest_file = manifest_path; + dataset + .apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(()) } @@ -1046,6 +1061,131 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_add_column_all_nulls() -> Result<()> { + let num_rows = 100; + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "id", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..num_rows))], + )?; + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + + let test_dir = tempfile::tempdir()?; + let test_uri = test_dir.path().to_str().unwrap(); + let mut dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + max_rows_per_file: 50, + max_rows_per_group: 25, + data_storage_version: Some(LanceFileVersion::Stable), + ..Default::default() + }), + ) + .await?; + dataset.validate().await?; + + dataset + .add_columns( + NewColumnTransform::AllNulls(Arc::new(ArrowSchema::new(vec![ArrowField::new( + "nulls", + DataType::Int32, + true, + )]))), + None, + None, + ) + .await?; + + let data = dataset.scan().try_into_batch().await?; + let expected_schema = ArrowSchema::new(vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new("nulls", DataType::Int32, true), + ]); + assert_eq!(data.schema().as_ref(), &expected_schema); + assert_eq!(data.num_rows(), num_rows as usize); + + // check that can't add non-nullable columns + let err = + dataset + .add_columns( + NewColumnTransform::AllNulls(Arc::new(ArrowSchema::new(vec![ + ArrowField::new("non_nulls", DataType::Int32, false), + ]))), + None, + None, + ) + .await + .unwrap_err(); + assert!(err + .to_string() + .contains("All-null columns must be nullable.")); + + let data = dataset.scan().try_into_batch().await?; + let expected_schema = ArrowSchema::new(vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new("nulls", DataType::Int32, true), + ]); + assert_eq!(data.schema().as_ref(), &expected_schema); + assert_eq!(data.num_rows(), num_rows as usize); + + Ok(()) + } + + #[tokio::test] + async fn test_add_column_all_nulls_legacy() -> Result<()> { + let num_rows = 100; + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "id", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..num_rows))], + )?; + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + + let test_dir = tempfile::tempdir()?; + let test_uri = test_dir.path().to_str().unwrap(); + let mut dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + max_rows_per_file: 50, + max_rows_per_group: 25, + data_storage_version: Some(LanceFileVersion::Legacy), + ..Default::default() + }), + ) + .await?; + dataset.validate().await?; + + let err = + dataset + .add_columns( + NewColumnTransform::AllNulls(Arc::new(ArrowSchema::new(vec![ + ArrowField::new("nulls", DataType::Int32, true), + ]))), + None, + None, + ) + .await + .unwrap_err(); + assert!(err + .to_string() + .contains("Cannot add all-null columns to legacy dataset version")); + + Ok(()) + } + #[rstest] #[tokio::test] async fn test_rename_columns( @@ -1576,4 +1716,115 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_new_column_sql_to_all_nulls_transform_optimizer() { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter(0..100))], + ) + .unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let test_dir = tempfile::tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let mut dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + max_rows_per_file: 50, + max_rows_per_group: 25, + data_storage_version: Some(LanceFileVersion::Stable), + ..Default::default() + }), + ) + .await + .unwrap(); + dataset.validate().await.unwrap(); + + let manifest_before = dataset.manifest.clone(); + + // Add all null column + dataset + .add_columns( + NewColumnTransform::SqlExpressions(vec![( + "b".to_string(), + "CAST(NULL AS int)".to_string(), + )]), + None, + None, + ) + .await + .unwrap(); + let manifest_after = dataset.manifest.clone(); + + // Check that this is a metadata-only operation (the fragments don't change) + assert_eq!(&manifest_before.fragments, &manifest_after.fragments); + + // check that the new field was added to the schema + let expected_schema = ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new("b", DataType::Int32, true), + ]); + assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema); + } + + #[tokio::test] + async fn test_new_column_sql_to_all_nulls_transform_optimizer_legacy() { + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "a", + DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter(0..100))], + ) + .unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let test_dir = tempfile::tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let mut dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + max_rows_per_file: 50, + max_rows_per_group: 25, + data_storage_version: Some(LanceFileVersion::Legacy), + ..Default::default() + }), + ) + .await + .unwrap(); + dataset.validate().await.unwrap(); + + // Add all null column ... + // This is basically a smoke test to ensure we don't try to use the all-nulls + // transform optimizer where it's not supported, and then blow up when we try + // to apply the transform + dataset + .add_columns( + NewColumnTransform::SqlExpressions(vec![( + "b".to_string(), + "CAST(NULL AS int)".to_string(), + )]), + None, + None, + ) + .await + .unwrap(); + + // check that the new field was added to the schema + let expected_schema = ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new("b", DataType::Int32, true), + ]); + assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema); + } } diff --git a/rust/lance/src/dataset/schema_evolution/optimize.rs b/rust/lance/src/dataset/schema_evolution/optimize.rs new file mode 100644 index 00000000000..540ae65f20f --- /dev/null +++ b/rust/lance/src/dataset/schema_evolution/optimize.rs @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow_schema::{DataType, Field, Schema}; +use datafusion::prelude::Expr; +use datafusion::scalar::ScalarValue; +use lance_datafusion::planner::Planner; + +use crate::error::Result; +use crate::Dataset; + +use super::NewColumnTransform; + +/// Optimizes a `NewColumnTransform` into +pub(super) trait NewColumnTransformOptimizer: Send + Sync { + /// Optimize the passed `NewColumnTransform` to a more efficient form. + fn optimize( + &self, + dataset: &Dataset, + transform: NewColumnTransform, + ) -> Result; +} + +/// A `NewColumnTransformOptimizer` that chains multiple `NewColumnTransformOptimizer`s together. +pub(super) struct ChainedNewColumnTransformOptimizer { + optimizers: Vec>, +} + +impl ChainedNewColumnTransformOptimizer { + pub(super) fn new(optimizers: Vec>) -> Self { + Self { optimizers } + } + + pub(super) fn add_optimizer(&mut self, optimizer: Box) { + self.optimizers.push(optimizer); + } +} + +/// A `NewColumnTransformOptimizer` that chains multiple `NewColumnTransformOptimizer`s together. +impl NewColumnTransformOptimizer for ChainedNewColumnTransformOptimizer { + fn optimize( + &self, + dataset: &Dataset, + transform: NewColumnTransform, + ) -> Result { + let mut transform = transform; + for optimizer in &self.optimizers { + transform = optimizer.optimize(dataset, transform)?; + } + Ok(transform) + } +} + +/// Optimizes a `NewColumnTransform` that is a SQL expression to a `NewColumnTransform::AllNulls` if +/// the SQL expression is "NULL". For example +/// `NewColumnTransform::SqlExpression(vec![("new_col", "CAST(NULL AS int)"])` +/// would be optimized to +/// `NewColumnTransform::AllNulls(Schema::new(vec![Field::new("new_col", DataType::Int)]))`. +/// +pub(super) struct SqlToAllNullsOptimizer; + +impl SqlToAllNullsOptimizer { + pub(super) fn new() -> Self { + Self + } + + fn is_all_null(&self, expr: &Expr) -> AllNullsResult { + match expr { + Expr::Cast(cast) => { + if matches!(cast.expr.as_ref(), Expr::Literal(ScalarValue::Null)) { + let data_type = cast.data_type.clone(); + AllNullsResult::AllNulls(data_type) + } else { + AllNullsResult::NotAllNulls + } + } + _ => AllNullsResult::NotAllNulls, + } + } +} + +enum AllNullsResult { + AllNulls(DataType), + NotAllNulls, +} + +impl NewColumnTransformOptimizer for SqlToAllNullsOptimizer { + fn optimize( + &self, + dataset: &Dataset, + transform: NewColumnTransform, + ) -> Result { + match &transform { + NewColumnTransform::SqlExpressions(expressions) => { + let arrow_schema = Arc::new(Schema::from(dataset.schema())); + let planner = Planner::new(arrow_schema); + let mut all_null_schema_fields = vec![]; + for (name, expr) in expressions { + let expr = planner.parse_expr(expr)?; + if let AllNullsResult::AllNulls(data_type) = self.is_all_null(&expr) { + let field = Field::new(name, data_type, true); + all_null_schema_fields.push(field); + } else { + return Ok(transform); + } + } + + let all_null_schema = Schema::new(all_null_schema_fields); + Ok(NewColumnTransform::AllNulls(Arc::new(all_null_schema))) + } + _ => Ok(transform), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + use arrow_array::RecordBatchIterator; + + #[tokio::test] + async fn test_sql_to_all_null_transform() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let empty_reader = RecordBatchIterator::new(vec![], schema.clone()); + let dataset = Arc::new( + Dataset::write(empty_reader, "memory://", None) + .await + .unwrap(), + ); + + let original = NewColumnTransform::SqlExpressions(vec![ + ("new_col1".to_string(), "CAST(NULL AS int)".to_string()), + ("new_col2".to_string(), "CAST(NULL AS bigint)".to_string()), + ]); + + let optimizer = SqlToAllNullsOptimizer::new(); + let result = optimizer.optimize(&dataset, original).unwrap(); + + assert!(matches!(result, NewColumnTransform::AllNulls(_))); + if let NewColumnTransform::AllNulls(schema) = result { + assert_eq!(schema.fields().len(), 2); + assert_eq!(schema.field(0).name(), "new_col1"); + assert_eq!(schema.field(0).data_type(), &DataType::Int32); + assert!(schema.field(0).is_nullable()); + assert_eq!(schema.field(1).name(), "new_col2"); + assert_eq!(schema.field(1).data_type(), &DataType::Int64); + assert!(schema.field(1).is_nullable()); + } + } +} diff --git a/rust/lance/src/dataset/statistics.rs b/rust/lance/src/dataset/statistics.rs new file mode 100644 index 00000000000..e2dfa34e353 --- /dev/null +++ b/rust/lance/src/dataset/statistics.rs @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Module for statistics related to the dataset. + +use std::{collections::HashMap, future::Future, sync::Arc}; + +use lance_core::Result; +use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; + +use super::{fragment::FileFragment, Dataset}; + +/// Statistics about a single field in the dataset +pub struct FieldStatistics { + /// Id of the field + pub id: u32, + /// Amount of data in the field (after compression, if any) + /// + /// This will be 0 if the data storage version is less than 2 + pub bytes_on_disk: u64, +} + +/// Statistics about the data in the dataset +pub struct DataStatistics { + /// Statistics about each field in the dataset + pub fields: Vec, +} + +pub trait DatasetStatisticsExt { + /// Get statistics about the data in the dataset + fn calculate_data_stats( + self: &Arc, + ) -> impl Future> + Send; +} + +impl DatasetStatisticsExt for Dataset { + async fn calculate_data_stats(self: &Arc) -> Result { + let field_ids = self.schema().field_ids(); + let mut field_stats: HashMap = + HashMap::from_iter(field_ids.iter().map(|id| { + ( + *id as u32, + FieldStatistics { + id: *id as u32, + bytes_on_disk: 0, + }, + ) + })); + if !self.is_legacy_storage() { + let scan_scheduler = ScanScheduler::new( + self.object_store.clone(), + SchedulerConfig::max_bandwidth(self.object_store.as_ref()), + ); + for fragment in self.fragments().as_ref() { + let file_fragment = FileFragment::new(self.clone(), fragment.clone()); + file_fragment + .update_storage_stats(&mut field_stats, self.schema(), scan_scheduler.clone()) + .await?; + } + } + let field_stats = field_ids + .into_iter() + .map(|id| field_stats.remove(&(id as u32)).unwrap()) + .collect(); + Ok(DataStatistics { + fields: field_stats, + }) + } +} diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index c390bbd45c9..da7130084da 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -6,9 +6,10 @@ use std::{collections::BTreeMap, ops::Range, pin::Pin, sync::Arc}; use crate::dataset::fragment::FragReadConfig; use crate::dataset::rowids::get_row_id_index; use crate::{Error, Result}; -use arrow::{array::as_struct_array, compute::concat_batches, datatypes::UInt64Type}; +use arrow::{compute::concat_batches, datatypes::UInt64Type}; use arrow_array::cast::AsArray; -use arrow_array::{RecordBatch, StructArray, UInt64Array}; +use arrow_array::{Array, RecordBatch, StructArray, UInt64Array}; +use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, NullBuffer}; use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; use datafusion::error::DataFusionError; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -16,9 +17,10 @@ use futures::{Future, Stream, StreamExt, TryStreamExt}; use lance_arrow::RecordBatchExt; use lance_core::datatypes::Schema; use lance_core::utils::address::RowAddress; +use lance_core::utils::deletion::OffsetMapper; use lance_core::ROW_ADDR; use lance_datafusion::projection::ProjectionPlan; -use snafu::{location, Location}; +use snafu::location; use super::ProjectionRequest; use super::{fragment::FileFragment, scanner::DatasetRecordBatchStream, Dataset}; @@ -45,29 +47,45 @@ pub async fn take( let mut frag_iter = fragments.iter(); let mut cur_frag = frag_iter.next(); let mut cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; + let mut offset_mapper = if let Some(cur_frag) = cur_frag { + let deletion_vector = cur_frag.get_deletion_vector().await?; + deletion_vector.map(OffsetMapper::new) + } else { + None + }; let mut frag_offset = 0; - let mut addrs = Vec::with_capacity(sorted_offsets.len()); + let mut addrs: Vec = Vec::with_capacity(sorted_offsets.len()); for sorted_offset in sorted_offsets.into_iter() { while cur_frag.is_some() && sorted_offset >= frag_offset + cur_frag_rows { frag_offset += cur_frag_rows; cur_frag = frag_iter.next(); cur_frag_rows = if let Some(cur_frag) = cur_frag { - cur_frag.count_rows().await? as u64 + cur_frag.count_rows(None).await? as u64 } else { 0 }; + offset_mapper = if let Some(cur_frag) = cur_frag { + let deletion_vector = cur_frag.get_deletion_vector().await?; + deletion_vector.map(OffsetMapper::new) + } else { + None + }; } let Some(cur_frag) = cur_frag else { addrs.push(RowAddress::TOMBSTONE_ROW); continue; }; - let row_addr = - RowAddress::new_from_parts(cur_frag.id() as u32, (sorted_offset - frag_offset) as u32); + + let mut local_offset = (sorted_offset - frag_offset) as u32; + if let Some(offset_mapper) = &mut offset_mapper { + local_offset = offset_mapper.map_offset(local_offset); + }; + let row_addr = RowAddress::new_from_parts(cur_frag.id() as u32, local_offset); addrs.push(u64::from(row_addr)); } @@ -132,7 +150,7 @@ async fn do_take_rows( })?; let reader = fragment - .open(&projection.physical_schema, FragReadConfig::default(), None) + .open(&projection.physical_schema, FragReadConfig::default()) .await?; reader.legacy_read_range_as_batch(range).await } else if row_addr_stats.sorted { @@ -266,9 +284,13 @@ async fn do_take_rows( // Remove the rowaddr column. let keep_indices = (0..one_batch.num_columns() - 1).collect::>(); let one_batch = one_batch.project(&keep_indices)?; + + // There's a bug in arrow_select::take::take, that it doesn't handle empty struct correctly, + // so we need to handle it manually here. + // TODO: remove this once the bug is fixed. let struct_arr: StructArray = one_batch.into(); - let reordered = arrow_select::take::take(&struct_arr, &remapping_index, None)?; - Ok(as_struct_array(&reordered).into()) + let reordered = take_struct_array(&struct_arr, &remapping_index)?; + Ok(reordered.into()) }?; let batch = projection.project_batch(batch).await?; @@ -536,6 +558,42 @@ impl TakeBuilder { } } +fn take_struct_array(array: &StructArray, indices: &UInt64Array) -> Result { + let nulls = array.nulls().map(|nulls| { + let is_valid = indices.iter().map(|index| { + if let Some(index) = index { + nulls.is_valid(index.to_usize().unwrap()) + } else { + false + } + }); + NullBuffer::new(BooleanBuffer::new( + Buffer::from_iter(is_valid), + 0, + indices.len(), + )) + }); + + if array.fields().is_empty() { + return Ok(StructArray::new_empty_fields(indices.len(), nulls)); + } + + let arrays = array + .columns() + .iter() + .map(|array| { + let array = match array.data_type() { + arrow::datatypes::DataType::Struct(_) => { + Arc::new(take_struct_array(array.as_struct(), indices)?) + } + _ => arrow_select::take::take(array, indices, None)?, + }; + Ok(array) + }) + .collect::>>()?; + Ok(StructArray::new(array.fields().clone(), arrays, nulls)) +} + #[cfg(test)] mod test { use arrow_array::{Int32Array, RecordBatchIterator, StringArray}; @@ -626,6 +684,55 @@ mod test { ); } + #[tokio::test] + async fn test_take_with_deletion() { + let data = test_batch(0..120); + let write_params = WriteParams { + max_rows_per_file: 40, + max_rows_per_group: 10, + ..Default::default() + }; + let batches = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let mut dataset = Dataset::write(batches, "memory://", Some(write_params)) + .await + .unwrap(); + + dataset.delete("i in (40, 77, 78, 79)").await.unwrap(); + + let projection = Schema::try_from(data.schema().as_ref()).unwrap(); + let values = dataset + .take( + &[ + 0, // 0 + 39, // 39 + 40, // 41 + 75, // 76 + 76, // 80 + 77, // 81 + 115, // 119 + ], + projection, + ) + .await + .unwrap(); + + assert_eq!( + RecordBatch::try_new( + data.schema(), + vec![ + Arc::new(Int32Array::from_iter_values([0, 39, 41, 76, 80, 81, 119])), + Arc::new(StringArray::from_iter_values( + [0, 39, 41, 76, 80, 81, 119] + .iter() + .map(|v| format!("str-{v}")) + )), + ], + ) + .unwrap(), + values + ); + } + #[rstest] #[tokio::test] async fn test_take_with_projection( diff --git a/rust/lance/src/dataset/transaction.rs b/rust/lance/src/dataset/transaction.rs index 558c0e9ba3d..7aab0cec806 100644 --- a/rust/lance/src/dataset/transaction.rs +++ b/rust/lance/src/dataset/transaction.rs @@ -22,21 +22,28 @@ //! a conflict. Some operations have additional conditions that must be met for //! them to be compatible. //! -//! | | Append | Delete / Update | Overwrite/Create | Create Index | Rewrite | Merge | Project | UpdateConfig | -//! |------------------|--------|-----------------|------------------|--------------|---------|-------|---------|-------------| -//! | Append | ✅ | ✅ | ⌠| ✅ | ✅ | ⌠| ⌠| ✅ | -//! | Delete / Update | ✅ | (1) | ⌠| ✅ | (1) | ⌠| ⌠| ✅ | -//! | Overwrite/Create | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | (2) | -//! | Create index | ✅ | ✅ | ⌠| ✅ | ✅ | ✅ | ✅ | ✅ | -//! | Rewrite | ✅ | (1) | ⌠| ⌠| (1) | ⌠| ⌠| ✅ | -//! | Merge | ⌠| ⌠| ⌠| ⌠| ✅ | ⌠| ⌠| ✅ | -//! | Project | ✅ | ✅ | ⌠| ⌠| ✅ | ⌠| ✅ | ✅ | -//! | UpdateConfig | ✅ | ✅ | (2) | ✅ | ✅ | ✅ | ✅ | (2) | +//! NOTE/TODO(rmeng): DataReplacement conflict resolution is not fully implemented //! -//! (1) Delete, update, and rewrite are compatible with each other and themselves only if +//! | | Append | Delete / Update | Overwrite/Create | Create Index | Rewrite | Merge | Project | UpdateConfig | DataReplacement | +//! |------------------|--------|-----------------|------------------|--------------|---------|-------|---------|--------------|-----------------| +//! | Append | ✅ | ✅ | ⌠| ✅ | ✅ | ⌠| ⌠| ✅ | ✅ +//! | Delete / Update | ✅ | 1ï¸âƒ£ | ⌠| ✅ | 1ï¸âƒ£ | ⌠| ⌠| ✅ | ✅ +//! | Overwrite/Create | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 2ï¸âƒ£ | ✅ +//! | Create index | ✅ | ✅ | ⌠| ✅ | ✅ | ✅ | ✅ | ✅ | 3ï¸âƒ£ +//! | Rewrite | ✅ | 1ï¸âƒ£ | ⌠| ⌠| 1ï¸âƒ£ | ⌠| ⌠| ✅ | 3ï¸âƒ£ +//! | Merge | ⌠| ⌠| ⌠| ⌠| ✅ | ⌠| ⌠| ✅ | ✅ +//! | Project | ✅ | ✅ | ⌠| ⌠| ✅ | ⌠| ✅ | ✅ | ✅ +//! | UpdateConfig | ✅ | ✅ | 2ï¸âƒ£ | ✅ | ✅ | ✅ | ✅ | 2ï¸âƒ£ | ✅ +//! | DataReplacement | ✅ | ✅ | ⌠| 3ï¸âƒ£ | 1ï¸âƒ£ | ✅ | 3ï¸âƒ£ | ✅ | 3ï¸âƒ£ +//! +//! 1ï¸âƒ£ Delete, update, and rewrite are compatible with each other and themselves only if //! they affect distinct fragments. Otherwise, they conflict. -//! (2) Operations that mutate the config conflict if one of the operations upserts a key -//! that if referenced by another concurrent operation. +//! 2ï¸âƒ£ Operations that mutate the config conflict if one of the operations upserts a key +//! that if referenced by another concurrent operation or if both operations modify the schema +//! metadata or the same field metadata. +//! 3ï¸âƒ£ DataReplacement on a column without index is compatible with any operation AS LONG AS +//! the operation does not modify the region of the column being replaced. +//! use std::{ collections::{HashMap, HashSet}, @@ -50,7 +57,7 @@ use lance_io::object_store::ObjectStore; use lance_table::{ format::{ pb::{self, IndexMetadata}, - DataStorageFormat, Fragment, Index, Manifest, RowIdMeta, + DataFile, DataStorageFormat, Fragment, Index, Manifest, RowIdMeta, }, io::{ commit::CommitHandler, @@ -60,7 +67,7 @@ use lance_table::{ }; use object_store::path::Path; use roaring::RoaringBitmap; -use snafu::{location, Location}; +use snafu::location; use uuid::Uuid; use super::ManifestWriteConfig; @@ -94,6 +101,9 @@ pub enum BlobsOperation { Updated(u64), } +#[derive(Debug, Clone, DeepSizeOf)] +pub struct DataReplacementGroup(pub u64, pub DataFile); + /// An operation on a dataset. #[derive(Debug, Clone, DeepSizeOf)] pub enum Operation { @@ -135,6 +145,25 @@ pub enum Operation { /// Indices that have been updated with the new row addresses rewritten_indices: Vec, }, + /// Replace data in a column in the dataset with a new data. This is used for + /// null column population where we replace an entirely null column with a + /// new column that has data. + /// + /// This operation will only allow replacing files that contain the same schema + /// e.g. if the original files contains column A, B, C and the new files contains + /// only column A, B then the operation is not allowed. As we would need to split + /// the original files into two files, one with column A, B and the other with column C. + /// + /// Corollary to the above: the operation will also not allow replacing files unless the + /// affected columns all have the same datafile layout across the fragments being replaced. + /// + /// e.g. if fragments being replaced contains files with different schema layouts on + /// the column being replaced, the operation is not allowed. + /// say frag_1: [A] [B, C] and frag_2: [A, B] [C] and we are trying to replace column A + /// with a new column A the operation is not allowed. + DataReplacement { + replacements: Vec, + }, /// Merge a new column in Merge { fragments: Vec, @@ -165,9 +194,30 @@ pub enum Operation { UpdateConfig { upsert_values: Option>, delete_keys: Option>, + schema_metadata: Option>, + field_metadata: Option>>, }, } +impl std::fmt::Display for Operation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Append { .. } => write!(f, "Append"), + Self::Delete { .. } => write!(f, "Delete"), + Self::Overwrite { .. } => write!(f, "Overwrite"), + Self::CreateIndex { .. } => write!(f, "CreateIndex"), + Self::Rewrite { .. } => write!(f, "Rewrite"), + Self::Merge { .. } => write!(f, "Merge"), + Self::Restore { .. } => write!(f, "Restore"), + Self::ReserveFragments { .. } => write!(f, "ReserveFragments"), + Self::Update { .. } => write!(f, "Update"), + Self::Project { .. } => write!(f, "Project"), + Self::UpdateConfig { .. } => write!(f, "UpdateConfig"), + Self::DataReplacement { .. } => write!(f, "DataReplacement"), + } + } +} + #[derive(Debug, Clone)] pub struct RewrittenIndex { pub old_id: Uuid, @@ -226,6 +276,7 @@ impl Operation { .map(|f| f.id) .chain(removed_fragment_ids.iter().copied()), ), + Self::DataReplacement { replacements } => Box::new(replacements.iter().map(|r| r.0)), } } @@ -268,6 +319,38 @@ impl Operation { other_ids.any(|id| self_ids.contains(&id)) } + fn modifies_same_metadata(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::UpdateConfig { + schema_metadata, + field_metadata, + .. + }, + Self::UpdateConfig { + schema_metadata: other_schema_metadata, + field_metadata: other_field_metadata, + .. + }, + ) => { + if schema_metadata.is_some() && other_schema_metadata.is_some() { + return true; + } + if let Some(field_metadata) = field_metadata { + if let Some(other_field_metadata) = other_field_metadata { + for field in field_metadata.keys() { + if other_field_metadata.contains_key(field) { + return true; + } + } + } + } + false + } + _ => false, + } + } + /// Check whether another operation upserts a key that is referenced by another operation fn upsert_key_conflict(&self, other: &Self) -> bool { let self_upsert_keys = self.get_upsert_config_keys(); @@ -297,10 +380,29 @@ impl Operation { Self::Update { .. } => "Update", Self::Project { .. } => "Project", Self::UpdateConfig { .. } => "UpdateConfig", + Self::DataReplacement { .. } => "DataReplacement", } } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ConflictResult { + /// The operation is compatible with the other operation + /// + /// For example, two operations that modify different fragments are compatible. + Compatible, + /// The operation is not compatible with the other operation + /// + /// For example, an Overwrite with a change in schema and an Append are + /// not compatible. + NotCompatible, + /// The operation is not compatible, but the current operation can be + /// retried on top of the others changes. + /// + /// For example, two operations that modify the same fragment. + Retryable, +} + impl Transaction { pub fn new( read_version: u64, @@ -320,98 +422,226 @@ impl Transaction { /// Returns true if the transaction cannot be committed if the other /// transaction is committed first. - pub fn conflicts_with(&self, other: &Self) -> bool { + pub fn conflicts_with(&self, other: &Self) -> ConflictResult { + use ConflictResult::*; // This assumes IsolationLevel is Snapshot Isolation, which is more // permissive than Serializable. In particular, it allows a Delete // transaction to succeed after a concurrent Append, even if the Append // added rows that would be deleted. match &self.operation { Operation::Append { .. } => match &other.operation { - // Append is compatible with anything that doesn't change the schema - Operation::Append { .. } => false, - Operation::Rewrite { .. } => false, - Operation::CreateIndex { .. } => false, - Operation::Delete { .. } | Operation::Update { .. } => false, - Operation::ReserveFragments { .. } => false, - Operation::Project { .. } => false, - Operation::UpdateConfig { .. } => false, - _ => true, + Operation::Append { .. } + | Operation::Rewrite { .. } + | Operation::CreateIndex { .. } + | Operation::Delete { .. } + | Operation::Update { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } + | Operation::Merge { .. } + | Operation::UpdateConfig { .. } + | Operation::DataReplacement { .. } => Compatible, + // Append is not compatible with any operation that completely + // overwrites the schema. + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, Operation::Rewrite { .. } => match &other.operation { // Rewrite is only compatible with operations that don't touch - // existing fragments. - // TODO: it could also be compatible with operations that update - // fragments we don't touch. - Operation::Append { .. } => false, - Operation::ReserveFragments { .. } => false, + // existing fragments or update fragments we don't touch. + Operation::Append { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } + | Operation::UpdateConfig { .. } => Compatible, Operation::Delete { .. } | Operation::Rewrite { .. } | Operation::Update { .. } => { // As long as they rewrite disjoint fragments they shouldn't conflict. - self.operation.modifies_same_ids(&other.operation) + if self.operation.modifies_same_ids(&other.operation) { + Retryable + } else { + Compatible + } + } + Operation::DataReplacement { .. } | Operation::Merge { .. } => { + // TODO(rmeng): check that the fragments being replaced are not part of the groups + Retryable } - Operation::Project { .. } => false, - Operation::UpdateConfig { .. } => false, - _ => true, + Operation::CreateIndex { new_indices, .. } => { + let mut affected_ids = HashSet::new(); + for index in new_indices { + if let Some(frag_bitmap) = &index.fragment_bitmap { + affected_ids.extend(frag_bitmap.iter()); + } else { + return Retryable; + } + } + if self + .operation + .modified_fragment_ids() + .any(|id| affected_ids.contains(&(id as u32))) + { + Retryable + } else { + Compatible + } + } + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, // Restore always succeeds - Operation::Restore { .. } => false, + Operation::Restore { .. } => Compatible, // ReserveFragments is compatible with anything that doesn't reset the // max fragment id. - Operation::ReserveFragments { .. } => matches!( - &other.operation, - Operation::Overwrite { .. } | Operation::Restore { .. } - ), - Operation::CreateIndex { .. } => match &other.operation { - Operation::Append { .. } => false, + Operation::ReserveFragments { .. } => match &other.operation { + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, + _ => Compatible, + }, + Operation::CreateIndex { new_indices, .. } => match &other.operation { + Operation::Append { .. } => Compatible, // Indices are identified by UUIDs, so they shouldn't conflict. - Operation::CreateIndex { .. } => false, + Operation::CreateIndex { .. } => Compatible, // Although some of the rows we indexed may have been deleted / moved, // row ids are still valid, so we allow this optimistically. - Operation::Delete { .. } | Operation::Update { .. } => false, - // Merge & reserve don't change row ids, so this should be fine. - Operation::Merge { .. } => false, - Operation::ReserveFragments { .. } => false, - // Rewrite likely changed many of the row ids, so our index is - // likely useless. It should be rebuilt. - // TODO: we could be smarter here and only invalidate the index - // if the rewrite changed more than X% of row ids. - Operation::Rewrite { .. } => true, - Operation::UpdateConfig { .. } => false, - _ => true, + Operation::Delete { .. } | Operation::Update { .. } => Compatible, + // Merge, reserve, and project don't change row ids, so this should be fine. + Operation::Merge { .. } => Compatible, + Operation::ReserveFragments { .. } => Compatible, + Operation::Project { .. } => Compatible, + // Should be compatible with rewrite if it didn't move the rows + // we indexed. If it did, we could retry. + // TODO: this will change with stable row ids. + Operation::Rewrite { .. } => { + let mut affected_ids = HashSet::new(); + for index in new_indices { + if let Some(frag_bitmap) = &index.fragment_bitmap { + affected_ids.extend(frag_bitmap.iter()); + } else { + return Retryable; + } + } + if other + .operation + .modified_fragment_ids() + .any(|id| affected_ids.contains(&(id as u32))) + { + Retryable + } else { + Compatible + } + } + Operation::UpdateConfig { .. } => Compatible, + Operation::DataReplacement { .. } => { + // TODO(rmeng): check that the new indices isn't on the column being replaced + Retryable + } + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, Operation::Delete { .. } | Operation::Update { .. } => match &other.operation { - Operation::CreateIndex { .. } => false, - Operation::ReserveFragments { .. } => false, - Operation::Delete { .. } | Operation::Rewrite { .. } | Operation::Update { .. } => { + Operation::CreateIndex { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } + | Operation::Append { .. } + | Operation::UpdateConfig { .. } => Compatible, + Operation::Delete { .. } + | Operation::Rewrite { .. } + | Operation::Update { .. } + | Operation::DataReplacement { .. } => { // If we update the same fragments, we conflict. - self.operation.modifies_same_ids(&other.operation) + if self.operation.modifies_same_ids(&other.operation) { + Retryable + } else { + Compatible + } + } + Operation::Merge { .. } => Retryable, + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, + }, + Operation::Overwrite { .. } => match &other.operation { + // Overwrite only conflicts with another operation modifying the same update config + Operation::Overwrite { .. } | Operation::UpdateConfig { .. } + if self.operation.upsert_key_conflict(&other.operation) => + { + NotCompatible } - Operation::Project { .. } => false, - Operation::Append { .. } => false, - Operation::UpdateConfig { .. } => false, - _ => true, + _ => Compatible, }, - Operation::Overwrite { .. } | Operation::UpdateConfig { .. } => { - match &other.operation { - Operation::Overwrite { .. } | Operation::UpdateConfig { .. } => { - self.operation.upsert_key_conflict(&other.operation) + Operation::UpdateConfig { + schema_metadata, + field_metadata, + .. + } => match &other.operation { + Operation::Overwrite { .. } => { + // Updates to schema metadata or field metadata conflict with any kind + // of overwrite. + if schema_metadata.is_some() + || field_metadata.is_some() + || self.operation.upsert_key_conflict(&other.operation) + { + NotCompatible + } else { + Compatible } - _ => false, } - } + Operation::UpdateConfig { .. } => { + if self.operation.upsert_key_conflict(&other.operation) + | self.operation.modifies_same_metadata(&other.operation) + { + NotCompatible + } else { + Compatible + } + } + _ => Compatible, + }, // Merge changes the schema, but preserves row ids, so the only operations // it's compatible with is CreateIndex, ReserveFragments, SetMetadata and DeleteMetadata. - Operation::Merge { .. } => !matches!( - &other.operation, + Operation::Merge { .. } => match &other.operation { Operation::CreateIndex { .. } - | Operation::ReserveFragments { .. } - | Operation::UpdateConfig { .. } - ), + | Operation::ReserveFragments { .. } + | Operation::UpdateConfig { .. } => Compatible, + Operation::Update { .. } + | Operation::Append { .. } + | Operation::Delete { .. } + | Operation::Rewrite { .. } + | Operation::Merge { .. } + | Operation::DataReplacement { .. } => Retryable, + Operation::Overwrite { .. } + | Operation::Restore { .. } + | Operation::Project { .. } => NotCompatible, + }, Operation::Project { .. } => match &other.operation { // Project is compatible with anything that doesn't change the schema - Operation::CreateIndex { .. } => false, - Operation::Overwrite { .. } => false, - Operation::UpdateConfig { .. } => false, - _ => true, + Operation::Append { .. } + | Operation::Update { .. } + | Operation::Delete { .. } + | Operation::UpdateConfig { .. } + | Operation::CreateIndex { .. } + | Operation::DataReplacement { .. } + | Operation::Rewrite { .. } + | Operation::ReserveFragments { .. } => Compatible, + Operation::Merge { .. } | Operation::Project { .. } => { + // Need to recompute the schema + Retryable + } + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, + }, + Operation::DataReplacement { .. } => match &other.operation { + Operation::Append { .. } + | Operation::Delete { .. } + | Operation::Update { .. } + | Operation::Merge { .. } + | Operation::UpdateConfig { .. } + | Operation::ReserveFragments { .. } + | Operation::Project { .. } => Compatible, + Operation::CreateIndex { .. } => { + // TODO(rmeng): check that the new indices isn't on the column being replaced + NotCompatible + } + Operation::Rewrite { .. } => { + // TODO(rmeng): check that the fragments being replaced are not part of the groups + NotCompatible + } + Operation::DataReplacement { .. } => { + // TODO(rmeng): check cell conflicts + NotCompatible + } + Operation::Overwrite { .. } | Operation::Restore { .. } => NotCompatible, }, } } @@ -656,7 +886,7 @@ impl Transaction { }); final_indices.extend(new_indices.clone()); } - Operation::ReserveFragments { .. } => { + Operation::ReserveFragments { .. } | Operation::UpdateConfig { .. } => { final_fragments.extend(maybe_existing_fragments?.clone()); } Operation::Merge { ref fragments, .. } => { @@ -690,7 +920,110 @@ impl Transaction { Operation::Restore { .. } => { unreachable!() } - Operation::UpdateConfig { .. } => {} + Operation::DataReplacement { replacements } => { + log::warn!("Building manifest with DataReplacement operation. This operation is not stable yet, please use with caution."); + + let (old_fragment_ids, new_datafiles): (Vec<&u64>, Vec<&DataFile>) = replacements + .iter() + .map(|DataReplacementGroup(fragment_id, new_file)| (fragment_id, new_file)) + .unzip(); + + // 1. make sure the new files all have the same fields / or empty + // NOTE: arguably this requirement could be relaxed in the future + // for the sake of simplicity, we require the new files to have the same fields + if new_datafiles + .iter() + .map(|f| f.fields.clone()) + .collect::>() + .len() + > 1 + { + let field_info = new_datafiles + .iter() + .enumerate() + .map(|(id, f)| (id, f.fields.clone())) + .fold("".to_string(), |acc, (id, fields)| { + format!("{}File {}: {:?}\n", acc, id, fields) + }); + + return Err(Error::invalid_input( + format!( + "All new data files must have the same fields, but found different fields:\n{field_info}" + ), + location!(), + )); + } + + let existing_fragments = maybe_existing_fragments?; + + // 2. check that the fragments being modified have isomorphic layouts along the columns being replaced + // 3. add modified fragments to final_fragments + for (frag_id, new_file) in old_fragment_ids.iter().zip(new_datafiles) { + let frag = existing_fragments + .iter() + .find(|f| f.id == **frag_id) + .ok_or_else(|| { + Error::invalid_input( + "Fragment being replaced not found in existing fragments", + location!(), + ) + })?; + let mut new_frag = frag.clone(); + + // TODO(rmeng): check new file and fragment are the same length + + let mut columns_covered = HashSet::new(); + for file in &mut new_frag.files { + if file.fields == new_file.fields + && file.file_major_version == new_file.file_major_version + && file.file_minor_version == new_file.file_minor_version + { + // assign the new file path to the fragment + file.path = new_file.path.clone(); + } + columns_covered.extend(file.fields.iter()); + } + // SPECIAL CASE: if the column(s) being replaced are not covered by the fragment + // Then it means it's a all-NULL column that is being replaced with real data + // just add it to the final fragments + if columns_covered.is_disjoint(&new_file.fields.iter().collect()) { + new_frag.add_file( + new_file.path.clone(), + new_file.fields.clone(), + new_file.column_indices.clone(), + &LanceFileVersion::try_from_major_minor( + new_file.file_major_version, + new_file.file_minor_version, + ) + .expect("Expected valid file version"), + ); + } + + // Nothing changed in the current fragment, which is not expected -- error out + if &new_frag == frag { + return Err(Error::invalid_input( + "Expected to modify the fragment but no changes were made. This means the new data files does not align with any exiting datafiles. Please check if the schema of the new data files matches the schema of the old data files including the file major and minor versions", + location!(), + )); + } + final_fragments.push(new_frag); + } + + let fragments_changed = old_fragment_ids + .iter() + .cloned() + .cloned() + .collect::>(); + + // 4. push fragments that didn't change back to final_fragments + let unmodified_fragments = existing_fragments + .iter() + .filter(|f| !fragments_changed.contains(&f.id)) + .cloned() + .collect::>(); + + final_fragments.extend(unmodified_fragments); + } }; // If a fragment was reserved then it may not belong at the end of the fragments list. @@ -748,6 +1081,8 @@ impl Transaction { Operation::UpdateConfig { upsert_values, delete_keys, + schema_metadata, + field_metadata, } => { // Delete is handled first. If the same key is referenced by upsert and // delete, then upserted key-value pair will remain. @@ -763,6 +1098,14 @@ impl Transaction { if let Some(upsert_values) = upsert_values { manifest.update_config(upsert_values.clone()); } + if let Some(schema_metadata) = schema_metadata { + manifest.update_schema_metadata(schema_metadata.clone()); + } + if let Some(field_metadata) = field_metadata { + for (field_id, metadata) in field_metadata { + manifest.update_field_metadata(*field_id as i32, metadata.clone()); + } + } } _ => {} } @@ -936,6 +1279,34 @@ impl Transaction { } } +impl From<&DataReplacementGroup> for pb::transaction::DataReplacementGroup { + fn from(DataReplacementGroup(fragment_id, new_file): &DataReplacementGroup) -> Self { + Self { + fragment_id: *fragment_id, + new_file: Some(new_file.into()), + } + } +} + +/// Convert a protobug DataReplacementGroup to a rust native DataReplacementGroup +/// this is unfortunately TryFrom instead of From because of the Option in the pb::DataReplacementGroup +impl TryFrom for DataReplacementGroup { + type Error = Error; + + fn try_from(message: pb::transaction::DataReplacementGroup) -> Result { + Ok(Self( + message.fragment_id, + message + .new_file + .ok_or(Error::invalid_input( + "DataReplacementGroup must have a new_file", + location!(), + ))? + .try_into()?, + )) + } +} + impl TryFrom for Transaction { type Error = Error; @@ -1068,6 +1439,8 @@ impl TryFrom for Transaction { Some(pb::transaction::Operation::UpdateConfig(pb::transaction::UpdateConfig { upsert_values, delete_keys, + schema_metadata, + field_metadata, })) => { let upsert_values = match upsert_values.len() { 0 => None, @@ -1077,11 +1450,36 @@ impl TryFrom for Transaction { 0 => None, _ => Some(delete_keys), }; + let schema_metadata = match schema_metadata.len() { + 0 => None, + _ => Some(schema_metadata), + }; + let field_metadata = match field_metadata.len() { + 0 => None, + _ => Some( + field_metadata + .into_iter() + .map(|(field_id, field_meta_update)| { + (field_id, field_meta_update.metadata) + }) + .collect(), + ), + }; Operation::UpdateConfig { upsert_values, delete_keys, + schema_metadata, + field_metadata, } } + Some(pb::transaction::Operation::DataReplacement( + pb::transaction::DataReplacement { replacements }, + )) => Operation::DataReplacement { + replacements: replacements + .into_iter() + .map(DataReplacementGroup::try_from) + .collect::>>()?, + }, None => { return Err(Error::Internal { message: "Transaction message did not contain an operation".to_string(), @@ -1275,10 +1673,37 @@ impl From<&Transaction> for pb::Transaction { Operation::UpdateConfig { upsert_values, delete_keys, + schema_metadata, + field_metadata, } => pb::transaction::Operation::UpdateConfig(pb::transaction::UpdateConfig { upsert_values: upsert_values.clone().unwrap_or(Default::default()), delete_keys: delete_keys.clone().unwrap_or(Default::default()), + schema_metadata: schema_metadata.clone().unwrap_or(Default::default()), + field_metadata: field_metadata + .as_ref() + .map(|field_metadata| { + field_metadata + .iter() + .map(|(field_id, metadata)| { + ( + *field_id, + pb::transaction::update_config::FieldMetadataUpdate { + metadata: metadata.clone(), + }, + ) + }) + .collect() + }) + .unwrap_or(Default::default()), }), + Operation::DataReplacement { replacements } => { + pb::transaction::Operation::DataReplacement(pb::transaction::DataReplacement { + replacements: replacements + .iter() + .map(pb::transaction::DataReplacementGroup::from) + .collect(), + }) + } }; let blob_operation = value.blobs_op.as_ref().map(|op| match op { @@ -1427,6 +1852,8 @@ mod tests { #[test] fn test_conflicts() { + use ConflictResult::*; + let index0 = Index { uuid: uuid::Uuid::new_v4(), name: "test".to_string(), @@ -1483,6 +1910,14 @@ mod tests { "value".to_string(), )])), delete_keys: Some(vec!["remove-key".to_string()]), + schema_metadata: Some(HashMap::from_iter(vec![( + "schema-key".to_string(), + "schema-value".to_string(), + )])), + field_metadata: Some(HashMap::from_iter(vec![( + 0, + HashMap::from_iter(vec![("field-key".to_string(), "field-value".to_string())]), + )])), }, ]; let other_transactions = other_operations @@ -1497,7 +1932,17 @@ mod tests { Operation::Append { fragments: vec![fragment0.clone()], }, - [false, false, false, true, true, false, false, false, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ( Operation::Delete { @@ -1506,7 +1951,17 @@ mod tests { deleted_fragment_ids: vec![], predicate: "x > 2".to_string(), }, - [false, false, false, true, true, false, false, true, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Retryable, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::Delete { @@ -1515,7 +1970,17 @@ mod tests { deleted_fragment_ids: vec![], predicate: "x > 2".to_string(), }, - [false, false, true, true, true, true, false, true, false], + [ + Compatible, // append + Compatible, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::Overwrite { @@ -1525,9 +1990,7 @@ mod tests { }, // No conflicts: overwrite can always happen since it doesn't // depend on previous state of the table. - [ - false, false, false, false, false, false, false, false, false, - ], + [Compatible; 9], ), ( Operation::CreateIndex { @@ -1535,7 +1998,17 @@ mod tests { removed_indices: vec![index0], }, // Will only conflict with operations that modify row ids. - [false, false, false, false, true, true, false, false, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ( // Rewrite that affects different fragments @@ -1546,7 +2019,17 @@ mod tests { }], rewritten_indices: Vec::new(), }, - [false, true, false, true, true, false, false, true, false], + [ + Compatible, // append + Retryable, // create index + Compatible, // delete + Retryable, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( // Rewrite that affects the same fragments @@ -1557,7 +2040,17 @@ mod tests { }], rewritten_indices: Vec::new(), }, - [false, true, true, true, true, true, false, true, false], + [ + Compatible, // append + Retryable, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::Merge { @@ -1565,12 +2058,32 @@ mod tests { schema: Schema::default(), }, // Merge conflicts with everything except CreateIndex and ReserveFragments. - [true, false, true, true, true, true, false, true, false], + [ + Retryable, // append + Compatible, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( Operation::ReserveFragments { num_fragments: 2 }, // ReserveFragments only conflicts with Overwrite and Restore. - [false, false, false, false, true, false, false, false, false], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ( Operation::Update { @@ -1579,7 +2092,17 @@ mod tests { removed_fragment_ids: vec![], new_fragments: vec![fragment2], }, - [false, false, true, true, true, true, false, true, false], + [ + Compatible, // append + Compatible, // create index + Retryable, // delete + Retryable, // merge + NotCompatible, // overwrite + Retryable, // rewrite + Compatible, // reserve + Retryable, // update + Compatible, // update config + ], ), ( // Update config that should not conflict with anything @@ -1589,10 +2112,10 @@ mod tests { "new-value".to_string(), )])), delete_keys: None, + schema_metadata: None, + field_metadata: None, }, - [ - false, false, false, false, false, false, false, false, false, - ], + [Compatible; 9], ), ( // Update config that conflicts with key being upserted by other UpdateConfig operation @@ -1602,8 +2125,20 @@ mod tests { "new-value".to_string(), )])), delete_keys: None, + schema_metadata: None, + field_metadata: None, }, - [false, false, false, false, false, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + Compatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], ), ( // Update config that conflicts with key being deleted by other UpdateConfig operation @@ -1613,26 +2148,127 @@ mod tests { "new-value".to_string(), )])), delete_keys: None, + schema_metadata: None, + field_metadata: None, }, - [false, false, false, false, false, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + Compatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], ), ( // Delete config keys currently being deleted by other UpdateConfig operation Operation::UpdateConfig { upsert_values: None, delete_keys: Some(vec!["remove-key".to_string()]), + schema_metadata: None, + field_metadata: None, }, - [ - false, false, false, false, false, false, false, false, false, - ], + [Compatible; 9], ), ( // Delete config keys currently being upserted by other UpdateConfig operation Operation::UpdateConfig { upsert_values: None, delete_keys: Some(vec!["lance.test".to_string()]), + schema_metadata: None, + field_metadata: None, + }, + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + Compatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], + ), + ( + // Changing schema metadata conflicts with another update changing schema + // metadata or with an overwrite + Operation::UpdateConfig { + upsert_values: None, + delete_keys: None, + schema_metadata: Some(HashMap::from_iter(vec![( + "schema-key".to_string(), + "new-value".to_string(), + )])), + field_metadata: None, }, - [false, false, false, false, false, false, false, false, true], + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], + ), + ( + // Changing field metadata conflicts with another update changing same field + // metadata or overwrite + Operation::UpdateConfig { + upsert_values: None, + delete_keys: None, + schema_metadata: None, + field_metadata: Some(HashMap::from_iter(vec![( + 0, + HashMap::from_iter(vec![( + "field_key".to_string(), + "field_value".to_string(), + )]), + )])), + }, + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + NotCompatible, // update config + ], + ), + ( + // Updates to different field metadata are allowed + Operation::UpdateConfig { + upsert_values: None, + delete_keys: None, + schema_metadata: None, + field_metadata: Some(HashMap::from_iter(vec![( + 1, + HashMap::from_iter(vec![( + "field_key".to_string(), + "field_value".to_string(), + )]), + )])), + }, + [ + Compatible, // append + Compatible, // create index + Compatible, // delete + Compatible, // merge + NotCompatible, // overwrite + Compatible, // rewrite + Compatible, // reserve + Compatible, // update + Compatible, // update config + ], ), ]; @@ -1642,13 +2278,9 @@ mod tests { assert_eq!( transaction.conflicts_with(other), *expected_conflict, - "Transaction {:?} should {} with {:?}", + "Transaction {:?} should {:?} with {:?}", transaction, - if *expected_conflict { - "conflict" - } else { - "not conflict" - }, + expected_conflict, other ); } diff --git a/rust/lance/src/dataset/updater.rs b/rust/lance/src/dataset/updater.rs index f12b201de88..750cfb6eec3 100644 --- a/rust/lance/src/dataset/updater.rs +++ b/rust/lance/src/dataset/updater.rs @@ -3,11 +3,12 @@ use arrow_array::{RecordBatch, UInt32Array}; use futures::StreamExt; +use lance_core::datatypes::{OnMissing, OnTypeMismatch}; use lance_core::utils::deletion::DeletionVector; use lance_core::{datatypes::Schema, Error, Result}; use lance_table::format::Fragment; use lance_table::utils::stream::ReadBatchFutStream; -use snafu::{location, Location}; +use snafu::location; use super::fragment::FragmentReader; use super::scanner::get_default_batch_size; @@ -182,12 +183,11 @@ impl Updater { final_schema.set_field_id(Some(self.fragment.dataset().manifest.max_field_id())); self.final_schema = Some(final_schema); self.final_schema.as_ref().unwrap().validate()?; - self.write_schema = Some( - self.final_schema - .as_ref() - .unwrap() - .project_by_schema(output_schema.as_ref())?, - ); + self.write_schema = Some(self.final_schema.as_ref().unwrap().project_by_schema( + output_schema.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?); } self.writer = Some( diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index 600176fae64..5f90691c898 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -4,11 +4,18 @@ use std::sync::Arc; use arrow_array::RecordBatch; +use chrono::TimeDelta; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::SendableRecordBatchStream; -use futures::{StreamExt, TryStreamExt}; -use lance_core::datatypes::{NullabilityComparison, SchemaCompareOptions, StorageClass}; +use futures::{Stream, StreamExt, TryStreamExt}; +use lance_core::datatypes::{ + NullabilityComparison, OnMissing, OnTypeMismatch, SchemaCompareOptions, StorageClass, +}; +use lance_core::error::LanceOptionExt; +use lance_core::utils::tracing::{AUDIT_MODE_CREATE, AUDIT_TYPE_DATA, TRACE_FILE_AUDIT}; use lance_core::{datatypes::Schema, Error, Result}; use lance_datafusion::chunker::{break_stream, chunk_stream}; +use lance_datafusion::spill::{create_replay_spill, SpillReceiver, SpillSender}; use lance_datafusion::utils::StreamingWriteSource; use lance_file::v2; use lance_file::v2::writer::FileWriterOptions; @@ -19,8 +26,8 @@ use lance_table::format::{DataFile, Fragment}; use lance_table::io::commit::{commit_handler_from_url, CommitHandler}; use lance_table::io::manifest::ManifestDescribing; use object_store::path::Path; -use snafu::{location, Location}; -use tracing::instrument; +use snafu::location; +use tracing::{info, instrument}; use uuid::Uuid; use crate::session::Session; @@ -108,6 +115,22 @@ impl TryFrom<&str> for WriteMode { } } +/// Auto cleanup parameters +#[derive(Debug, Clone)] +pub struct AutoCleanupParams { + pub interval: usize, + pub older_than: TimeDelta, +} + +impl Default for AutoCleanupParams { + fn default() -> Self { + Self { + interval: 20, + older_than: TimeDelta::days(14), + } + } +} + /// Dataset Write Parameters #[derive(Debug, Clone)] pub struct WriteParams { @@ -167,9 +190,15 @@ pub struct WriteParams { /// Default is False. pub enable_v2_manifest_paths: bool, - pub object_store_registry: Arc, - pub session: Option>, + + /// If Some and this is a new dataset, old dataset versions will be + /// automatically cleaned up according to the parameters set out in + /// `AutoCleanupParams`. This parameter has no effect on existing datasets. + /// To add autocleaning to an existing dataset, use Dataset::update_config + /// to set lance.auto_cleanup.interval and lance.auto_cleanup.older_than. + /// Both parameters must be set to invoke autocleaning. + pub auto_cleanup: Option, } impl Default for WriteParams { @@ -187,8 +216,8 @@ impl Default for WriteParams { data_storage_version: None, enable_move_stable_row_ids: false, enable_v2_manifest_paths: false, - object_store_registry: Arc::new(ObjectStoreRegistry::default()), session: None, + auto_cleanup: Some(AutoCleanupParams::default()), } } } @@ -206,6 +235,13 @@ impl WriteParams { pub fn storage_version_or_default(&self) -> LanceFileVersion { self.data_storage_version.unwrap_or_default() } + + pub fn store_registry(&self) -> Arc { + self.session + .as_ref() + .map(|s| s.store_registry()) + .unwrap_or_default() + } } /// Writes the given data to the dataset and returns fragments. @@ -270,6 +306,7 @@ pub async fn do_write_fragments( || writer.as_mut().unwrap().tell().await? >= params.max_bytes_per_file as u64 { let (num_rows, data_file) = writer.take().unwrap().finish().await?; + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_CREATE, type=AUDIT_TYPE_DATA, path = &data_file.path); debug_assert_eq!(num_rows, num_rows_in_current_file); params.progress.complete(fragments.last().unwrap()).await?; let last_fragment = fragments.last_mut().unwrap(); @@ -282,6 +319,7 @@ pub async fn do_write_fragments( // Complete the final writer if let Some(mut writer) = writer.take() { let (num_rows, data_file) = writer.finish().await?; + info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_CREATE, type=AUDIT_TYPE_DATA, path = &data_file.path); let last_fragment = fragments.last_mut().unwrap(); last_fragment.physical_rows = Some(num_rows as usize); last_fragment.files.push(data_file); @@ -335,7 +373,11 @@ pub async fn write_fragments_internal( }, )?; // Project from the dataset schema, because it has the correct field ids. - let write_schema = dataset.schema().project_by_schema(&schema)?; + let write_schema = dataset.schema().project_by_schema( + &schema, + OnMissing::Error, + OnTypeMismatch::Error, + )?; // Use the storage version from the dataset, ignoring any version from the user. let data_storage_version = dataset .manifest() @@ -362,7 +404,11 @@ pub async fn write_fragments_internal( (schema, params.storage_version_or_default()) }; - let data_schema = schema.project_by_schema(data.schema().as_ref())?; + let data_schema = schema.project_by_schema( + data.schema().as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?; let (data, blob_data) = data.extract_blob_stream(&data_schema); @@ -374,7 +420,6 @@ pub async fn write_fragments_internal( enable_move_stable_row_ids: true, // This shouldn't really matter since all commits are detached enable_v2_manifest_paths: true, - object_store_registry: params.object_store_registry.clone(), max_bytes_per_file: params.max_bytes_per_file, max_rows_per_file: params.max_rows_per_file, ..Default::default() @@ -578,6 +623,7 @@ async fn resolve_commit_handler( ) -> Result> { match commit_handler { None => { + #[allow(deprecated)] if store_options .as_ref() .map(|opts| opts.object_store.is_some()) @@ -601,6 +647,112 @@ async fn resolve_commit_handler( } } +/// Create an iterator of record batch streams from the given source. +/// +/// If `enable_retries` is true, then the source will be saved either in memory +/// or spilled to disk to allow replaying the source in case of a failure. The +/// source will be kept in memory if either (1) the size hint shows that +/// there is only one batch or (2) the stream contains less than 100MB of +/// data. Otherwise, the source will be spilled to a temporary file on disk. +/// +/// This is used to support retries on write operations. +async fn new_source_iter( + source: SendableRecordBatchStream, + enable_retries: bool, +) -> Result + Send + 'static>> { + if enable_retries { + let schema = source.schema(); + + // If size hint shows there is only one batch, spilling has no benefit, just keep that + // in memory. (This is a pretty common case.) + let size_hint = source.size_hint(); + if size_hint.0 == 1 && size_hint.1 == Some(1) { + let batches: Vec = source.try_collect().await?; + Ok(Box::new(std::iter::repeat_with(move || { + Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::iter(batches.clone().into_iter().map(Ok)), + )) as SendableRecordBatchStream + }))) + } else { + // Allow buffering up to 100MB in memory before spilling to disk. + Ok(Box::new( + SpillStreamIter::try_new(source, 100 * 1024 * 1024).await?, + )) + } + } else { + Ok(Box::new(std::iter::once(source))) + } +} + +struct SpillStreamIter { + receiver: SpillReceiver, + #[allow(dead_code)] // Exists to keep the SpillSender alive + sender_handle: tokio::task::JoinHandle, + // This temp dir is used to store the spilled data. It is kept alive by + // this struct. When this struct is dropped, the Drop implementation of + // tempfile::TempDir will delete the temp dir. + #[allow(dead_code)] // Exists to keep the temp dir alive + tmp_dir: tempfile::TempDir, +} + +impl SpillStreamIter { + pub async fn try_new( + mut source: SendableRecordBatchStream, + memory_limit: usize, + ) -> Result { + let tmp_dir = tokio::task::spawn_blocking(|| { + tempfile::tempdir().map_err(|e| Error::InvalidInput { + source: format!("Failed to create temp dir: {}", e).into(), + location: location!(), + }) + }) + .await + .ok() + .expect_ok()??; + + let tmp_path = tmp_dir.path().join("spill.arrows"); + let (mut sender, receiver) = create_replay_spill(tmp_path, source.schema(), memory_limit); + + let sender_handle = tokio::task::spawn(async move { + while let Some(res) = source.next().await { + match res { + Ok(batch) => match sender.write(batch).await { + Ok(_) => {} + Err(e) => { + sender.send_error(e); + break; + } + }, + Err(e) => { + sender.send_error(e); + break; + } + } + } + + if let Err(err) = sender.finish().await { + sender.send_error(err); + } + sender + }); + + Ok(Self { + receiver, + tmp_dir, + sender_handle, + }) + } +} + +impl Iterator for SpillStreamIter { + type Item = SendableRecordBatchStream; + + fn next(&mut self) -> Option { + Some(self.receiver.read()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance/src/dataset/write/commit.rs b/rust/lance/src/dataset/write/commit.rs index 4b7fb4cea79..b80addd9f3f 100644 --- a/rust/lance/src/dataset/write/commit.rs +++ b/rust/lance/src/dataset/write/commit.rs @@ -4,12 +4,12 @@ use std::sync::Arc; use lance_file::version::LanceFileVersion; -use lance_io::object_store::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry}; +use lance_io::object_store::{ObjectStore, ObjectStoreParams}; use lance_table::{ format::{is_detached_version, DataStorageFormat}, io::commit::{CommitConfig, CommitHandler, ManifestNamingScheme}, }; -use snafu::{location, Location}; +use snafu::location; use crate::{ dataset::{ @@ -36,7 +36,6 @@ pub struct CommitBuilder<'a> { storage_format: Option, commit_handler: Option>, store_params: Option, - object_store_registry: Arc, object_store: Option>, session: Option>, detached: bool, @@ -52,7 +51,6 @@ impl<'a> CommitBuilder<'a> { storage_format: None, commit_handler: None, store_params: None, - object_store_registry: Default::default(), object_store: None, session: None, detached: false, @@ -104,17 +102,6 @@ impl<'a> CommitBuilder<'a> { self } - /// Pass an object store registry to use. - /// - /// If an object store is passed, this registry will be ignored. - pub fn with_object_store_registry( - mut self, - object_store_registry: Arc, - ) -> Self { - self.object_store_registry = object_store_registry; - self - } - /// Pass a session to use for the dataset. /// /// If a session is not passed, but a dataset is used as the destination, @@ -162,6 +149,11 @@ impl<'a> CommitBuilder<'a> { } pub async fn execute(self, transaction: Transaction) -> Result { + let session = self + .session + .or_else(|| self.dest.dataset().map(|ds| ds.session.clone())) + .unwrap_or_default(); + let (object_store, base_path, commit_handler) = match &self.dest { WriteDestination::Dataset(dataset) => ( dataset.object_store.clone(), @@ -170,12 +162,12 @@ impl<'a> CommitBuilder<'a> { ), WriteDestination::Uri(uri) => { let (object_store, base_path) = ObjectStore::from_uri_and_params( - self.object_store_registry.clone(), + session.store_registry(), uri, &self.store_params.clone().unwrap_or_default(), ) .await?; - let mut object_store = Arc::new(object_store); + let mut object_store = object_store; let commit_handler = if self.commit_handler.is_some() && self.object_store.is_some() { self.commit_handler.as_ref().unwrap().clone() @@ -190,11 +182,6 @@ impl<'a> CommitBuilder<'a> { } }; - let session = self - .session - .or_else(|| self.dest.dataset().map(|ds| ds.session.clone())) - .unwrap_or_default(); - let dest = match &self.dest { WriteDestination::Dataset(dataset) => WriteDestination::Dataset(dataset.clone()), WriteDestination::Uri(uri) => { @@ -203,7 +190,6 @@ impl<'a> CommitBuilder<'a> { .with_read_params(ReadParams { store_options: self.store_params.clone(), commit_handler: self.commit_handler.clone(), - object_store_registry: self.object_store_registry.clone(), ..Default::default() }) .with_session(session.clone()); @@ -273,7 +259,7 @@ impl<'a> CommitBuilder<'a> { ..Default::default() }; - let (manifest, manifest_file) = if let Some(dataset) = dest.dataset() { + let (manifest, manifest_file, manifest_e_tag) = if let Some(dataset) = dest.dataset() { if self.detached { if matches!(manifest_naming_scheme, ManifestNamingScheme::V1) { return Err(Error::NotSupported { @@ -332,6 +318,7 @@ impl<'a> CommitBuilder<'a> { manifest: Arc::new(manifest), manifest_file, session, + manifest_e_tag, ..dataset.as_ref().clone() }), WriteDestination::Uri(uri) => Ok(Dataset { @@ -344,6 +331,7 @@ impl<'a> CommitBuilder<'a> { commit_handler, tags, manifest_naming_scheme, + manifest_e_tag, }), } } @@ -421,11 +409,7 @@ pub struct BatchCommitResult { mod tests { use arrow::array::{Int32Array, RecordBatch}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; - use lance_table::{ - format::{DataFile, Fragment}, - io::commit::RenameCommitHandler, - }; - use url::Url; + use lance_table::format::{DataFile, Fragment}; use crate::dataset::{InsertBuilder, WriteParams}; @@ -464,6 +448,7 @@ mod tests { // Need to use in-memory for accurate IOPS tracking. use crate::utils::test::IoTrackingStore; + let session = Arc::new(Session::default()); // Create new dataset let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( "i", @@ -475,17 +460,15 @@ mod tests { vec![Arc::new(Int32Array::from_iter_values(0..10_i32))], ) .unwrap(); - let memory_store = Arc::new(object_store::memory::InMemory::new()); let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); let store_params = ObjectStoreParams { object_store_wrapper: Some(io_stats_wrapper), - object_store: Some((memory_store.clone(), Url::parse("memory://test").unwrap())), ..Default::default() }; let dataset = InsertBuilder::new("memory://test") .with_params(&WriteParams { store_params: Some(store_params.clone()), - commit_handler: Some(Arc::new(RenameCommitHandler)), + session: Some(session.clone()), ..Default::default() }) .execute(vec![batch]) @@ -523,17 +506,15 @@ mod tests { // resolution. let (reads, writes) = get_new_iops(); assert_eq!(reads, 1, "i = {}", i); - // Should see 3 IOPs: + // Should see 2 IOPs: // 1. Write the transaction files - // 2. Write the manifest - // 3. Atomically rename the manifest - assert_eq!(writes, 3, "i = {}", i); + // 2. Write (conditional put) the manifest + assert_eq!(writes, 2, "i = {}", i); } // Commit transaction with URI and session let new_ds = CommitBuilder::new("memory://test") .with_store_params(store_params.clone()) - .with_commit_handler(Arc::new(RenameCommitHandler)) .with_session(dataset.session.clone()) .execute(sample_transaction(1)) .await @@ -544,12 +525,14 @@ mod tests { // are needed. let (reads, writes) = get_new_iops(); assert_eq!(reads, 3); - assert_eq!(writes, 3); + assert_eq!(writes, 2); - // Commit transaction with URI and no session + // Commit transaction with URI and new session. Re-use the store + // registry so we see the same store. + let new_session = Arc::new(Session::new(0, 0, session.store_registry())); let new_ds = CommitBuilder::new("memory://test") .with_store_params(store_params) - .with_commit_handler(Arc::new(RenameCommitHandler)) + .with_session(new_session) .execute(sample_transaction(1)) .await .unwrap(); @@ -557,7 +540,7 @@ mod tests { // Now we have to load all previous transactions. let (reads, writes) = get_new_iops(); assert!(reads > 20); - assert_eq!(writes, 3); + assert_eq!(writes, 2); } #[tokio::test] diff --git a/rust/lance/src/dataset/write/insert.rs b/rust/lance/src/dataset/write/insert.rs index 89bc008e28b..a73ecd9e0ea 100644 --- a/rust/lance/src/dataset/write/insert.rs +++ b/rust/lance/src/dataset/write/insert.rs @@ -1,11 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::sync::Arc; use arrow_array::RecordBatch; use arrow_array::RecordBatchIterator; use datafusion::execution::SendableRecordBatchStream; +use humantime::format_duration; use lance_core::datatypes::NullabilityComparison; use lance_core::datatypes::Schema; use lance_core::datatypes::SchemaCompareOptions; @@ -15,7 +17,7 @@ use lance_io::object_store::ObjectStore; use lance_table::feature_flags::can_write_dataset; use lance_table::io::commit::CommitHandler; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use crate::dataset::builder::DatasetBuilder; use crate::dataset::transaction::Operation; @@ -119,7 +121,6 @@ impl<'a> InsertBuilder<'a> { let mut commit_builder = CommitBuilder::new(context.dest.clone()) .use_move_stable_row_ids(context.params.enable_move_stable_row_ids) .with_storage_format(context.storage_version) - .with_object_store_registry(context.params.object_store_registry.clone()) .enable_v2_manifest_paths(context.params.enable_v2_manifest_paths) .with_commit_handler(context.commit_handler.clone()) .with_object_store(context.object_store.clone()); @@ -203,7 +204,44 @@ impl<'a> InsertBuilder<'a> { context: &WriteContext<'_>, ) -> Result { let operation = match context.params.mode { - WriteMode::Create | WriteMode::Overwrite => Operation::Overwrite { + WriteMode::Create => { + // Fetch auto_cleanup params from context + let config_upsert_values = match context.params.auto_cleanup.as_ref() { + Some(auto_cleanup_params) => { + let mut upsert_values = HashMap::new(); + + upsert_values.insert( + String::from("lance.auto_cleanup.interval"), + auto_cleanup_params.interval.to_string(), + ); + + match auto_cleanup_params.older_than.to_std() { + Ok(d) => { + upsert_values.insert( + String::from("lance.auto_cleanup.older_than"), + format_duration(d).to_string(), + ); + } + Err(e) => { + return Err(Error::InvalidInput { + source: e.into(), + location: location!(), + }) + } + }; + + Some(upsert_values) + } + None => None, + }; + Operation::Overwrite { + // Use the full schema, not the written schema + schema, + fragments: written_frags.default.0, + config_upsert_values, + } + } + WriteMode::Overwrite => Operation::Overwrite { // Use the full schema, not the written schema schema, fragments: written_frags.default.0, @@ -321,8 +359,13 @@ impl<'a> InsertBuilder<'a> { dataset.commit_handler.clone(), ), WriteDestination::Uri(uri) => { + let registry = params + .session + .as_ref() + .map(|s| s.store_registry()) + .unwrap_or_else(|| Arc::new(Default::default())); let (object_store, base_path) = ObjectStore::from_uri_and_params( - params.object_store_registry.clone(), + registry, uri, ¶ms.store_params.clone().unwrap_or_default(), ) @@ -333,7 +376,7 @@ impl<'a> InsertBuilder<'a> { ¶ms.store_params, ) .await?; - (Arc::new(object_store), base_path, commit_handler) + (object_store, base_path, commit_handler) } }; let dest = match &self.dest { @@ -343,7 +386,6 @@ impl<'a> InsertBuilder<'a> { let builder = DatasetBuilder::from_uri(uri).with_read_params(ReadParams { store_options: params.store_params.clone(), commit_handler: params.commit_handler.clone(), - object_store_registry: params.object_store_registry.clone(), ..Default::default() }); diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 16af301c8d9..57ac706794f 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -22,7 +22,8 @@ use std::{ }; use arrow_array::{ - cast::AsArray, types::UInt64Type, BooleanArray, RecordBatch, StructArray, UInt64Array, + cast::AsArray, types::UInt64Type, BooleanArray, RecordBatch, RecordBatchIterator, StructArray, + UInt64Array, }; use arrow_schema::{DataType, Field, Schema}; use datafusion::{ @@ -30,19 +31,24 @@ use datafusion::{ context::{SessionConfig, SessionContext}, memory_pool::MemoryConsumer, }, - logical_expr::{Expr, JoinType}, + logical_expr::{self, Expr, JoinType}, physical_plan::{ joins::{HashJoinExec, PartitionMode}, + projection::ProjectionExec, repartition::RepartitionExec, stream::RecordBatchStreamAdapter, union::UnionExec, ColumnarValue, ExecutionPlan, PhysicalExpr, SendableRecordBatchStream, }, + prelude::DataFrame, scalar::ScalarValue, }; use lance_arrow::{interleave_batches, RecordBatchExt, SchemaExt}; -use lance_datafusion::{chunker::chunk_stream, dataframe::DataFrameExt, exec::get_session_context}; +use lance_datafusion::{ + chunker::chunk_stream, dataframe::DataFrameExt, exec::get_session_context, + utils::reader_to_stream, +}; use datafusion_physical_expr::expressions::Column; use futures::{ @@ -50,9 +56,9 @@ use futures::{ Stream, StreamExt, TryStreamExt, }; use lance_core::{ - datatypes::SchemaCompareOptions, + datatypes::{OnMissing, OnTypeMismatch, SchemaCompareOptions}, error::{box_error, InvalidInputSnafu}, - utils::{futures::Capacity, tokio::get_num_compute_intensive_cpus}, + utils::{backoff::Backoff, futures::Capacity, tokio::get_num_compute_intensive_cpus}, Error, Result, ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD, }; use lance_datafusion::{ @@ -64,7 +70,7 @@ use lance_index::DatasetIndexExt; use lance_table::format::{Fragment, Index}; use log::info; use roaring::RoaringTreemap; -use snafu::{location, Location, ResultExt}; +use snafu::{location, ResultExt}; use tokio::task::JoinSet; use crate::{ @@ -75,17 +81,13 @@ use crate::{ write::open_writer, }, index::DatasetIndexInternalExt, - io::{ - commit::commit_transaction, - exec::{ - project, scalar_index::MapIndexExec, utils::ReplayExec, AddRowAddrExec, Planner, - TakeExec, - }, + io::exec::{ + project, scalar_index::MapIndexExec, utils::ReplayExec, AddRowAddrExec, Planner, TakeExec, }, Dataset, }; -use super::{write_fragments_internal, WriteParams}; +use super::{write_fragments_internal, CommitBuilder, WriteParams}; // "update if" expressions typically compare fields from the source table to the target table. // These tables have the same schema and so filter expressions need to differentiate. To do that @@ -150,11 +152,15 @@ impl WhenNotMatchedBySource { let expr = planner .parse_filter(expr) .map_err(box_error) - .context(InvalidInputSnafu)?; + .context(InvalidInputSnafu { + location: location!(), + })?; let expr = planner .optimize_expr(expr) .map_err(box_error) - .context(InvalidInputSnafu)?; + .context(InvalidInputSnafu { + location: location!(), + })?; Ok(Self::DeleteIf(expr)) } } @@ -183,11 +189,15 @@ impl WhenMatched { let expr = planner .parse_filter(expr) .map_err(box_error) - .context(InvalidInputSnafu)?; + .context(InvalidInputSnafu { + location: location!(), + })?; let expr = planner .optimize_expr(expr) .map_err(box_error) - .context(InvalidInputSnafu)?; + .context(InvalidInputSnafu { + location: location!(), + })?; Ok(Self::UpdateIf(expr)) } } @@ -214,10 +224,12 @@ struct MergeInsertParams { insert_not_matched: bool, // Controls whether data that is not matched by the source is deleted or not delete_not_matched_by_source: WhenNotMatchedBySource, + conflict_retries: u32, } /// A MergeInsertJob inserts new rows, deletes old rows, and updates existing rows all as /// part of a single transaction. +#[derive(Clone)] pub struct MergeInsertJob { // The column to merge the new data into dataset: Arc, @@ -289,6 +301,7 @@ impl MergeInsertBuilder { when_matched: WhenMatched::DoNothing, insert_not_matched: true, delete_not_matched_by_source: WhenNotMatchedBySource::Keep, + conflict_retries: 10, }, }) } @@ -318,6 +331,18 @@ impl MergeInsertBuilder { self } + /// Set number of times to retry the operation if there is contention. + /// + /// If this is set > 0, then the operation will keep a copy of the input data + /// either in memory or on disk (depending on the size of the data) and will + /// retry the operation if there is contention. + /// + /// Default is 10. + pub fn conflict_retries(&mut self, retries: u32) -> &mut Self { + self.params.conflict_retries = retries; + self + } + /// Crate a merge insert job pub fn try_build(&mut self) -> Result { if !self.params.insert_not_matched @@ -447,12 +472,13 @@ impl MergeInsertJob { } // 4 - Take the mapped row ids - let mut target = Arc::new(TakeExec::try_new( - self.dataset.clone(), - index_mapper, - Arc::new(self.dataset.schema().project_by_schema(schema.as_ref())?), - get_num_compute_intensive_cpus(), - )?) as Arc; + let projection = self + .dataset + .empty_projection() + .union_arrow_schema(schema.as_ref(), OnMissing::Error)?; + let mut target = + Arc::new(TakeExec::try_new(self.dataset.clone(), index_mapper, projection)?.unwrap()) + as Arc; // 5 - Take puts the row id and row addr at the beginning. A full scan (used when there is // no scalar index) puts the row id and addr at the end. We need to match these up so @@ -498,9 +524,16 @@ impl MergeInsertJob { )?); } + // We need to prefix the fields in the target with target_ so that we don't have any duplicate + // field names (DF doesn't support this as of version 44) + target = Self::prefix_columns_phys(target, "target_"); + // 6 - Finally, join the input (source table) with the taken data (target table) let source_key = Column::new_with_schema(&index_column, shared_input.schema().as_ref())?; - let target_key = Column::new_with_schema(&index_column, target.schema().as_ref())?; + let target_key = Column::new_with_schema( + &format!("target_{}", index_column), + target.schema().as_ref(), + )?; let joined = Arc::new( HashJoinExec::try_new( shared_input, @@ -523,6 +556,38 @@ impl MergeInsertJob { ) } + fn prefix_columns(df: DataFrame, prefix: &str) -> DataFrame { + let schema = df.schema(); + let columns = schema + .fields() + .iter() + .map(|f| { + // Need to "quote" the column name so it gets interpreted case-sensitively + logical_expr::col(format!("\"{}\"", f.name())).alias(format!( + "{}{}", + prefix, + f.name() + )) + }) + .collect::>(); + df.select(columns).unwrap() + } + + fn prefix_columns_phys(inp: Arc, prefix: &str) -> Arc { + let schema = inp.schema(); + let exprs = schema + .fields() + .iter() + .enumerate() + .map(|(idx, f)| { + let col = Arc::new(Column::new(f.name(), idx)) as Arc; + let new_name = format!("{}{}", prefix, f.name()); + (col, new_name) + }) + .collect::>(); + Arc::new(ProjectionExec::try_new(exprs, inp).unwrap()) + } + // If the join keys are not indexed then we need to do a full scan of the table async fn create_full_table_joined_stream( &self, @@ -538,12 +603,21 @@ impl MergeInsertJob { .iter() .map(|c| c.as_str()) .collect::>(); // vector of strings of col names to join + let target_cols = self + .params + .on + .iter() + .map(|c| format!("target_{}", c)) + .collect::>(); + let target_cols = target_cols.iter().map(|s| s.as_str()).collect::>(); match self.check_compatible_schema(&schema)? { SchemaComparison::FullCompatible => { let existing = session_ctx.read_lance(self.dataset.clone(), true, false)?; + // We need to rename the columns from the target table so that they don't conflict with the source table + let existing = Self::prefix_columns(existing, "target_"); let joined = - new_data.join(existing, JoinType::Full, &join_cols, &join_cols, None)?; // full join + new_data.join(existing, JoinType::Full, &join_cols, &target_cols, None)?; // full join Ok(joined.execute_stream().await?) } SchemaComparison::Subschema => { @@ -555,14 +629,27 @@ impl MergeInsertJob { .chain([ROW_ID, ROW_ADDR]) .collect::>(); let projected = existing.select_columns(&columns)?; + // We need to rename the columns from the target table so that they don't conflict with the source table + let projected = Self::prefix_columns(projected, "target_"); // We aren't supporting inserts or deletes right now, so we can use inner join - let joined = - new_data.join(projected, JoinType::Inner, &join_cols, &join_cols, None)?; + let join_type = if self.params.insert_not_matched { + JoinType::Left + } else { + JoinType::Inner + }; + let joined = new_data.join(projected, join_type, &join_cols, &target_cols, None)?; Ok(joined.execute_stream().await?) } } } + /// Join the source and target data streams + /// + /// If there is a scalar index on the join key, we can use it to do an indexed join. Otherwise we need to do + /// a full outer join. + /// + /// Datafusion doesn't allow duplicate column names so during this join we rename the columns from target and + /// prefix them with _target. async fn create_joined_stream( &self, source: SendableRecordBatchStream, @@ -588,46 +675,30 @@ impl MergeInsertJob { async fn update_fragments( dataset: Arc, source: SendableRecordBatchStream, - ) -> Result> { + ) -> Result<(Vec, Vec)> { // Expected source schema: _rowaddr, updated_cols* use datafusion::logical_expr::{col, lit}; - let session_ctx = get_session_context(LanceExecutionOptions { + let session_ctx = get_session_context(&LanceExecutionOptions { use_spilling: true, + target_partition: Some(get_num_compute_intensive_cpus().min(8)), ..Default::default() }); let mut group_stream = session_ctx .read_one_shot(source)? - .sort(vec![col(ROW_ADDR).sort(true, true)])? .with_column("_fragment_id", col(ROW_ADDR) >> lit(32))? + .sort(vec![col(ROW_ADDR).sort(true, true)])? .group_by_stream(&["_fragment_id"]) .await?; // Can update the fragments in parallel. let updated_fragments = Arc::new(Mutex::new(Vec::new())); + let new_fragments = Arc::new(Mutex::new(Vec::new())); let mut tasks = JoinSet::new(); - let task_limit = get_num_compute_intensive_cpus(); + let task_limit = dataset.object_store().io_parallelism(); let mut reservation = MemoryConsumer::new("MergeInsert").register(session_ctx.task_ctx().memory_pool()); - while let Some((frag_id, batches)) = group_stream.next().await.transpose()? { - let Some(ScalarValue::UInt64(Some(frag_id))) = frag_id.first() else { - return Err(Error::Internal { - message: format!("Got non-fragment id from merge result: {:?}", frag_id), - location: location!(), - }); - }; - let frag_id = *frag_id; - let fragment = - dataset - .get_fragment(frag_id as usize) - .ok_or_else(|| Error::Internal { - message: format!( - "Got non-existent fragment id from merge result: {}", - frag_id - ), - location: location!(), - })?; - let metadata = fragment.metadata.clone(); + while let Some((frag_id, batches)) = group_stream.next().await.transpose()? { async fn handle_fragment( dataset: Arc, fragment: FileFragment, @@ -638,7 +709,11 @@ impl MergeInsertJob { ) -> Result { // batches still have _rowaddr let write_schema = batches[0].schema().as_ref().without_column(ROW_ADDR); - let write_schema = dataset.local_schema().project_by_schema(&write_schema)?; + let write_schema = dataset.local_schema().project_by_schema( + &write_schema, + OnMissing::Error, + OnTypeMismatch::Error, + )?; let updated_rows: usize = batches.iter().map(|batch| batch.num_rows()).sum(); if Some(updated_rows) == metadata.physical_rows { @@ -677,7 +752,6 @@ impl MergeInsertJob { .open( dataset.schema(), FragReadConfig::default().with_row_address(true), - None, ) .await?; let batch_size = reader.legacy_num_rows_in_batch(0).unwrap(); @@ -787,6 +861,48 @@ impl MergeInsertJob { Ok(reservation_size) } + async fn handle_new_fragments( + dataset: Arc, + batches: Vec, + new_fragments: Arc>>, + reservation_size: usize, + ) -> Result { + // Batches still have _rowaddr (used elsewhere to merge with existing data) + // We need to remove it before writing to Lance files. + let num_fields = batches[0].schema().fields().len(); + let mut projection = Vec::with_capacity(num_fields - 1); + for (i, field) in batches[0].schema().fields().iter().enumerate() { + if field.name() != ROW_ADDR { + projection.push(i); + } + } + let write_schema = Arc::new(batches[0].schema().project(&projection).unwrap()); + + let batches = batches + .into_iter() + .map(move |batch| batch.project(&projection)); + let reader = RecordBatchIterator::new(batches, write_schema.clone()); + let stream = reader_to_stream(Box::new(reader)); + + let write_schema = dataset.schema().project_by_schema( + write_schema.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?; + + let fragments = write_fragments_internal( + Some(dataset.as_ref()), + dataset.object_store.clone(), + &dataset.base, + write_schema, + stream, + Default::default(), // TODO: support write params. + ) + .await?; + + new_fragments.lock().unwrap().extend(fragments.default.0); + Ok(reservation_size) + } // We shouldn't need much more memory beyond what is already in the batches. let mut memory_size = batches .iter() @@ -813,15 +929,47 @@ impl MergeInsertJob { } } - let fut = handle_fragment( - dataset.clone(), - fragment, - metadata, - batches, - updated_fragments.clone(), - memory_size, - ); - tasks.spawn(fut); + match frag_id.first() { + Some(ScalarValue::UInt64(Some(frag_id))) => { + let frag_id = *frag_id; + let fragment = + dataset + .get_fragment(frag_id as usize) + .ok_or_else(|| Error::Internal { + message: format!( + "Got non-existent fragment id from merge result: {}", + frag_id + ), + location: location!(), + })?; + let metadata = fragment.metadata.clone(); + + let fut = handle_fragment( + dataset.clone(), + fragment, + metadata, + batches, + updated_fragments.clone(), + memory_size, + ); + tasks.spawn(fut); + } + Some(ScalarValue::Null | ScalarValue::UInt64(None)) => { + let fut = handle_new_fragments( + dataset.clone(), + batches, + new_fragments.clone(), + memory_size, + ); + tasks.spawn(fut); + } + _ => { + return Err(Error::Internal { + message: format!("Got non-fragment id from merge result: {:?}", frag_id), + location: location!(), + }); + } + }; } while let Some(res) = tasks.join_next().await { @@ -847,7 +995,12 @@ impl MergeInsertJob { } } - Ok(updated_fragments) + let new_fragments = Arc::try_unwrap(new_fragments) + .unwrap() + .into_inner() + .unwrap(); + + Ok((updated_fragments, new_fragments)) } /// Executes the merge insert job @@ -855,16 +1008,69 @@ impl MergeInsertJob { /// This will take in the source, merge it with the existing target data, and insert new /// rows, update existing rows, and delete existing rows pub async fn execute( - self, + mut self, source: SendableRecordBatchStream, ) -> Result<(Arc, MergeStats)> { - let schema = source.schema(); + let mut source_iter = + super::new_source_iter(source, self.params.conflict_retries > 0).await?; + + let mut dataset_ref = self.dataset.clone(); + let max_retries = self.params.conflict_retries; + let mut backoff = Backoff::default(); + while backoff.attempt() <= max_retries { + let ds = dataset_ref.clone(); + let (transaction, stats) = self + .clone() + .execute_uncommitted_impl(source_iter.next().unwrap()) + .await?; + match CommitBuilder::new(ds).execute(transaction).await { + Ok(ds) => return Ok((Arc::new(ds), stats)), + Err(Error::RetryableCommitConflict { .. }) => { + tokio::time::sleep(backoff.next_backoff()).await; + let mut ds = dataset_ref.as_ref().clone(); + ds.checkout_latest().await?; + dataset_ref = Arc::new(ds); + self.dataset = dataset_ref.clone(); + continue; + } + Err(e) => return Err(e), + }; + } + Err(Error::TooMuchWriteContention { + message: format!("Attempted {} retries.", max_retries), + location: location!(), + }) + } + + /// Execute the merge insert job without committing the changes. + /// + /// Use [`CommitBuilder`] to commit the returned transaction. + pub async fn execute_uncommitted( + self, + source: impl StreamingWriteSource, + ) -> Result<(Transaction, MergeStats)> { + let stream = source.into_stream(); + self.execute_uncommitted_impl(stream).await + } - let full_schema = Schema::from(self.dataset.local_schema()); - let is_full_schema = &full_schema == schema.as_ref(); + async fn execute_uncommitted_impl( + self, + source: SendableRecordBatchStream, + ) -> Result<(Transaction, MergeStats)> { + // Erase metadata on source / dataset schemas to avoid comparing metadata + let schema = lance_core::datatypes::Schema::try_from(source.schema().as_ref())?; + let full_schema = self.dataset.local_schema(); + let is_full_schema = full_schema.compare_with_options( + &schema, + &SchemaCompareOptions { + compare_metadata: false, + ..Default::default() + }, + ); + let source_schema = source.schema(); let joined = self.create_joined_stream(source).await?; - let merger = Merger::try_new(self.params.clone(), schema.clone(), !is_full_schema)?; + let merger = Merger::try_new(self.params.clone(), source_schema, !is_full_schema)?; let merge_statistics = merger.merge_stats.clone(); let deleted_rows = merger.deleted_rows.clone(); let merger_schema = merger.output_schema().clone(); @@ -873,14 +1079,7 @@ impl MergeInsertJob { .try_flatten(); let stream = RecordBatchStreamAdapter::new(merger_schema, stream); - let committed_ds = if !is_full_schema { - if self.params.insert_not_matched { - return Err(Error::NotSupported { - source: "The merge insert operation is configured to not insert new rows, but the source data has a different schema than the target data".into(), - location: location!(), - }); - } - + let operation = if !is_full_schema { if !matches!( self.params.delete_not_matched_by_source, WhenNotMatchedBySource::Keep @@ -891,10 +1090,14 @@ impl MergeInsertJob { // We will have a different commit path here too, as we are modifying // fragments rather than writing new ones - let updated_fragments = + let (updated_fragments, new_fragments) = Self::update_fragments(self.dataset.clone(), Box::pin(stream)).await?; - Self::commit(self.dataset, Vec::new(), updated_fragments, Vec::new()).await? + Operation::Update { + removed_fragment_ids: Vec::new(), + updated_fragments, + new_fragments, + } } else { let written = write_fragments_internal( Some(&self.dataset), @@ -916,13 +1119,11 @@ impl MergeInsertJob { Self::apply_deletions(&self.dataset, &removed_row_ids).await?; // Commit updated and new fragments - Self::commit( - self.dataset, + Operation::Update { removed_fragment_ids, - old_fragments, + updated_fragments: old_fragments, new_fragments, - ) - .await? + } }; let stats = Arc::into_inner(merge_statistics) @@ -930,7 +1131,14 @@ impl MergeInsertJob { .into_inner() .unwrap(); - Ok((committed_ds, stats)) + let transaction = Transaction::new( + self.dataset.manifest.version, + operation, + /*blobs_op=*/ None, + None, + ); + + Ok((transaction, stats)) } // Delete a batch of rows by id, returns the fragments modified and the fragments removed @@ -979,43 +1187,6 @@ impl MergeInsertJob { Ok((updated_fragments, removed_fragments)) } - - // Commit the operation - async fn commit( - dataset: Arc, - removed_fragment_ids: Vec, - updated_fragments: Vec, - new_fragments: Vec, - ) -> Result> { - let operation = Operation::Update { - removed_fragment_ids, - updated_fragments, - new_fragments, - }; - let transaction = Transaction::new( - dataset.manifest.version, - operation, - /*blobs_op=*/ None, - None, - ); - - let (manifest, manifest_path) = commit_transaction( - dataset.as_ref(), - dataset.object_store(), - dataset.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - dataset.manifest_naming_scheme, - ) - .await?; - - let mut dataset = dataset.as_ref().clone(); - dataset.manifest = Arc::new(manifest); - dataset.manifest_file = manifest_path; - - Ok(Arc::new(dataset)) - } } /// Merger will store these statistics as it runs (for each batch) @@ -1249,10 +1420,14 @@ impl Merger { } if self.params.insert_not_matched { let not_matched = arrow::compute::filter_record_batch(&batch, &left_only)?; - let not_matched = not_matched.project(&left_cols)?; + let left_cols_with_id = left_cols + .into_iter() + .chain(row_addr_col) + .collect::>(); + let not_matched = not_matched.project(&left_cols_with_id)?; // See comment above explaining this schema replacement let not_matched = RecordBatch::try_new( - self.schema.clone(), + self.output_schema.clone(), Vec::from_iter(not_matched.columns().iter().cloned()), )?; @@ -1307,12 +1482,14 @@ mod tests { }; use arrow_select::concat::concat_batches; use datafusion::common::Column; + use futures::future::try_join_all; use lance_datafusion::utils::reader_to_stream; use lance_datagen::{array, BatchCount, RowCount, Seed}; use lance_index::{scalar::ScalarIndexParams, IndexType}; use tempfile::tempdir; + use tokio::sync::{Barrier, Notify}; - use crate::dataset::{WriteMode, WriteParams}; + use crate::dataset::{builder::DatasetBuilder, InsertBuilder, WriteMode, WriteParams}; use super::*; @@ -1683,10 +1860,7 @@ mod tests { // Check that the data is as expected let updated = ds - .scan() - .filter("value = 9999999") - .unwrap() - .count_rows() + .count_rows(Some("value = 9999999".to_string())) .await .unwrap(); assert_eq!(updated, 2048); @@ -1745,9 +1919,10 @@ mod tests { .col("other", array::rand_utf8(4.into(), false)) .col("value", array::step::()) .col("key", array::rand_pseudo_uuid_hex()); - let batch = data.into_batch_rows(RowCount::from(1024)).unwrap(); + let batch = data.into_batch_rows(RowCount::from(1024 + 2)).unwrap(); let batch1 = batch.slice(0, 512); let batch2 = batch.slice(512, 512); + let batch3 = batch.slice(1024, 2); let schema = batch.schema(); let reader = Box::new(RecordBatchIterator::new( @@ -1770,7 +1945,7 @@ mod tests { .unwrap(); } - // Another two batches, not in the scalar index (if there is one) + // Another two files, not in the scalar index (if there is one) let reader = Box::new(RecordBatchIterator::new( [Ok(batch2.clone())], batch2.schema(), @@ -1781,14 +1956,16 @@ mod tests { // New data with only a subset of columns let update_schema = Arc::new(schema.project(&[2, 1]).unwrap()); - // Full second file and part of third file. + // Full second file and part of third file. Also two more new rows. let indices: Int64Array = (256..512).chain(600..612).chain([712, 715]).collect(); let keys = arrow::compute::take(batch["key"].as_ref(), &indices, None).unwrap(); + let keys = arrow::compute::concat(&[&keys, &batch3["key"]]).unwrap(); + let num_rows = keys.len(); let new_data = RecordBatch::try_new( update_schema, vec![ keys, - Arc::new((1000..(1000 + indices.len() as u32)).collect::()), + Arc::new((1024..(1024 + num_rows as u32)).collect::()), ], ) .unwrap(); @@ -1825,30 +2002,6 @@ mod tests { ); } - #[tokio::test] - async fn test_insert_not_supported() { - let Fixtures { ds, new_data } = setup(false).await; - - let reader = Box::new(RecordBatchIterator::new( - [Ok(new_data.clone())], - new_data.schema(), - )); - - // Should reject when_not_matched_insert_all as not yet supported - let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) - .unwrap() - .when_not_matched(WhenNotMatched::InsertAll) - .when_matched(WhenMatched::UpdateAll) - .try_build() - .unwrap(); - let res = job.execute_reader(reader).await; - assert!(matches!( - res, - Err(Error::NotSupported { source, .. }) - if source.to_string().contains("The merge insert operation is configured to not insert new rows, but the source data has a different schema than the target data") - )); - } - #[tokio::test] async fn test_errors_on_bad_schema() { let Fixtures { ds, new_data } = setup(false).await; @@ -1884,7 +2037,10 @@ mod tests { #[rstest] #[tokio::test] - async fn test_merge_insert_subcols(#[values(false, true)] scalar_index: bool) { + async fn test_merge_insert_subcols( + #[values(false, true)] scalar_index: bool, + #[values(false, true)] insert: bool, + ) { let Fixtures { ds, new_data } = setup(scalar_index).await; let reader = Box::new(RecordBatchIterator::new( [Ok(new_data.clone())], @@ -1898,7 +2054,11 @@ mod tests { let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) .unwrap() .when_matched(WhenMatched::UpdateAll) - .when_not_matched(WhenNotMatched::DoNothing) + .when_not_matched(if insert { + WhenNotMatched::InsertAll + } else { + WhenNotMatched::DoNothing + }) .try_build() .unwrap(); @@ -1912,9 +2072,13 @@ mod tests { .collect::>(); assert_eq!( fragments_before.iter().map(|f| f.id).collect::>(), - fragments_after.iter().map(|f| f.id).collect::>() + fragments_after + .iter() + .take(fragments_before.len()) + .map(|f| f.id) + .collect::>() ); - // Only the second fragment should be different. + // Only the second and third fragment should be different. assert_eq!(fragments_before[0], fragments_after[0]); assert_ne!(fragments_before[1], fragments_after[1]); assert_ne!(fragments_before[2], fragments_after[2]); @@ -1931,8 +2095,15 @@ mod tests { has_added_files(&fragments_after[1]); has_added_files(&fragments_after[2]); - assert_eq!(stats.num_inserted_rows, 0); - assert_eq!(stats.num_updated_rows, new_data.num_rows() as u64); + if insert { + assert_eq!(fragments_after.len(), 5); + assert_eq!(stats.num_inserted_rows, 2); + } else { + assert_eq!(fragments_after.len(), 4); + assert_eq!(stats.num_inserted_rows, 0); + } + + assert_eq!(stats.num_updated_rows, (new_data.num_rows() - 2) as u64); assert_eq!(stats.num_deleted_rows, 0); let data = ds @@ -1941,7 +2112,7 @@ mod tests { .try_into_batch() .await .unwrap(); - assert_eq!(data.num_rows(), 1024); + assert_eq!(data.num_rows(), if insert { 1024 + 2 } else { 1024 }); assert_eq!(data.num_columns(), 3); let values = data @@ -1950,9 +2121,191 @@ mod tests { .downcast_ref::() .unwrap(); assert_eq!(values.value(0), 0); - assert_eq!(values.value(256), 1_000); + assert_eq!(values.value(256), 1024); assert_eq!(values.value(512), 512); - assert_eq!(values.value(715), 1_000 + new_data.num_rows() as u32 - 1); + assert_eq!(values.value(715), 1024 + new_data.num_rows() as u32 - 3); + if insert { + assert_eq!(values.value(1024), 1024 + new_data.num_rows() as u32 - 2); + } + } + } + + #[tokio::test] + async fn test_merge_insert_concurrency() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("value", DataType::UInt32, false), + ])); + let num_rows = 10; + let initial_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..num_rows)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 0, + num_rows as usize, + ))), + ], + ) + .unwrap(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let mut dataset = InsertBuilder::new(test_uri) + .execute(vec![initial_data]) + .await + .unwrap(); + + // do 10 merge inserts in parallel. Each will open the dataset, signal + // they have opened, and then wait for a signal to proceed. Once the signal + // is received, they will do a merge insert and close the dataset. + + let barrier = Arc::new(Barrier::new(10)); + let mut handles = Vec::new(); + for i in 0..10 { + let uri_ref = test_uri.to_string(); + let schema_ref = schema.clone(); + let barrier_ref = barrier.clone(); + let handle = tokio::task::spawn(async move { + let dataset = DatasetBuilder::from_uri(&uri_ref).load().await.unwrap(); + let dataset = Arc::new(dataset); + + let new_data = RecordBatch::try_new( + schema_ref.clone(), + vec![ + Arc::new(UInt32Array::from(vec![i])), + Arc::new(UInt32Array::from(vec![1])), + ], + ) + .unwrap(); + let source = Box::new(RecordBatchIterator::new([Ok(new_data)], schema_ref.clone())); + + let job = MergeInsertBuilder::try_new(dataset, vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap(); + barrier_ref.wait().await; + + job.execute_reader(source).await.unwrap(); + }); + handles.push(handle); } + + try_join_all(handles).await.unwrap(); + + dataset.checkout_latest().await.unwrap(); + let batches = dataset.scan().try_into_batch().await.unwrap(); + + let values = batches["value"].as_primitive::(); + assert!( + values.values().iter().all(|&v| v == 1), + "All values should be 1 after merge insert. Got: {:?}", + values + ); + } + + #[tokio::test] + async fn test_merge_insert_large_concurrent() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("value", DataType::UInt32, false), + ])); + let num_rows = 10; + let initial_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..num_rows)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 0, + num_rows as usize, + ))), + ], + ) + .unwrap(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let dataset = InsertBuilder::new(test_uri) + .execute(vec![initial_data]) + .await + .unwrap(); + let dataset = Arc::new(dataset); + + // Start one merge insert, but don't commit it yet. + let new_data1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1])), + Arc::new(UInt32Array::from(vec![1])), + ], + ) + .unwrap(); + let (transaction1, _stats) = + MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_uncommitted(RecordBatchIterator::new( + vec![Ok(new_data1)], + schema.clone(), + )) + .await + .unwrap(); + + // Setup a "large" merge insert, with many batches + let new_data2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..1000)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n(2, 1000))), + ], + ) + .unwrap(); + let notify = Arc::new(Notify::new()); + let source = RecordBatchIterator::new( + (0..10) + .map(|i| { + let batch = new_data2.slice(i * 100, 100); + if i == 9 { + notify.notify_one(); + } + Ok(batch) + }) + .collect::>(), + schema.clone(), + ); + let dataset2 = DatasetBuilder::from_uri(test_uri).load().await.unwrap(); + let job = MergeInsertBuilder::try_new(Arc::new(dataset2), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_reader(source); + let task = tokio::task::spawn(job); + + // Right as the large merge insert has finished reading the last batch, + // we will commit the first merge insert. This should trigger a conflict, + // but we should resolve it automatically. + notify.notified().await; + let mut dataset = CommitBuilder::new(dataset) + .execute(transaction1) + .await + .unwrap(); + + task.await.unwrap().unwrap(); + dataset.checkout_latest().await.unwrap(); + + let batches = dataset.scan().try_into_batch().await.unwrap(); + let values = batches["value"].as_primitive::(); + assert!( + values.values().iter().all(|&v| v == 2), + "All values should be 1 after merge insert. Got: {:?}", + values + ); } } diff --git a/rust/lance/src/dataset/write/update.rs b/rust/lance/src/dataset/write/update.rs index b15a3289670..a620d8805de 100644 --- a/rust/lance/src/dataset/write/update.rs +++ b/rust/lance/src/dataset/write/update.rs @@ -22,10 +22,9 @@ use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_datafusion::expr::safe_coerce_scalar; use lance_table::format::Fragment; use roaring::RoaringTreemap; -use snafu::{location, Location, ResultExt}; +use snafu::{location, ResultExt}; use crate::dataset::transaction::{Operation, Transaction}; -use crate::io::commit::commit_transaction; use crate::{io::exec::Planner, Dataset}; use crate::{Error, Result}; @@ -69,13 +68,14 @@ impl UpdateBuilder { let expr = planner .parse_filter(filter) .map_err(box_error) - .context(InvalidInputSnafu)?; - self.condition = Some( - planner - .optimize_expr(expr) - .map_err(box_error) - .context(InvalidInputSnafu)?, - ); + .context(InvalidInputSnafu { + location: location!(), + })?; + self.condition = Some(planner.optimize_expr(expr).map_err(box_error).context( + InvalidInputSnafu { + location: location!(), + }, + )?); Ok(self) } @@ -113,7 +113,9 @@ impl UpdateBuilder { let mut expr = planner .parse_expr(value) .map_err(box_error) - .context(InvalidInputSnafu)?; + .context(InvalidInputSnafu { + location: location!(), + })?; // Cast expression to the column's data type if necessary. let dest_type = field.data_type(); @@ -121,7 +123,9 @@ impl UpdateBuilder { let src_type = expr .get_type(&df_schema) .map_err(box_error) - .context(InvalidInputSnafu)?; + .context(InvalidInputSnafu { + location: location!(), + })?; if dest_type != src_type { expr = match expr { // TODO: remove this branch once DataFusion supports casting List to FSL @@ -140,7 +144,9 @@ impl UpdateBuilder { _ => expr .cast_to(&dest_type, &df_schema) .map_err(box_error) - .context(InvalidInputSnafu)?, + .context(InvalidInputSnafu { + location: location!(), + })?, }; } @@ -150,7 +156,9 @@ impl UpdateBuilder { let expr = planner .optimize_expr(expr) .map_err(box_error) - .context(InvalidInputSnafu)?; + .context(InvalidInputSnafu { + location: location!(), + })?; self.updates.insert(column.as_ref().to_string(), expr); Ok(self) @@ -373,20 +381,10 @@ impl UpdateJob { None, ); - let (manifest, manifest_path) = commit_transaction( - self.dataset.as_ref(), - self.dataset.object_store(), - self.dataset.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.dataset.manifest_naming_scheme, - ) - .await?; - let mut dataset = self.dataset.as_ref().clone(); - dataset.manifest = Arc::new(manifest); - dataset.manifest_file = manifest_path; + dataset + .apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(Arc::new(dataset)) } @@ -420,9 +418,9 @@ mod tests { schema.clone(), vec![ Arc::new(Int64Array::from_iter_values(0..30)), - Arc::new(StringArray::from_iter_values( - std::iter::repeat("foo").take(30), - )), + Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "foo", 30, + ))), ], ) .unwrap(); diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index ace9906d5cc..431c9528490 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -5,23 +5,29 @@ //! use std::collections::{HashMap, HashSet}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow_schema::DataType; +use arrow_schema::{DataType, Schema}; use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::{stream, StreamExt, TryStreamExt}; use itertools::Itertools; +use lance_core::utils::parse::str_is_truthy; +use lance_core::utils::tracing::{IO_TYPE_OPEN_SCALAR, IO_TYPE_OPEN_VECTOR, TRACE_IO_EVENTS}; use lance_file::reader::FileReader; use lance_file::v2; use lance_file::v2::reader::FileReaderOptions; +use lance_index::metrics::{MetricsCollector, NoOpMetricsCollector}; use lance_index::optimize::OptimizeOptions; use lance_index::pb::index::Implementation; use lance_index::scalar::expression::{ IndexInformationProvider, LabelListQueryParser, SargableQueryParser, ScalarQueryParser, + TextQueryParser, }; use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::scalar::{InvertedIndexParams, ScalarIndex, ScalarIndexType}; -use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::pq::ProductQuantizer; use lance_index::vector::sq::ScalarQuantizer; @@ -44,10 +50,11 @@ use lance_table::io::manifest::read_manifest_indexes; use roaring::RoaringBitmap; use scalar::{build_inverted_index, detect_scalar_index_type, inverted_index_details}; use serde_json::json; -use snafu::{location, Location}; -use tracing::instrument; +use snafu::location; +use tracing::{info, instrument}; use uuid::Uuid; use vector::ivf::v2::IVFIndex; +use vector::utils::get_vector_type; pub(crate) mod append; pub(crate) mod cache; @@ -60,13 +67,23 @@ pub use crate::index::prefilter::{FilterLoader, PreFilter}; use crate::dataset::transaction::{Operation, Transaction}; use crate::index::vector::remap_vector_index; -use crate::io::commit::commit_transaction; use crate::{dataset::Dataset, Error, Result}; use self::append::merge_indices; use self::scalar::build_scalar_index; use self::vector::{build_vector_index, VectorIndexParams, LANCE_VECTOR_INDEX}; +// Whether to auto-migrate a dataset when we encounter corruption. +fn auto_migrate_corruption() -> bool { + static LANCE_AUTO_MIGRATION: OnceLock = OnceLock::new(); + *LANCE_AUTO_MIGRATION.get_or_init(|| { + std::env::var("LANCE_AUTO_MIGRATION") + .ok() + .map(|s| str_is_truthy(&s)) + .unwrap_or(true) + }) +} + /// Builds index. #[async_trait] pub trait IndexBuilder { @@ -106,21 +123,15 @@ pub(crate) async fn remap_index( let new_id = Uuid::new_v4(); let generic = dataset - .open_generic_index(&field.name, &index_id.to_string()) + .open_generic_index(&field.name, &index_id.to_string(), &NoOpMetricsCollector) .await?; match generic.index_type() { it if it.is_scalar() => { - let new_store = match it { - IndexType::Scalar | IndexType::BTree => { - LanceIndexStore::from_dataset(dataset, &new_id.to_string()) - .with_legacy_format(true) - } - _ => LanceIndexStore::from_dataset(dataset, &new_id.to_string()), - }; + let new_store = LanceIndexStore::from_dataset(dataset, &new_id.to_string()); let scalar_index = dataset - .open_scalar_index(&field.name, &index_id.to_string()) + .open_scalar_index(&field.name, &index_id.to_string(), &NoOpMetricsCollector) .await?; scalar_index.remap(row_id_map, &new_store).await?; } @@ -230,7 +241,11 @@ impl DatasetIndexExt for Dataset { let index_id = Uuid::new_v4(); let index_details: prost_types::Any = match (index_type, params.index_name()) { ( - IndexType::Bitmap | IndexType::BTree | IndexType::Inverted | IndexType::LabelList, + IndexType::Bitmap + | IndexType::BTree + | IndexType::Inverted + | IndexType::NGram + | IndexType::LabelList, LANCE_SCALAR_INDEX, ) => { let params = ScalarIndexParams::new(index_type.try_into()?); @@ -270,8 +285,15 @@ impl DatasetIndexExt for Dataset { location: location!(), })?; - build_vector_index(self, column, &index_name, &index_id.to_string(), vec_params) - .await?; + // this is a large future so move it to heap + Box::pin(build_vector_index( + self, + column, + &index_name, + &index_id.to_string(), + vec_params, + )) + .await?; vector_index_details() } // Can't use if let Some(...) here because it's not stable yet. @@ -328,19 +350,50 @@ impl DatasetIndexExt for Dataset { None, ); - let (new_manifest, manifest_path) = commit_transaction( - self, - self.object_store(), - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; + + Ok(()) + } + + async fn drop_index(&mut self, name: &str) -> Result<()> { + let indices = self.load_indices_by_name(name).await?; + if indices.is_empty() { + return Err(Error::IndexNotFound { + identity: format!("name={}", name), + location: location!(), + }); + } + + let transaction = Transaction::new( + self.manifest.version, + Operation::CreateIndex { + new_indices: vec![], + removed_indices: indices.clone(), + }, + /*blobs_op= */ None, + None, + ); + + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; + + Ok(()) + } + + async fn prewarm_index(&self, name: &str) -> Result<()> { + let indices = self.load_indices_by_name(name).await?; + if indices.is_empty() { + return Err(Error::IndexNotFound { + identity: format!("name={}", name), + location: location!(), + }); + } - self.manifest = Arc::new(new_manifest); - self.manifest_file = manifest_path; + let index = self + .open_generic_index(name, &indices[0].uuid.to_string(), &NoOpMetricsCollector) + .await?; + index.prewarm().await?; Ok(()) } @@ -404,19 +457,8 @@ impl DatasetIndexExt for Dataset { None, ); - let (new_manifest, new_path) = commit_transaction( - self, - self.object_store(), - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; - - self.manifest = Arc::new(new_manifest); - self.manifest_file = new_path; + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; Ok(()) } @@ -452,7 +494,7 @@ impl DatasetIndexExt for Dataset { .filter(|idx| { indices_to_optimize .as_ref() - .map_or(true, |names| names.contains(&idx.name)) + .is_none_or(|names| names.contains(&idx.name)) }) .map(|idx| (idx.name.clone(), idx)) .into_group_map(); @@ -503,19 +545,9 @@ impl DatasetIndexExt for Dataset { None, ); - let (new_manifest, manifest_path) = commit_transaction( - self, - self.object_store(), - self.commit_handler.as_ref(), - &transaction, - &Default::default(), - &Default::default(), - self.manifest_naming_scheme, - ) - .await?; + self.apply_commit(transaction, &Default::default(), &Default::default()) + .await?; - self.manifest = Arc::new(new_manifest); - self.manifest_file = manifest_path; Ok(()) } @@ -539,7 +571,10 @@ impl DatasetIndexExt for Dataset { // Open all delta indices let indices = stream::iter(metadatas.iter()) - .then(|m| async move { self.open_generic_index(column, &m.uuid.to_string()).await }) + .then(|m| async move { + self.open_generic_index(column, &m.uuid.to_string(), &NoOpMetricsCollector) + .await + }) .try_collect::>() .await?; @@ -552,23 +587,71 @@ impl DatasetIndexExt for Dataset { let index_type = indices[0].index_type().to_string(); let indexed_fragments_per_delta = self.indexed_fragments(index_name).await?; - let num_indexed_rows_per_delta = self.indexed_fragments(index_name).await? - .iter() - .map(|frags| { - frags.iter().map(|f| f.num_rows().expect("Fragment should have row counts, please upgrade lance and trigger a single right to fix this")).sum::() - }) - .collect::>(); - let num_indexed_fragments = indexed_fragments_per_delta - .clone() - .into_iter() - .flatten() - .map(|f| f.id) - .collect::>() - .len(); + let res = indexed_fragments_per_delta + .iter() + .map(|frags| { + let mut sum = 0; + for frag in frags.iter() { + sum += frag.num_rows().ok_or_else(|| Error::Internal { + message: "Fragment should have row counts, please upgrade lance and \ + trigger a single write to fix this" + .to_string(), + location: location!(), + })?; + } + Ok(sum) + }) + .collect::>>(); + + async fn migrate_and_recompute(ds: &Dataset, index_name: &str) -> Result { + let mut ds = ds.clone(); + log::warn!( + "Detecting out-dated fragment metadata, migrating dataset. \ + To disable migration, set LANCE_AUTO_MIGRATION=false" + ); + ds.delete("false").await.map_err(|err| { + Error::Execution { + message: format!("Failed to migrate dataset while calculating index statistics. \ + To disable migration, set LANCE_AUTO_MIGRATION=false. Original error: {}", err), + location: location!(), + } + })?; + ds.index_statistics(index_name).await + } + + let num_indexed_rows_per_delta = match res { + Ok(rows) => rows, + Err(Error::Internal { message, .. }) + if auto_migrate_corruption() && message.contains("trigger a single write") => + { + return migrate_and_recompute(self, index_name).await; + } + Err(e) => return Err(e), + }; + + let mut fragment_ids = HashSet::new(); + for frags in indexed_fragments_per_delta.iter() { + for frag in frags.iter() { + if !fragment_ids.insert(frag.id) { + if auto_migrate_corruption() { + return migrate_and_recompute(self, index_name).await; + } else { + return Err(Error::Internal { + message: + "Overlap in indexed fragments. Please upgrade to lance >= 0.23.0 \ + and trigger a single write to fix this" + .to_string(), + location: location!(), + }); + } + } + } + } + let num_indexed_fragments = fragment_ids.len(); let num_unindexed_fragments = self.fragments().len() - num_indexed_fragments; - let num_indexed_rows = num_indexed_rows_per_delta.iter().last().unwrap(); + let num_indexed_rows: usize = num_indexed_rows_per_delta.iter().cloned().sum(); let num_unindexed_rows = self.count_rows(None).await? - num_indexed_rows; let stats = json!({ @@ -588,6 +671,50 @@ impl DatasetIndexExt for Dataset { location: location!(), }) } + + async fn read_index_partition( + &self, + index_name: &str, + partition_id: usize, + with_vector: bool, + ) -> Result { + let indices = self.load_indices_by_name(index_name).await?; + if indices.is_empty() { + return Err(Error::IndexNotFound { + identity: format!("name={}", index_name), + location: location!(), + }); + } + let column = self.schema().field_by_id(indices[0].fields[0]).unwrap(); + + let mut schema: Option> = None; + let mut partition_streams = Vec::with_capacity(indices.len()); + for index in indices { + let index = self + .open_vector_index(&column.name, &index.uuid.to_string(), &NoOpMetricsCollector) + .await?; + + let stream = index + .partition_reader(partition_id, with_vector, &NoOpMetricsCollector) + .await?; + if schema.is_none() { + schema = Some(stream.schema()); + } + partition_streams.push(stream); + } + + match schema { + Some(schema) => { + let merged = stream::select_all(partition_streams); + let stream = RecordBatchStreamAdapter::new(schema, merged); + Ok(Box::pin(stream)) + } + None => Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::new(Schema::empty()), + stream::empty(), + ))), + } + } } /// A trait for internal dataset utilities @@ -596,11 +723,26 @@ impl DatasetIndexExt for Dataset { #[async_trait] pub trait DatasetIndexInternalExt: DatasetIndexExt { /// Opens an index (scalar or vector) as a generic index - async fn open_generic_index(&self, column: &str, uuid: &str) -> Result>; + async fn open_generic_index( + &self, + column: &str, + uuid: &str, + metrics: &dyn MetricsCollector, + ) -> Result>; /// Opens the requested scalar index - async fn open_scalar_index(&self, column: &str, uuid: &str) -> Result>; + async fn open_scalar_index( + &self, + column: &str, + uuid: &str, + metrics: &dyn MetricsCollector, + ) -> Result>; /// Opens the requested vector index - async fn open_vector_index(&self, column: &str, uuid: &str) -> Result>; + async fn open_vector_index( + &self, + column: &str, + uuid: &str, + metrics: &dyn MetricsCollector, + ) -> Result>; /// Loads information about all the available scalar indices on the dataset async fn scalar_index_info(&self) -> Result; @@ -613,7 +755,12 @@ pub trait DatasetIndexInternalExt: DatasetIndexExt { #[async_trait] impl DatasetIndexInternalExt for Dataset { - async fn open_generic_index(&self, column: &str, uuid: &str) -> Result> { + async fn open_generic_index( + &self, + column: &str, + uuid: &str, + metrics: &dyn MetricsCollector, + ) -> Result> { // Checking for cache existence is cheap so we just check both scalar and vector caches if let Some(index) = self.session.index_cache.get_scalar(uuid) { return Ok(index.as_index()); @@ -633,15 +780,20 @@ impl DatasetIndexInternalExt for Dataset { let index_dir = self.indices_dir().child(uuid); let index_file = index_dir.child(INDEX_FILE_NAME); if self.object_store.exists(&index_file).await? { - let index = self.open_vector_index(column, uuid).await?; + let index = self.open_vector_index(column, uuid, metrics).await?; Ok(index.as_index()) } else { - let index = self.open_scalar_index(column, uuid).await?; + let index = self.open_scalar_index(column, uuid, metrics).await?; Ok(index.as_index()) } } - async fn open_scalar_index(&self, column: &str, uuid: &str) -> Result> { + async fn open_scalar_index( + &self, + column: &str, + uuid: &str, + metrics: &dyn MetricsCollector, + ) -> Result> { if let Some(index) = self.session.index_cache.get_scalar(uuid) { return Ok(index); } @@ -652,11 +804,20 @@ impl DatasetIndexInternalExt for Dataset { })?; let index = crate::index::scalar::open_scalar_index(self, column, &index_meta).await?; + + info!(target: TRACE_IO_EVENTS, index_uuid=uuid, type=IO_TYPE_OPEN_SCALAR, index_type=index.index_type().to_string()); + metrics.record_index_load(); + self.session.index_cache.insert_scalar(uuid, index.clone()); Ok(index) } - async fn open_vector_index(&self, column: &str, uuid: &str) -> Result> { + async fn open_vector_index( + &self, + column: &str, + uuid: &str, + metrics: &dyn MetricsCollector, + ) -> Result> { if let Some(index) = self.session.index_cache.get_vector(uuid) { log::debug!("Found vector index in cache uuid: {}", uuid); return Ok(index); @@ -673,18 +834,13 @@ impl DatasetIndexInternalExt for Dataset { // TODO: we need to change the legacy IVF_PQ to be in lance format let index = match (major_version, minor_version) { (0, 1) | (0, 0) => { + info!(target: TRACE_IO_EVENTS, index_uuid=uuid, type=IO_TYPE_OPEN_VECTOR, version="0.1", index_type="IVF_PQ"); let proto = open_index_proto(reader.as_ref()).await?; match &proto.implementation { Some(Implementation::VectorIndex(vector_index)) => { let dataset = Arc::new(self.clone()); - crate::index::vector::open_vector_index( - dataset, - column, - uuid, - vector_index, - reader, - ) - .await + crate::index::vector::open_vector_index(dataset, uuid, vector_index, reader) + .await } None => Err(Error::Internal { message: "Index proto was missing implementation field".into(), @@ -694,6 +850,7 @@ impl DatasetIndexInternalExt for Dataset { } (0, 2) => { + info!(target: TRACE_IO_EVENTS, index_uuid=uuid, type=IO_TYPE_OPEN_VECTOR, version="0.2", index_type="IVF_PQ"); let reader = FileReader::try_new_self_described_from_reader( reader.clone(), Some(&self.session.file_metadata_cache), @@ -737,16 +894,12 @@ impl DatasetIndexInternalExt for Dataset { location: location!(), })?; - let value_type = if let DataType::FixedSizeList(df, _) = field.data_type() { - Result::Ok(df.data_type().to_owned()) - } else { - return Err(Error::Index { - message: format!("Column {} is not a vector column", column), - location: location!(), - }); - }?; + let (_, element_type) = get_vector_type(self.schema(), column)?; + + info!(target: TRACE_IO_EVENTS, index_uuid=uuid, type=IO_TYPE_OPEN_VECTOR, version="0.3", index_type=index_metadata.index_type); + match index_metadata.index_type.as_str() { - "IVF_FLAT" => match value_type { + "IVF_FLAT" => match element_type { DataType::Float16 | DataType::Float32 | DataType::Float64 => { let ivf = IVFIndex::::try_new( self.object_store.clone(), @@ -757,6 +910,16 @@ impl DatasetIndexInternalExt for Dataset { .await?; Ok(Arc::new(ivf) as Arc) } + DataType::UInt8 => { + let ivf = IVFIndex::::try_new( + self.object_store.clone(), + self.indices_dir(), + uuid.to_owned(), + Arc::downgrade(&self.session), + ) + .await?; + Ok(Arc::new(ivf) as Arc) + } _ => Err(Error::Index { message: format!( "the field type {} is not supported for FLAT index", @@ -813,6 +976,7 @@ impl DatasetIndexInternalExt for Dataset { }), }; let index = index?; + metrics.record_index_load(); self.session.index_cache.insert_vector(uuid, index.clone()); Ok(index) } @@ -850,7 +1014,15 @@ impl DatasetIndexInternalExt for Dataset { if matches!(index_type, ScalarIndexType::Inverted) { continue; } - Box::::default() as Box + match index_type { + ScalarIndexType::BTree | ScalarIndexType::Bitmap => { + Box::::default() as Box + } + ScalarIndexType::NGram => { + Box::::default() as Box + } + _ => continue, + } } _ => Box::::default() as Box, }; @@ -904,6 +1076,8 @@ impl DatasetIndexInternalExt for Dataset { #[cfg(test)] mod tests { use crate::dataset::builder::DatasetBuilder; + use crate::dataset::optimize::{compact_files, CompactionOptions}; + use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; use super::*; @@ -979,19 +1153,79 @@ mod tests { .is_err()); } - #[tokio::test] - async fn test_count_index_rows() { - let test_dir = tempdir().unwrap(); + fn sample_vector_field() -> Field { let dimensions = 16; let column_name = "vec"; - let field = Field::new( + Field::new( column_name, DataType::FixedSizeList( Arc::new(Field::new("item", DataType::Float32, true)), dimensions, ), false, - ); + ) + } + + #[tokio::test] + async fn test_drop_index() { + let test_dir = tempdir().unwrap(); + let schema = Schema::new(vec![ + sample_vector_field(), + Field::new("ints", DataType::Int32, false), + ]); + let mut dataset = lance_datagen::rand(&schema) + .into_dataset( + test_dir.path().to_str().unwrap(), + FragmentCount::from(1), + FragmentRowCount::from(256), + ) + .await + .unwrap(); + + let idx_name = "name".to_string(); + dataset + .create_index( + &["vec"], + IndexType::Vector, + Some(idx_name.clone()), + &VectorIndexParams::ivf_pq(2, 8, 2, MetricType::L2, 10), + true, + ) + .await + .unwrap(); + dataset + .create_index( + &["ints"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + assert_eq!(dataset.load_indices().await.unwrap().len(), 2); + + dataset.drop_index(&idx_name).await.unwrap(); + + assert_eq!(dataset.load_indices().await.unwrap().len(), 1); + + // Even though we didn't give the scalar index a name it still has an auto-generated one we can use + let scalar_idx_name = &dataset.load_indices().await.unwrap()[0].name; + dataset.drop_index(scalar_idx_name).await.unwrap(); + + assert_eq!(dataset.load_indices().await.unwrap().len(), 0); + + // Make sure it returns an error if the index doesn't exist + assert!(dataset.drop_index(scalar_idx_name).await.is_err()); + } + + #[tokio::test] + async fn test_count_index_rows() { + let test_dir = tempdir().unwrap(); + let dimensions = 16; + let column_name = "vec"; + let field = sample_vector_field(); let schema = Arc::new(Schema::new(vec![field])); let float_arr = generate_random_array(512 * dimensions as usize); @@ -1047,7 +1281,6 @@ mod tests { #[tokio::test] async fn test_optimize_delta_indices() { - let test_dir = tempdir().unwrap(); let dimensions = 16; let column_name = "vec"; let vec_field = Field::new( @@ -1083,8 +1316,7 @@ mod tests { schema.clone(), ); - let test_uri = test_dir.path().to_str().unwrap(); - let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap(); + let mut dataset = Dataset::write(reader, "memory://", None).await.unwrap(); let params = VectorIndexParams::ivf_pq(10, 8, 2, MetricType::L2, 10); dataset .create_index( @@ -1107,81 +1339,104 @@ mod tests { .await .unwrap(); - let stats: serde_json::Value = - serde_json::from_str(&dataset.index_statistics("vec_idx").await.unwrap()).unwrap(); + async fn get_stats(dataset: &Dataset, name: &str) -> serde_json::Value { + serde_json::from_str(&dataset.index_statistics(name).await.unwrap()).unwrap() + } + async fn get_meta(dataset: &Dataset, name: &str) -> Vec { + dataset + .load_indices() + .await + .unwrap() + .iter() + .filter(|m| m.name == name) + .cloned() + .collect() + } + fn get_bitmap(meta: &IndexMetadata) -> Vec { + meta.fragment_bitmap.as_ref().unwrap().iter().collect() + } + + let stats = get_stats(&dataset, "vec_idx").await; assert_eq!(stats["num_unindexed_rows"], 0); assert_eq!(stats["num_indexed_rows"], 512); assert_eq!(stats["num_indexed_fragments"], 1); assert_eq!(stats["num_indices"], 1); + let meta = get_meta(&dataset, "vec_idx").await; + assert_eq!(meta.len(), 1); + assert_eq!(get_bitmap(&meta[0]), vec![0]); let reader = RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); dataset.append(reader, None).await.unwrap(); - let mut dataset = DatasetBuilder::from_uri(test_uri).load().await.unwrap(); - let stats: serde_json::Value = - serde_json::from_str(&dataset.index_statistics("vec_idx").await.unwrap()).unwrap(); + let stats = get_stats(&dataset, "vec_idx").await; assert_eq!(stats["num_unindexed_rows"], 512); assert_eq!(stats["num_indexed_rows"], 512); assert_eq!(stats["num_indexed_fragments"], 1); assert_eq!(stats["num_unindexed_fragments"], 1); assert_eq!(stats["num_indices"], 1); + let meta = get_meta(&dataset, "vec_idx").await; + assert_eq!(meta.len(), 1); + assert_eq!(get_bitmap(&meta[0]), vec![0]); dataset - .optimize_indices(&OptimizeOptions { - num_indices_to_merge: 0, // Just create index for delta - index_names: Some(vec![]), // Optimize nothing - }) + .optimize_indices(&OptimizeOptions::append().index_names(vec![])) // Does nothing because no index name is passed .await .unwrap(); - let stats: serde_json::Value = - serde_json::from_str(&dataset.index_statistics("vec_idx").await.unwrap()).unwrap(); + let stats = get_stats(&dataset, "vec_idx").await; assert_eq!(stats["num_unindexed_rows"], 512); assert_eq!(stats["num_indexed_rows"], 512); assert_eq!(stats["num_indexed_fragments"], 1); assert_eq!(stats["num_unindexed_fragments"], 1); assert_eq!(stats["num_indices"], 1); + let meta = get_meta(&dataset, "vec_idx").await; + assert_eq!(meta.len(), 1); + assert_eq!(get_bitmap(&meta[0]), vec![0]); // optimize the other index dataset - .optimize_indices(&OptimizeOptions { - num_indices_to_merge: 0, // Just create index for delta - index_names: Some(vec!["other_vec_idx".to_string()]), - }) + .optimize_indices( + &OptimizeOptions::append().index_names(vec!["other_vec_idx".to_owned()]), + ) .await .unwrap(); - let stats: serde_json::Value = - serde_json::from_str(&dataset.index_statistics("vec_idx").await.unwrap()).unwrap(); + let stats = get_stats(&dataset, "vec_idx").await; assert_eq!(stats["num_unindexed_rows"], 512); assert_eq!(stats["num_indexed_rows"], 512); assert_eq!(stats["num_indexed_fragments"], 1); assert_eq!(stats["num_unindexed_fragments"], 1); assert_eq!(stats["num_indices"], 1); + let meta = get_meta(&dataset, "vec_idx").await; + assert_eq!(meta.len(), 1); + assert_eq!(get_bitmap(&meta[0]), vec![0]); - let stats: serde_json::Value = - serde_json::from_str(&dataset.index_statistics("other_vec_idx").await.unwrap()) - .unwrap(); + let stats = get_stats(&dataset, "other_vec_idx").await; assert_eq!(stats["num_unindexed_rows"], 0); assert_eq!(stats["num_indexed_rows"], 1024); assert_eq!(stats["num_indexed_fragments"], 2); assert_eq!(stats["num_unindexed_fragments"], 0); assert_eq!(stats["num_indices"], 2); + let meta = get_meta(&dataset, "other_vec_idx").await; + assert_eq!(meta.len(), 2); + assert_eq!(get_bitmap(&meta[0]), vec![0]); + assert_eq!(get_bitmap(&meta[1]), vec![1]); dataset .optimize_indices(&OptimizeOptions { - num_indices_to_merge: 0, // Just create index for delta + num_indices_to_merge: 1, // merge the index with new data ..Default::default() }) .await .unwrap(); - let mut dataset = DatasetBuilder::from_uri(test_uri).load().await.unwrap(); - let stats: serde_json::Value = - serde_json::from_str(&dataset.index_statistics("vec_idx").await.unwrap()).unwrap(); + let stats = get_stats(&dataset, "vec_idx").await; assert_eq!(stats["num_unindexed_rows"], 0); assert_eq!(stats["num_indexed_rows"], 1024); assert_eq!(stats["num_indexed_fragments"], 2); assert_eq!(stats["num_unindexed_fragments"], 0); - assert_eq!(stats["num_indices"], 2); + assert_eq!(stats["num_indices"], 1); + let meta = get_meta(&dataset, "vec_idx").await; + assert_eq!(meta.len(), 1); + assert_eq!(get_bitmap(&meta[0]), vec![0, 1]); dataset .optimize_indices(&OptimizeOptions { @@ -1190,13 +1445,15 @@ mod tests { }) .await .unwrap(); - let stats: serde_json::Value = - serde_json::from_str(&dataset.index_statistics("vec_idx").await.unwrap()).unwrap(); + let stats = get_stats(&dataset, "other_vec_idx").await; assert_eq!(stats["num_unindexed_rows"], 0); assert_eq!(stats["num_indexed_rows"], 1024); assert_eq!(stats["num_indexed_fragments"], 2); assert_eq!(stats["num_unindexed_fragments"], 0); assert_eq!(stats["num_indices"], 1); + let meta = get_meta(&dataset, "other_vec_idx").await; + assert_eq!(meta.len(), 1); + assert_eq!(get_bitmap(&meta[0]), vec![0, 1]); } #[tokio::test] @@ -1314,7 +1571,7 @@ mod tests { .await .unwrap(); - let tokenizer_config = TokenizerConfig::default(); + let tokenizer_config = TokenizerConfig::default().lower_case(false); let params = InvertedIndexParams { with_position: true, tokenizer_config, @@ -1324,19 +1581,32 @@ mod tests { .await .unwrap(); + async fn assert_indexed_rows(dataset: &Dataset, expected_indexed_rows: usize) { + let stats = dataset.index_statistics("text_idx").await.unwrap(); + let stats: serde_json::Value = serde_json::from_str(&stats).unwrap(); + let indexed_rows = stats["num_indexed_rows"].as_u64().unwrap() as usize; + let unindexed_rows = stats["num_unindexed_rows"].as_u64().unwrap() as usize; + let num_rows = dataset.count_all_rows().await.unwrap(); + assert_eq!(indexed_rows, expected_indexed_rows); + assert_eq!(unindexed_rows, num_rows - expected_indexed_rows); + } + + let num_rows = dataset.count_all_rows().await.unwrap(); + assert_indexed_rows(&dataset, num_rows).await; + let new_words = ["elephant", "fig", "grape", "honeydew"]; let new_data = StringArray::from_iter_values(new_words.iter().map(|s| s.to_string())); let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(new_data)]).unwrap(); let batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); dataset.append(batch_iter, None).await.unwrap(); + assert_indexed_rows(&dataset, num_rows).await; dataset - .optimize_indices(&OptimizeOptions { - num_indices_to_merge: 0, - index_names: None, - }) + .optimize_indices(&OptimizeOptions::append()) .await .unwrap(); + let num_rows = dataset.count_all_rows().await.unwrap(); + assert_indexed_rows(&dataset, num_rows).await; for &word in words.iter().chain(new_words.iter()) { let query_result = dataset @@ -1363,6 +1633,125 @@ mod tests { assert_eq!(texts.len(), 1); assert_eq!(texts[0], word); } + + let uppercase_words = ["Apple", "Banana", "Cherry", "Date"]; + for &word in uppercase_words.iter() { + let query_result = dataset + .scan() + .project(&["text"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new(word.to_string())) + .unwrap() + .limit(Some(10), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + let texts = query_result["text"] + .as_string::() + .iter() + .map(|v| match v { + None => "".to_string(), + Some(v) => v.to_string(), + }) + .collect::>(); + + assert_eq!(texts.len(), 0); + } + let new_data = StringArray::from_iter_values(uppercase_words.iter().map(|s| s.to_string())); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(new_data)]).unwrap(); + let batch_iter = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + dataset.append(batch_iter, None).await.unwrap(); + assert_indexed_rows(&dataset, num_rows).await; + + // we should be able to query the new words + for &word in uppercase_words.iter() { + let query_result = dataset + .scan() + .project(&["text"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new(word.to_string())) + .unwrap() + .limit(Some(10), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + let texts = query_result["text"] + .as_string::() + .iter() + .map(|v| match v { + None => "".to_string(), + Some(v) => v.to_string(), + }) + .collect::>(); + + assert_eq!(texts.len(), 1, "query: {}, texts: {:?}", word, texts); + assert_eq!(texts[0], word, "query: {}, texts: {:?}", word, texts); + } + + dataset + .optimize_indices(&OptimizeOptions::append()) + .await + .unwrap(); + let num_rows = dataset.count_all_rows().await.unwrap(); + assert_indexed_rows(&dataset, num_rows).await; + + // we should be able to query the new words after optimization + for &word in uppercase_words.iter() { + let query_result = dataset + .scan() + .project(&["text"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new(word.to_string())) + .unwrap() + .limit(Some(10), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + let texts = query_result["text"] + .as_string::() + .iter() + .map(|v| match v { + None => "".to_string(), + Some(v) => v.to_string(), + }) + .collect::>(); + + assert_eq!(texts.len(), 1, "query: {}, texts: {:?}", word, texts); + assert_eq!(texts[0], word, "query: {}, texts: {:?}", word, texts); + + // we should be able to query the new words after compaction + compact_files(&mut dataset, CompactionOptions::default(), None) + .await + .unwrap(); + for &word in uppercase_words.iter() { + let query_result = dataset + .scan() + .project(&["text"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new(word.to_string())) + .unwrap() + .try_into_batch() + .await + .unwrap(); + let texts = query_result["text"] + .as_string::() + .iter() + .map(|v| match v { + None => "".to_string(), + Some(v) => v.to_string(), + }) + .collect::>(); + assert_eq!(texts.len(), 1, "query: {}, texts: {:?}", word, texts); + assert_eq!(texts[0], word, "query: {}, texts: {:?}", word, texts); + } + assert_indexed_rows(&dataset, num_rows).await; + } } #[tokio::test] @@ -1433,7 +1822,7 @@ mod tests { .unwrap(); let indices = dataset.load_indices().await.unwrap(); let index = dataset - .open_generic_index("tag", &indices[0].uuid.to_string()) + .open_generic_index("tag", &indices[0].uuid.to_string(), &NoOpMetricsCollector) .await .unwrap(); assert_eq!(index.index_type(), IndexType::Bitmap); diff --git a/rust/lance/src/index/append.rs b/rust/lance/src/index/append.rs index 3c6a377dd5a..dfcf3375532 100644 --- a/rust/lance/src/index/append.rs +++ b/rust/lance/src/index/append.rs @@ -4,12 +4,13 @@ use std::sync::Arc; use lance_core::{Error, Result}; +use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::IndexType; use lance_table::format::Index as IndexMetadata; use roaring::RoaringBitmap; -use snafu::{location, Location}; +use snafu::location; use uuid::Uuid; use super::vector::ivf::optimize_vector_indices; @@ -54,7 +55,7 @@ pub async fn merge_indices<'a>( let mut indices = Vec::with_capacity(old_indices.len()); for idx in old_indices { let index = dataset - .open_generic_index(&column.name, &idx.uuid.to_string()) + .open_generic_index(&column.name, &idx.uuid.to_string(), &NoOpMetricsCollector) .await?; indices.push(index); } @@ -71,17 +72,24 @@ pub async fn merge_indices<'a>( let unindexed = dataset.unindexed_fragments(&old_indices[0].name).await?; let mut frag_bitmap = RoaringBitmap::new(); - old_indices.iter().for_each(|idx| { - frag_bitmap.extend(idx.fragment_bitmap.as_ref().unwrap().iter()); - }); unindexed.iter().for_each(|frag| { frag_bitmap.insert(frag.id as u32); }); let (new_uuid, indices_merged) = match indices[0].index_type() { it if it.is_scalar() => { + // There are no delta indices for scalar, so adding all indexed + // fragments to the new index. + old_indices.iter().for_each(|idx| { + frag_bitmap.extend(idx.fragment_bitmap.as_ref().unwrap().iter()); + }); + let index = dataset - .open_scalar_index(&column.name, &old_indices[0].uuid.to_string()) + .open_scalar_index( + &column.name, + &old_indices[0].uuid.to_string(), + &NoOpMetricsCollector, + ) .await?; let mut scanner = dataset.scan(); @@ -98,20 +106,20 @@ pub async fn merge_indices<'a>( let new_uuid = Uuid::new_v4(); - // The BTree index implementation leverages the legacy format's batch offset, - // which has been removed from new format, so keep using the legacy format for now. - let new_store = match index.index_type() { - IndexType::Scalar | IndexType::BTree => { - LanceIndexStore::from_dataset(&dataset, &new_uuid.to_string()) - .with_legacy_format(true) - } - _ => LanceIndexStore::from_dataset(&dataset, &new_uuid.to_string()), - }; + let new_store = LanceIndexStore::from_dataset(&dataset, &new_uuid.to_string()); index.update(new_data_stream.into(), &new_store).await?; Ok((new_uuid, 1)) } it if it.is_vector() => { + let start_pos = old_indices + .len() + .saturating_sub(options.num_indices_to_merge); + let indices_to_merge = &old_indices[start_pos..]; + indices_to_merge.iter().for_each(|idx| { + frag_bitmap.extend(idx.fragment_bitmap.as_ref().unwrap().iter()); + }); + let new_data_stream = if unindexed.is_empty() { None } else { @@ -120,6 +128,9 @@ pub async fn merge_indices<'a>( .with_fragments(unindexed) .with_row_id() .project(&[&column.name])?; + if column.nullable { + scanner.filter_expr(datafusion_expr::col(&column.name).is_not_null()); + } Some(scanner.try_into_stream().await?) }; @@ -152,6 +163,7 @@ pub async fn merge_indices<'a>( mod tests { use super::*; + use arrow::datatypes::Float32Type; use arrow_array::cast::AsArray; use arrow_array::types::UInt32Type; use arrow_array::{FixedSizeListArray, RecordBatch, RecordBatchIterator, UInt32Array}; @@ -160,6 +172,7 @@ mod tests { use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::sq::builder::SQBuildParams; + use lance_index::vector::storage::VectorStore; use lance_index::{ vector::{ivf::IvfBuildParams, pq::PQBuildParams}, DatasetIndexExt, IndexType, @@ -170,8 +183,8 @@ mod tests { use tempfile::tempdir; use crate::dataset::builder::DatasetBuilder; - use crate::index::vector::ivf::IVFIndex; - use crate::index::vector::{pq::PQIndex, VectorIndexParams}; + use crate::index::vector::ivf::v2; + use crate::index::vector::VectorIndexParams; #[tokio::test] async fn test_append_index() { @@ -225,7 +238,9 @@ mod tests { let q = array.value(5); let mut scanner = dataset.scan(); - scanner.nearest("vector", q.as_primitive(), 10).unwrap(); + scanner + .nearest("vector", q.as_primitive::(), 10) + .unwrap(); let results = scanner .try_into_stream() .await @@ -247,17 +262,13 @@ mod tests { // There should be two indices directories existed. let object_store = dataset.object_store(); - let index_dirs = object_store - .read_dir_all(&dataset.indices_dir(), None) - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); + let index_dirs = object_store.read_dir(dataset.indices_dir()).await.unwrap(); assert_eq!(index_dirs.len(), 2); let mut scanner = dataset.scan(); - scanner.nearest("vector", q.as_primitive(), 10).unwrap(); + scanner + .nearest("vector", q.as_primitive::(), 10) + .unwrap(); let results = scanner .try_into_stream() .await @@ -275,15 +286,18 @@ mod tests { // Check that the index has all 2000 rows. let binding = dataset - .open_vector_index("vector", index.uuid.to_string().as_str()) + .open_vector_index( + "vector", + index.uuid.to_string().as_str(), + &NoOpMetricsCollector, + ) .await .unwrap(); - let ivf_index = binding.as_any().downcast_ref::().unwrap(); + let ivf_index = binding.as_any().downcast_ref::().unwrap(); let row_in_index = stream::iter(0..IVF_PARTITIONS) .map(|part_id| async move { - let part = ivf_index.load_partition(part_id, true).await.unwrap(); - let pq_idx = part.as_any().downcast_ref::().unwrap(); - pq_idx.row_ids.as_ref().unwrap().len() + let part = ivf_index.load_partition_storage(part_id).await.unwrap(); + part.len() }) .buffered(2) .collect::>() @@ -385,7 +399,7 @@ mod tests { .scan() .project(&["id"]) .unwrap() - .nearest("vector", array.value(0).as_primitive(), 2) + .nearest("vector", array.value(0).as_primitive::(), 2) .unwrap() .refine(1) .try_into_batch() diff --git a/rust/lance/src/index/cache.rs b/rust/lance/src/index/cache.rs index bb28fa10930..f83faf8d5ba 100644 --- a/rust/lance/src/index/cache.rs +++ b/rust/lance/src/index/cache.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use deepsize::DeepSizeOf; +use lance_index::vector::VectorIndexCacheEntry; use lance_index::{ scalar::{ScalarIndex, ScalarIndexType}, vector::VectorIndex, @@ -13,8 +14,6 @@ use moka::sync::Cache; use std::sync::atomic::{AtomicU64, Ordering}; -use crate::dataset::DEFAULT_INDEX_CACHE_SIZE; - #[derive(Debug, Default, DeepSizeOf)] struct CacheStats { hits: AtomicU64, @@ -36,6 +35,8 @@ pub struct IndexCache { // TODO: Can we merge these two caches into one for uniform memory management? scalar_cache: Arc>>, vector_cache: Arc>>, + // this is for v3 index, sadly we can't use the same cache as the vector index for now + vector_partition_cache: Arc>>, /// Index metadata cache. /// @@ -61,6 +62,11 @@ impl DeepSizeOf for IndexCache { .iter() .map(|(_, v)| v.deep_size_of_children(context)) .sum::() + + self + .vector_partition_cache + .iter() + .map(|(_, v)| v.deep_size_of_children(context)) + .sum::() + self .metadata_cache .iter() @@ -75,19 +81,13 @@ impl IndexCache { Self { scalar_cache: Arc::new(Cache::new(capacity as u64)), vector_cache: Arc::new(Cache::new(capacity as u64)), + vector_partition_cache: Arc::new(Cache::new(capacity as u64)), metadata_cache: Arc::new(Cache::new(capacity as u64)), type_cache: Arc::new(Cache::new(capacity as u64)), cache_stats: Arc::new(CacheStats::default()), } } - pub(crate) fn capacity(&self) -> u64 { - self.vector_cache - .policy() - .max_capacity() - .unwrap_or(DEFAULT_INDEX_CACHE_SIZE as u64) - } - #[allow(dead_code)] pub(crate) fn len_vector(&self) -> usize { self.vector_cache.run_pending_tasks(); @@ -97,9 +97,18 @@ impl IndexCache { pub(crate) fn get_size(&self) -> usize { self.scalar_cache.run_pending_tasks(); self.vector_cache.run_pending_tasks(); + self.vector_partition_cache.run_pending_tasks(); self.metadata_cache.run_pending_tasks(); (self.scalar_cache.entry_count() + self.vector_cache.entry_count() + + self.vector_partition_cache.entry_count() + + self.metadata_cache.entry_count()) as usize + } + + pub(crate) fn approx_size(&self) -> usize { + (self.scalar_cache.entry_count() + + self.vector_cache.entry_count() + + self.vector_partition_cache.entry_count() + self.metadata_cache.entry_count()) as usize } @@ -134,6 +143,16 @@ impl IndexCache { } } + pub(crate) fn get_vector_partition(&self, key: &str) -> Option> { + if let Some(index) = self.vector_partition_cache.get(key) { + self.cache_stats.record_hit(); + Some(index) + } else { + self.cache_stats.record_miss(); + None + } + } + /// Insert a new entry into the cache. pub(crate) fn insert_scalar(&self, key: &str, index: Arc) { self.scalar_cache.insert(key.to_string(), index); @@ -143,6 +162,10 @@ impl IndexCache { self.vector_cache.insert(key.to_string(), index); } + pub(crate) fn insert_vector_partition(&self, key: &str, index: Arc) { + self.vector_partition_cache.insert(key.to_string(), index); + } + /// Construct a key for index metadata arrays. fn metadata_key(dataset_uuid: &str, version: u64) -> String { format!("{}:{}", dataset_uuid, version) diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 8efa3ec8297..6b03ae486c1 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -8,9 +8,13 @@ use std::sync::Arc; use arrow_schema::DataType; use async_trait::async_trait; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::SendableRecordBatchStream; +use futures::TryStreamExt; use lance_core::{Error, Result}; use lance_datafusion::{chunker::chunk_concat_stream, exec::LanceExecutionOptions}; +use lance_index::scalar::btree::DEFAULT_BTREE_BATCH_SIZE; +use lance_index::scalar::ngram::{train_ngram_index, NGramIndex}; use lance_index::scalar::InvertedIndexParams; use lance_index::scalar::{ bitmap::{train_bitmap_index, BitmapIndex, BITMAP_LOOKUP_NAME}, @@ -22,7 +26,8 @@ use lance_index::scalar::{ ScalarIndex, ScalarIndexParams, ScalarIndexType, }; use lance_table::format::Index; -use snafu::{location, Location}; +use log::info; +use snafu::location; use tracing::instrument; use crate::session::Session; @@ -31,6 +36,9 @@ use crate::{ Dataset, }; +// Log an update every TRAINING_UPDATE_FREQ million rows processed +const TRAINING_UPDATE_FREQ: usize = 1000000; + struct TrainingRequest { dataset: Arc, column: String, @@ -59,8 +67,28 @@ impl TrainingRequest { chunk_size: u32, sort: bool, ) -> Result { + let num_rows = self.dataset.count_all_rows().await?; + let mut scan = self.dataset.scan(); + let column_field = + self.dataset + .schema() + .field(&self.column) + .ok_or(Error::InvalidInput { + source: format!("No column with name {}", self.column).into(), + location: location!(), + })?; + + // Datafusion currently has bugs with spilling on string columns + // See https://github.com/apache/datafusion/issues/10073 + // + // One we upgrade we can remove this + let use_spilling = !matches!( + column_field.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + ); + let ordering = match sort { true => Some(vec![ColumnOrdering::asc_nulls_first(self.column.clone())]), false => None, @@ -73,11 +101,34 @@ impl TrainingRequest { let batches = scan .try_into_dfstream(LanceExecutionOptions { - use_spilling: true, + use_spilling, ..Default::default() }) .await?; - Ok(chunk_concat_stream(batches, chunk_size as usize)) + let batches = chunk_concat_stream(batches, chunk_size as usize); + + let schema = batches.schema(); + let mut rows_processed = 0; + let mut next_update = TRAINING_UPDATE_FREQ; + let training_uuid = uuid::Uuid::new_v4().to_string(); + info!( + "Starting index training job with id {} on column {}", + training_uuid, self.column + ); + info!("Training index (job_id={}): 0/{}", training_uuid, num_rows); + let batches = batches.map_ok(move |batch| { + rows_processed += batch.num_rows(); + if rows_processed >= next_update { + next_update += TRAINING_UPDATE_FREQ; + info!( + "Training index (job_id={}): {}/{}", + training_uuid, rows_processed, num_rows + ); + } + batch + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, batches))) } } @@ -87,7 +138,6 @@ impl TrainingRequest { // to make index types "generic" and "pluggable". We will need to create some // kind of core proto for scalar indices that the scanner can read to determine // how and when to use a scalar index. - pub trait ScalarIndexDetails { fn get_type(&self) -> ScalarIndexType; } @@ -107,6 +157,11 @@ fn label_list_index_details() -> prost_types::Any { prost_types::Any::from_msg(&details).unwrap() } +fn ngram_index_details() -> prost_types::Any { + let details = lance_table::format::pb::NGramIndexDetails {}; + prost_types::Any::from_msg(&details).unwrap() +} + pub(super) fn inverted_index_details() -> prost_types::Any { let details = lance_table::format::pb::InvertedIndexDetails::default(); prost_types::Any::from_msg(&details).unwrap() @@ -136,6 +191,12 @@ impl ScalarIndexDetails for lance_table::format::pb::InvertedIndexDetails { } } +impl ScalarIndexDetails for lance_table::format::pb::NGramIndexDetails { + fn get_type(&self) -> ScalarIndexType { + ScalarIndexType::NGram + } +} + fn get_scalar_index_details( details: &prost_types::Any, ) -> Result>> { @@ -155,6 +216,10 @@ fn get_scalar_index_details( Ok(Some(Box::new( details.to_msg::()?, ))) + } else if details.type_url.ends_with("NGramIndexDetails") { + Ok(Some(Box::new( + details.to_msg::()?, + ))) } else { Ok(None) } @@ -224,12 +289,25 @@ pub(super) async fn build_scalar_index( .await?; Ok(inverted_index_details()) } + Some(ScalarIndexType::NGram) => { + if field.data_type() != DataType::Utf8 { + return Err(Error::InvalidInput { + source: "NGram index can only be created on Utf8 type columns".into(), + location: location!(), + }); + } + train_ngram_index(training_request, &index_store).await?; + Ok(ngram_index_details()) + } _ => { - // The BTree index implementation leverages the legacy format's batch offset, - // which has been removed from new format, so keep using the legacy format for now. - let index_store = index_store.with_legacy_format(true); let flat_index_trainer = FlatIndexMetadata::new(field.data_type()); - train_btree_index(training_request, &flat_index_trainer, &index_store).await?; + train_btree_index( + training_request, + &flat_index_trainer, + &index_store, + DEFAULT_BTREE_BATCH_SIZE as u32, + ) + .await?; Ok(btree_index_details()) } } @@ -272,6 +350,10 @@ pub async fn open_scalar_index( let inverted_index = InvertedIndex::load(index_store).await?; Ok(inverted_index as Arc) } + ScalarIndexType::NGram => { + let ngram_index = NGramIndex::load(index_store).await?; + Ok(ngram_index as Arc) + } ScalarIndexType::BTree => { let btree_index = BTreeIndex::load(index_store).await?; Ok(btree_index as Arc) diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index bd05fcc6436..3ba5b52e4cf 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -10,14 +10,16 @@ use std::{any::Any, collections::HashMap}; pub mod builder; pub mod ivf; pub mod pq; -mod utils; +pub mod utils; #[cfg(test)] mod fixture_test; +use arrow_schema::DataType; use builder::IvfIndexBuilder; use lance_file::reader::FileReader; -use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::metrics::NoOpMetricsCollector; +use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::ProductQuantizer; @@ -37,9 +39,10 @@ use lance_io::traits::Reader; use lance_linalg::distance::*; use lance_table::format::Index as IndexMetadata; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tempfile::tempdir; use tracing::instrument; +use utils::get_vector_type; use uuid::Uuid; use self::{ivf::*, pq::PQIndex}; @@ -124,7 +127,7 @@ impl VectorIndexParams { Self { stages, metric_type, - version: IndexFileVersion::Legacy, + version: IndexFileVersion::V3, } } @@ -138,7 +141,7 @@ impl VectorIndexParams { Self { stages, metric_type, - version: IndexFileVersion::Legacy, + version: IndexFileVersion::V3, } } @@ -248,22 +251,57 @@ pub(crate) async fn build_vector_index( }); }; + let (vector_type, element_type) = get_vector_type(dataset.schema(), column)?; + if let DataType::List(_) = vector_type { + if params.metric_type != DistanceType::Cosine { + return Err(Error::Index { + message: "Build Vector Index: multivector type supports only cosine distance" + .to_string(), + location: location!(), + }); + } + } + let temp_dir = tempdir()?; let temp_dir_path = Path::from_filesystem_path(temp_dir.path())?; let shuffler = IvfShuffler::new(temp_dir_path, ivf_params.num_partitions); if is_ivf_flat(stages) { - IvfIndexBuilder::::new( - dataset.clone(), - column.to_owned(), - dataset.indices_dir().child(uuid), - params.metric_type, - Box::new(shuffler), - Some(ivf_params.clone()), - Some(()), - (), - )? - .build() - .await?; + match element_type { + DataType::Float16 | DataType::Float32 | DataType::Float64 => { + IvfIndexBuilder::::new( + dataset.clone(), + column.to_owned(), + dataset.indices_dir().child(uuid), + params.metric_type, + Box::new(shuffler), + Some(ivf_params.clone()), + Some(()), + (), + )? + .build() + .await?; + } + DataType::UInt8 => { + IvfIndexBuilder::::new( + dataset.clone(), + column.to_owned(), + dataset.indices_dir().child(uuid), + params.metric_type, + Box::new(shuffler), + Some(ivf_params.clone()), + Some(()), + (), + )? + .build() + .await?; + } + _ => { + return Err(Error::Index { + message: format!("Build Vector Index: invalid data type: {:?}", element_type), + location: location!(), + }); + } + } } else if is_ivf_pq(stages) { let len = stages.len(); let StageParams::PQ(pq_params) = &stages[len - 1] else { @@ -369,33 +407,38 @@ pub(crate) async fn remap_vector_index( mapping: &HashMap>, ) -> Result<()> { let old_index = dataset - .open_vector_index(column, &old_uuid.to_string()) + .open_vector_index(column, &old_uuid.to_string(), &NoOpMetricsCollector) .await?; old_index.check_can_remap()?; - let ivf_index: &IVFIndex = - old_index - .as_any() - .downcast_ref() - .ok_or_else(|| Error::NotSupported { - source: "Only IVF indexes can be remapped currently".into(), - location: location!(), - })?; - - remap_index_file( - dataset.as_ref(), - &old_uuid.to_string(), - &new_uuid.to_string(), - old_metadata.dataset_version, - ivf_index, - mapping, - old_metadata.name.clone(), - column.to_string(), - // We can safely assume there are no transforms today. We assert above that the - // top stage is IVF and IVF does not support transforms between IVF and PQ. This - // will be fixed in the future. - vec![], - ) - .await?; + + if let Some(ivf_index) = old_index.as_any().downcast_ref::() { + remap_index_file( + dataset.as_ref(), + &old_uuid.to_string(), + &new_uuid.to_string(), + old_metadata.dataset_version, + ivf_index, + mapping, + old_metadata.name.clone(), + column.to_string(), + // We can safely assume there are no transforms today. We assert above that the + // top stage is IVF and IVF does not support transforms between IVF and PQ. This + // will be fixed in the future. + vec![], + ) + .await?; + } else { + // it's v3 index + remap_index_file_v3( + dataset.as_ref(), + &new_uuid.to_string(), + old_index, + mapping, + column.to_string(), + ) + .await?; + } + Ok(()) } @@ -403,7 +446,6 @@ pub(crate) async fn remap_vector_index( #[instrument(level = "debug", skip(dataset, vec_idx, reader))] pub(crate) async fn open_vector_index( dataset: Arc, - column: &str, uuid: &str, vec_idx: &lance_index::pb::VectorIndex, reader: Arc, diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index c4c22265c4a..2f62370f381 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -1,28 +1,33 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::sync::Arc; use arrow::array::AsArray; -use arrow_array::{RecordBatch, UInt64Array}; +use arrow::datatypes; +use arrow_array::{FixedSizeListArray, RecordBatch, UInt64Array}; use futures::prelude::stream::{StreamExt, TryStreamExt}; +use futures::{stream, FutureExt}; use itertools::Itertools; -use lance_arrow::RecordBatchExt; +use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt}; use lance_core::cache::FileMetadataCache; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; use lance_file::v2::reader::FileReaderOptions; use lance_file::v2::{reader::FileReader, writer::FileWriter}; -use lance_index::vector::flat::storage::FlatStorage; +use lance_index::metrics::NoOpMetricsCollector; use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::pq::storage::transpose; use lance_index::vector::quantizer::{ QuantizationMetadata, QuantizationType, QuantizerBuildParams, }; use lance_index::vector::storage::STORAGE_METADATA_KEY; +use lance_index::vector::utils::is_finite; use lance_index::vector::v3::shuffler::IvfShufflerReader; use lance_index::vector::v3::subindex::SubIndexType; -use lance_index::vector::VectorIndex; +use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PART_ID_COLUMN, PQ_CODE_COLUMN}; use lance_index::{ pb, vector::{ @@ -49,11 +54,12 @@ use lance_linalg::distance::DistanceType; use log::info; use object_store::path::Path; use prost::Message; -use snafu::{location, Location}; +use snafu::location; use tempfile::{tempdir, TempDir}; -use tracing::{span, Level}; +use tracing::{instrument, span, Level}; use crate::dataset::ProjectionRequest; +use crate::index::vector::ivf::v2::PartitionEntry; use crate::Dataset; use super::utils; @@ -65,30 +71,32 @@ use super::v2::IVFIndex; // To build the index for the whole dataset, call `build` method. // To build the index for given IVF, quantizer, data stream, // call `with_ivf`, `with_quantizer`, `shuffle_data`, and `build` in order. -pub struct IvfIndexBuilder { - dataset: Dataset, +pub struct IvfIndexBuilder { + store: ObjectStore, column: String, index_dir: Path, distance_type: DistanceType, - shuffler: Arc, + retrain: bool, // build params, only needed for building new IVF, quantizer + dataset: Option, + shuffler: Option>, ivf_params: Option, quantizer_params: Option, - sub_index_params: S::BuildParams, + sub_index_params: Option, _temp_dir: TempDir, // store this for keeping the temp dir alive and clean up after build temp_dir: Path, // fields will be set during build ivf: Option, quantizer: Option, - shuffle_reader: Option>, + shuffle_reader: Option>, partition_sizes: Vec<(usize, usize)>, - // fields for merging indices + // fields for merging indices / remapping existing_indices: Vec>, } -impl IvfIndexBuilder { +impl IvfIndexBuilder { #[allow(clippy::too_many_arguments)] pub fn new( dataset: Dataset, @@ -103,14 +111,16 @@ impl IvfIndexBuilde let temp_dir = tempdir()?; let temp_dir_path = Path::from_filesystem_path(temp_dir.path())?; Ok(Self { - dataset, + store: dataset.object_store().clone(), column, index_dir, distance_type, - shuffler: shuffler.into(), + retrain: false, + dataset: Some(dataset), + shuffler: Some(shuffler.into()), ivf_params, quantizer_params, - sub_index_params, + sub_index_params: Some(sub_index_params), _temp_dir: temp_dir, temp_dir: temp_dir_path, // fields will be set during build @@ -142,16 +152,56 @@ impl IvfIndexBuilde ) } + pub fn new_remapper( + store: ObjectStore, + column: String, + index_dir: Path, + index: Arc, + ) -> Result { + let ivf_index = + index + .as_any() + .downcast_ref::>() + .ok_or(Error::invalid_input( + "existing index is not IVF index", + location!(), + ))?; + + let temp_dir = tempdir()?; + let temp_dir_path = Path::from_filesystem_path(temp_dir.path())?; + Ok(Self { + store, + column, + index_dir, + distance_type: ivf_index.metric_type(), + retrain: false, + dataset: None, + shuffler: None, + ivf_params: None, + quantizer_params: None, + sub_index_params: None, + _temp_dir: temp_dir, + temp_dir: temp_dir_path, + ivf: Some(ivf_index.ivf_model().clone()), + quantizer: Some(ivf_index.quantizer().try_into()?), + shuffle_reader: None, + partition_sizes: Vec::new(), + existing_indices: vec![index], + }) + } + // build the index with the all data in the dataset, pub async fn build(&mut self) -> Result<()> { - // step 1. train IVF & quantizer - if self.ivf.is_none() { - self.with_ivf(self.load_or_build_ivf().await?); - } - if self.quantizer.is_none() { - self.with_quantizer(self.load_or_build_quantizer().await?); + if self.retrain { + self.shuffle_reader = None; + self.existing_indices = Vec::new(); } + // step 1. train IVF & quantizer + self.with_ivf(self.load_or_build_ivf().await?); + + self.with_quantizer(self.load_or_build_quantizer().await?); + // step 2. shuffle the dataset if self.shuffle_reader.is_none() { self.shuffle_dataset().await?; @@ -166,6 +216,68 @@ impl IvfIndexBuilde Ok(()) } + pub async fn remap(&mut self, mapping: &HashMap>) -> Result<()> { + debug_assert_eq!(self.existing_indices.len(), 1); + let ivf_index = self.existing_indices[0] + .as_any() + .downcast_ref::>() + .ok_or(Error::invalid_input( + "existing index is not IVF index", + location!(), + ))?; + + let model = ivf_index.ivf_model(); + let mapped = stream::iter(0..model.num_partitions()) + .map(|part_id| async move { + let part = ivf_index + .load_partition(part_id, false, &NoOpMetricsCollector) + .await?; + let part = part.as_any().downcast_ref::>().ok_or( + Error::Internal { + message: "failed to downcast partition entry".to_string(), + location: location!(), + }, + )?; + Result::Ok((part.storage.remap(mapping)?, part.index.remap(mapping)?)) + }) + .buffered(get_num_compute_intensive_cpus()) + .try_collect::>() + .await?; + + self.partition_sizes = vec![(0, 0); model.num_partitions()]; + let local_store = ObjectStore::local(); + for (part_id, (store, index)) in mapped.into_iter().enumerate() { + let path = self.temp_dir.child(format!("storage_part{}", part_id)); + let batches = store.to_batches()?; + let schema = store.schema().as_ref().try_into()?; + let store_len = FileWriter::create_file_with_batches( + &local_store, + &path, + schema, + batches, + Default::default(), + ) + .await?; + + let path = self.temp_dir.child(format!("index_part{}", part_id)); + let batch = index.to_batch()?; + let schema = batch.schema().as_ref().try_into()?; + let index_len = FileWriter::create_file_with_batches( + &local_store, + &path, + schema, + std::iter::once(batch), + Default::default(), + ) + .await?; + + self.partition_sizes[part_id] = (store_len, index_len); + } + + self.merge_partitions().await?; + Ok(()) + } + pub fn with_ivf(&mut self, ivf: IvfModel) -> &mut Self { self.ivf = Some(ivf); self @@ -181,30 +293,63 @@ impl IvfIndexBuilde self } + pub fn retrain(&mut self, retrain: bool) -> &mut Self { + self.retrain = retrain; + self + } + + #[instrument(name = "load_or_build_ivf", level = "debug", skip_all)] async fn load_or_build_ivf(&self) -> Result { - let ivf_params = self.ivf_params.as_ref().ok_or(Error::invalid_input( - "IVF build params not set", + let dataset = self.dataset.as_ref().ok_or(Error::invalid_input( + "dataset not set before loading or building IVF", location!(), ))?; - let dim = utils::get_vector_dim(&self.dataset, &self.column)?; - super::build_ivf_model( - &self.dataset, - &self.column, - dim, - self.distance_type, - ivf_params, - ) - .await - // TODO: load ivf model + let dim = utils::get_vector_dim(dataset.schema(), &self.column)?; + match &self.ivf { + Some(ivf) => { + if self.retrain { + // retrain the IVF model with the existing indices + let mut ivf_params = IvfBuildParams::new(ivf.num_partitions()); + ivf_params.retrain = true; + + super::build_ivf_model( + dataset, + &self.column, + dim, + self.distance_type, + &ivf_params, + ) + .await + } else { + Ok(ivf.clone()) + } + } + None => { + let ivf_params = self.ivf_params.as_ref().ok_or(Error::invalid_input( + "IVF build params not set", + location!(), + ))?; + super::build_ivf_model(dataset, &self.column, dim, self.distance_type, ivf_params) + .await + } + } } + #[instrument(name = "load_or_build_quantizer", level = "debug", skip_all)] async fn load_or_build_quantizer(&self) -> Result { - let quantizer_params = self.quantizer_params.as_ref().ok_or(Error::invalid_input( - "quantizer build params not set", + if self.quantizer.is_some() && !self.retrain { + return Ok(self.quantizer.clone().unwrap()); + } + + let dataset = self.dataset.as_ref().ok_or(Error::invalid_input( + "dataset not set before loading or building quantizer", location!(), ))?; - let sample_size_hint = quantizer_params.sample_size(); + let sample_size_hint = match &self.quantizer_params { + Some(params) => params.sample_size(), + None => 256 * 256, // here it must be retrain, let's just set sample size to the default value + }; let start = std::time::Instant::now(); info!( @@ -212,8 +357,7 @@ impl IvfIndexBuilde sample_size_hint ); let training_data = - utils::maybe_sample_training_data(&self.dataset, &self.column, sample_size_hint) - .await?; + utils::maybe_sample_training_data(dataset, &self.column, sample_size_hint).await?; info!( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() @@ -221,29 +365,46 @@ impl IvfIndexBuilde // If metric type is cosine, normalize the training data, and after this point, // treat the metric type as L2. - let (training_data, dt) = if self.distance_type == DistanceType::Cosine { - let training_data = lance_linalg::kernels::normalize_fsl(&training_data)?; - (training_data, DistanceType::L2) + let training_data = if self.distance_type == DistanceType::Cosine { + lance_linalg::kernels::normalize_fsl(&training_data)? } else { - (training_data, self.distance_type) + training_data }; + // we filtered out nulls when sampling, but we still need to filter out NaNs and INFs here + let training_data = arrow::compute::filter(&training_data, &is_finite(&training_data))?; + let training_data = training_data.as_fixed_size_list(); + let training_data = match (self.ivf.as_ref(), Q::use_residual(self.distance_type)) { (Some(ivf), true) => { let ivf_transformer = lance_index::vector::ivf::new_ivf_transformer( ivf.centroids.clone().unwrap(), - dt, + DistanceType::L2, vec![], ); span!(Level::INFO, "compute residual for PQ training") - .in_scope(|| ivf_transformer.compute_residual(&training_data))? + .in_scope(|| ivf_transformer.compute_residual(training_data))? } - _ => training_data, + _ => training_data.clone(), }; info!("Start to train quantizer"); let start = std::time::Instant::now(); - let quantizer = Q::build(&training_data, DistanceType::L2, quantizer_params)?; + let quantizer = match &self.quantizer { + Some(q) => { + let mut q = q.clone(); + if self.retrain { + q.retrain(&training_data)?; + } + q + } + None => { + let quantizer_params = self.quantizer_params.as_ref().ok_or( + Error::invalid_input("quantizer build params not set", location!()), + )?; + Q::build(&training_data, DistanceType::L2, quantizer_params)? + } + }; info!( "Trained quantizer in {:02} seconds", start.elapsed().as_secs_f32() @@ -252,14 +413,16 @@ impl IvfIndexBuilde } async fn shuffle_dataset(&mut self) -> Result<()> { - let stream = self - .dataset - .scan() + let dataset = self.dataset.as_ref().ok_or(Error::invalid_input( + "dataset not set before shuffling", + location!(), + ))?; + let mut builder = dataset.scan(); + builder .batch_readahead(get_num_compute_intensive_cpus()) .project(&[self.column.as_str()])? - .with_row_id() - .try_into_stream() - .await?; + .with_row_id(); + let stream = builder.try_into_stream().await?; self.shuffle_data(Some(stream)).await?; Ok(()) } @@ -284,6 +447,10 @@ impl IvfIndexBuilde "quantizer not set before shuffle data", location!(), ))?; + let shuffler = self.shuffler.as_ref().ok_or(Error::invalid_input( + "shuffler not set before shuffle data", + location!(), + ))?; let transformer = Arc::new( lance_index::vector::ivf::new_ivf_transformer_with_quantizer( @@ -291,7 +458,7 @@ impl IvfIndexBuilde self.distance_type, &self.column, quantizer.into(), - Some(0..ivf.num_partitions() as u32), + None, )?, ); let mut transformed_stream = Box::pin( @@ -310,33 +477,43 @@ impl IvfIndexBuilde Some(Err(e)) => panic!("do this better: error reading first batch: {:?}", e), None => { log::info!("no data to shuffle"); - self.shuffle_reader = Some(Box::new(IvfShufflerReader::new( - self.dataset.object_store.clone(), + self.shuffle_reader = Some(Arc::new(IvfShufflerReader::new( + Arc::new(self.store.clone()), self.temp_dir.clone(), vec![0; ivf.num_partitions()], + 0.0, ))); return Ok(self); } }; self.shuffle_reader = Some( - self.shuffler + shuffler .shuffle(Box::new(RecordBatchStreamAdapter::new( schema, transformed_stream, ))) - .await?, + .await? + .into(), ); Ok(self) } + #[instrument(name = "build_partitions", level = "debug", skip_all)] async fn build_partitions(&mut self) -> Result<&mut Self> { - let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( + let ivf = self.ivf.as_mut().ok_or(Error::invalid_input( "IVF not set before building partitions", location!(), ))?; - + let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( + "quantizer not set before building partition", + location!(), + ))?; + let sub_index_params = self.sub_index_params.clone().ok_or(Error::invalid_input( + "sub index params not set before building partition", + location!(), + ))?; let reader = self.shuffle_reader.as_ref().ok_or(Error::invalid_input( "shuffle reader not set before building partitions", location!(), @@ -352,105 +529,164 @@ impl IvfIndexBuilde .map(|(idx, _)| idx) .collect::>(); + let reader = reader.clone(); + let existing_indices = Arc::new(self.existing_indices.clone()); + let distance_type = self.distance_type; let mut partition_sizes = vec![(0, 0); ivf.num_partitions()]; - for (i, &partition) in partition_build_order.iter().enumerate() { - log::info!( - "building partition {}, progress {}/{}", - partition, - i + 1, - ivf.num_partitions(), - ); - let mut batches = Vec::new(); - for existing_index in self.existing_indices.iter() { - let existing_index = existing_index - .as_any() - .downcast_ref::>() - .ok_or(Error::invalid_input( - "existing index is not IVF index", - location!(), - ))?; - - let part_storage = existing_index.load_partition_storage(partition).await?; - batches.extend( - self.take_vectors(part_storage.row_ids().cloned().collect_vec().as_ref()) - .await?, - ); - } + let build_iter = partition_build_order.iter().map(|&partition| { + let reader = reader.clone(); + let existing_indices = existing_indices.clone(); + let temp_dir = self.temp_dir.clone(); + let quantizer = quantizer.clone(); + let sub_index_params = sub_index_params.clone(); + let column = self.column.clone(); + async move { + let (batches, loss) = Self::take_partition_batches( + partition, + existing_indices.as_ref(), + reader.as_ref(), + ) + .await?; - match reader.partition_size(partition)? { - 0 => continue, - _ => { - let partition_data = - reader.read_partition(partition).await?.ok_or(Error::io( - format!("partition {} is empty", partition).as_str(), - location!(), - ))?; - batches.extend(partition_data.try_collect::>().await?); + let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); + if num_rows == 0 { + return Ok(((0, 0), 0.0)); } - } - let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); - if num_rows == 0 { - continue; + Self::build_partition( + &temp_dir, + distance_type, + quantizer, + sub_index_params, + batches, + partition, + column, + ) + .await + .map(|res| (res, loss)) } - let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; - let sizes = self.build_partition(partition, &batch).await?; - partition_sizes[partition] = sizes; - log::info!( - "partition {} built, progress {}/{}", - partition, - i + 1, - ivf.num_partitions() - ); + }); + let results = stream::iter(build_iter) + .buffered(get_num_compute_intensive_cpus()) + .try_collect::>() + .boxed() + .await?; + + let mut total_loss = 0.0; + for (i, (res, loss)) in results.into_iter().enumerate() { + total_loss += loss; + partition_sizes[partition_build_order[i]] = res; + } + if let Some(loss) = reader.total_loss() { + total_loss += loss; } + ivf.loss = Some(total_loss); + self.partition_sizes = partition_sizes; Ok(self) } - async fn build_partition(&self, part_id: usize, batch: &RecordBatch) -> Result<(usize, usize)> { - let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( - "quantizer not set before building partition", - location!(), - ))?; - + #[instrument(name = "build_partition", level = "debug", skip_all)] + #[allow(clippy::too_many_arguments)] + async fn build_partition( + temp_dir: &Path, + distance_type: DistanceType, + quantizer: Q, + sub_index_params: S::BuildParams, + batches: Vec, + part_id: usize, + column: String, + ) -> Result<(usize, usize)> { + let local_store = ObjectStore::local(); // build quantized vector storage - let object_store = ObjectStore::local(); - let storage_len = { - let storage = StorageBuilder::new(self.column.clone(), self.distance_type, quantizer) - .build(batch)?; - let path = self.temp_dir.child(format!("storage_part{}", part_id)); - let writer = object_store.create(&path).await?; - let mut writer = FileWriter::try_new( - writer, - storage.schema().as_ref().try_into()?, - Default::default(), - )?; - for batch in storage.to_batches()? { - writer.write_batch(&batch).await?; - } - writer.finish().await? as usize - }; + let storage = StorageBuilder::new(column, distance_type, quantizer)?.build(batches)?; + + let path = temp_dir.child(format!("storage_part{}", part_id)); + let batches = storage.to_batches()?; + let write_storage_fut = FileWriter::create_file_with_batches( + &local_store, + &path, + storage.schema().as_ref().try_into()?, + batches, + Default::default(), + ); // build the sub index, with in-memory storage - let index_len = { - let vectors = batch[&self.column].as_fixed_size_list(); - let flat_storage = FlatStorage::new(vectors.clone(), self.distance_type); - let sub_index = S::index_vectors(&flat_storage, self.sub_index_params.clone())?; - let path = self.temp_dir.child(format!("index_part{}", part_id)); - let writer = object_store.create(&path).await?; - let index_batch = sub_index.to_batch()?; - let mut writer = FileWriter::try_new( - writer, - index_batch.schema_ref().as_ref().try_into()?, - Default::default(), - )?; - writer.write_batch(&index_batch).await?; - writer.finish().await? as usize - }; + let sub_index = S::index_vectors(&storage, sub_index_params)?; + let path = temp_dir.child(format!("index_part{}", part_id)); + let index_batch = sub_index.to_batch()?; + let schema = index_batch.schema().as_ref().try_into()?; + let write_index_fut = FileWriter::create_file_with_batches( + &local_store, + &path, + schema, + std::iter::once(index_batch), + Default::default(), + ); - Ok((storage_len, index_len)) + futures::try_join!(write_storage_fut, write_index_fut) } + #[instrument(name = "take_partition_batches", level = "debug", skip_all)] + async fn take_partition_batches( + part_id: usize, + existing_indices: &[Arc], + reader: &dyn ShuffleReader, + ) -> Result<(Vec, f64)> { + let mut batches = Vec::new(); + for existing_index in existing_indices.iter() { + let existing_index = existing_index + .as_any() + .downcast_ref::>() + .ok_or(Error::invalid_input( + "existing index is not IVF index", + location!(), + ))?; + + let part_storage = existing_index.load_partition_storage(part_id).await?; + let mut part_batches = part_storage.to_batches()?.collect::>(); + // for PQ, the PQ codes are transposed, so we need to transpose them back + if matches!(Q::quantization_type(), QuantizationType::Product) { + for batch in part_batches.iter_mut() { + let codes = batch[PQ_CODE_COLUMN] + .as_fixed_size_list() + .values() + .as_primitive::(); + let codes_num_bytes = codes.len() / batch.num_rows(); + let original_codes = transpose(codes, codes_num_bytes, batch.num_rows()); + let original_codes = FixedSizeListArray::try_new_from_values( + original_codes, + codes_num_bytes as i32, + )?; + *batch = batch + .replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))? + .drop_column(PART_ID_COLUMN)?; + } + } + batches.extend(part_batches); + } + + let mut loss = 0.0; + if reader.partition_size(part_id)? > 0 { + let mut partition_data = reader.read_partition(part_id).await?.ok_or(Error::io( + format!("partition {} is empty", part_id).as_str(), + location!(), + ))?; + while let Some(batch) = partition_data.try_next().await? { + loss += batch + .metadata() + .get(LOSS_METADATA_KEY) + .map(|s| s.parse::().unwrap_or(0.0)) + .unwrap_or(0.0); + let batch = batch.drop_column(PART_ID_COLUMN)?; + batches.push(batch); + } + } + + Ok((batches, loss)) + } + + #[instrument(name = "merge_partitions", level = "debug", skip_all)] async fn merge_partitions(&mut self) -> Result<()> { let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( "IVF not set before merge partitions", @@ -470,14 +706,14 @@ impl IvfIndexBuilde let index_path = self.index_dir.child(INDEX_FILE_NAME); let mut storage_writer = None; let mut index_writer = FileWriter::try_new( - self.dataset.object_store().create(&index_path).await?, + self.store.create(&index_path).await?, S::schema().as_ref().try_into()?, Default::default(), )?; // maintain the IVF partitions let mut storage_ivf = IvfModel::empty(); - let mut index_ivf = IvfModel::new(ivf.centroids.clone().unwrap()); + let mut index_ivf = IvfModel::new(ivf.centroids.clone().unwrap(), ivf.loss); let mut partition_index_metadata = Vec::with_capacity(partition_sizes.len()); let obj_store = Arc::new(ObjectStore::local()); let scheduler_config = SchedulerConfig::max_bandwidth(&obj_store); @@ -508,7 +744,7 @@ impl IvfIndexBuilde let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; if storage_writer.is_none() { storage_writer = Some(FileWriter::try_new( - self.dataset.object_store().create(&storage_path).await?, + self.store.create(&storage_path).await?, batch.schema_ref().as_ref().try_into()?, Default::default(), )?); @@ -601,15 +837,19 @@ impl IvfIndexBuilde // take vectors from the dataset // used for reading vectors from existing indices - async fn take_vectors(&self, row_ids: &[u64]) -> Result> { - let column = self.column.clone(); - let object_store = self.dataset.object_store().clone(); - let projection = Arc::new(self.dataset.schema().project(&[column.as_str()])?); + #[allow(dead_code)] + async fn take_vectors( + dataset: &Arc, + column: &str, + store: &ObjectStore, + row_ids: &[u64], + ) -> Result> { + let projection = Arc::new(dataset.schema().project(&[column])?); // arrow uses i32 for index, so we chunk the row ids to avoid large batch causing overflow let mut batches = Vec::new(); - for chunk in row_ids.chunks(object_store.block_size()) { - let batch = self - .dataset + let row_ids = dataset.filter_deleted_ids(row_ids).await?; + for chunk in row_ids.chunks(store.block_size()) { + let batch = dataset .take_rows(chunk, ProjectionRequest::Schema(projection.clone())) .await?; let batch = batch.try_with_column( @@ -639,165 +879,3 @@ pub(crate) fn index_type_string(sub_index: SubIndexType, quantizer: Quantization } } } - -#[cfg(test)] -mod tests { - use crate::Dataset; - use arrow::datatypes::Float32Type; - use arrow_array::{FixedSizeListArray, RecordBatch, RecordBatchIterator}; - use arrow_schema::{DataType, Field, Schema}; - use lance_arrow::FixedSizeListArrayExt; - use lance_index::vector::hnsw::builder::HnswBuildParams; - use lance_index::vector::hnsw::HNSW; - use lance_index::vector::pq::{PQBuildParams, ProductQuantizer}; - use lance_index::vector::sq::builder::SQBuildParams; - use lance_index::vector::sq::ScalarQuantizer; - use lance_index::vector::{ - flat::index::{FlatIndex, FlatQuantizer}, - ivf::IvfBuildParams, - v3::shuffler::IvfShuffler, - }; - use lance_linalg::distance::DistanceType; - use lance_testing::datagen::generate_random_array_with_range; - use object_store::path::Path; - use std::{collections::HashMap, ops::Range, sync::Arc}; - use tempfile::tempdir; - - const DIM: usize = 32; - - async fn generate_test_dataset( - test_uri: &str, - range: Range, - ) -> (Dataset, Arc) { - let vectors = generate_random_array_with_range::(1000 * DIM, range); - let metadata: HashMap = vec![("test".to_string(), "ivf_pq".to_string())] - .into_iter() - .collect(); - - let schema: Arc<_> = Schema::new(vec![Field::new( - "vector", - DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Float32, true)), - DIM as i32, - ), - true, - )]) - .with_metadata(metadata) - .into(); - let array = Arc::new(FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap()); - let batch = RecordBatch::try_new(schema.clone(), vec![array.clone()]).unwrap(); - - let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); - let dataset = Dataset::write(batches, test_uri, None).await.unwrap(); - (dataset, array) - } - - #[tokio::test] - async fn test_build_ivf_flat() { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - let (dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; - - let ivf_params = IvfBuildParams::default(); - let index_dir: Path = tempdir().unwrap().path().to_str().unwrap().into(); - let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - - super::IvfIndexBuilder::::new( - dataset, - "vector".to_owned(), - index_dir, - DistanceType::L2, - Box::new(shuffler), - Some(ivf_params), - Some(()), - (), - ) - .unwrap() - .build() - .await - .unwrap(); - } - - #[tokio::test] - async fn test_build_ivf_pq() { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - let (dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; - - let ivf_params = IvfBuildParams::default(); - let pq_params = PQBuildParams::default(); - let index_dir: Path = tempdir().unwrap().path().to_str().unwrap().into(); - let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - - super::IvfIndexBuilder::::new( - dataset, - "vector".to_owned(), - index_dir, - DistanceType::L2, - Box::new(shuffler), - Some(ivf_params), - Some(pq_params), - (), - ) - .unwrap() - .build() - .await - .unwrap(); - } - - #[tokio::test] - async fn test_build_ivf_hnsw_sq() { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - let (dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; - - let ivf_params = IvfBuildParams::default(); - let hnsw_params = HnswBuildParams::default(); - let sq_params = SQBuildParams::default(); - let index_dir: Path = tempdir().unwrap().path().to_str().unwrap().into(); - let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - - super::IvfIndexBuilder::::new( - dataset, - "vector".to_owned(), - index_dir, - DistanceType::L2, - Box::new(shuffler), - Some(ivf_params), - Some(sq_params), - hnsw_params, - ) - .unwrap() - .build() - .await - .unwrap(); - } - - #[tokio::test] - async fn test_build_ivf_hnsw_pq() { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - let (dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; - - let ivf_params = IvfBuildParams::default(); - let hnsw_params = HnswBuildParams::default(); - let pq_params = PQBuildParams::default(); - let index_dir: Path = tempdir().unwrap().path().to_str().unwrap().into(); - let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - - super::IvfIndexBuilder::::new( - dataset, - "vector".to_owned(), - index_dir, - DistanceType::L2, - Box::new(shuffler), - Some(ivf_params), - Some(pq_params), - hnsw_params, - ) - .unwrap() - .build() - .await - .unwrap(); - } -} diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index 0214e83a998..95c21e25d67 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -17,11 +17,15 @@ mod test { use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, UInt32Array}; use arrow_schema::{DataType, Field, Schema}; use async_trait::async_trait; + use datafusion::execution::SendableRecordBatchStream; use deepsize::{Context, DeepSizeOf}; use lance_arrow::FixedSizeListArrayExt; - use lance_index::vector::ivf::storage::IvfModel; - use lance_index::vector::quantizer::{QuantizationType, Quantizer}; use lance_index::vector::v3::subindex::SubIndexType; + use lance_index::{metrics::MetricsCollector, vector::ivf::storage::IvfModel}; + use lance_index::{ + metrics::NoOpMetricsCollector, + vector::quantizer::{QuantizationType, Quantizer}, + }; use lance_index::{vector::Query, Index, IndexType}; use lance_io::{local::LocalObjectReader, traits::Reader}; use lance_linalg::distance::MetricType; @@ -70,6 +74,10 @@ mod test { Ok(self) } + async fn prewarm(&self) -> Result<()> { + Ok(()) + } + /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result { Ok(serde_json::Value::Null) @@ -91,6 +99,7 @@ mod test { &self, query: &Query, _pre_filter: Arc, + _metrics: &dyn MetricsCollector, ) -> Result { let key: &Float32Array = query.key.as_primitive(); assert_eq!(key.len(), self.assert_query_value.len()); @@ -109,6 +118,7 @@ mod test { _: usize, _: &Query, _: Arc, + _: &dyn MetricsCollector, ) -> Result { unimplemented!("only for IVF") } @@ -134,15 +144,23 @@ mod test { Ok(Box::new(self.clone())) } + fn num_rows(&self) -> u64 { + self.ret_val.num_rows() as u64 + } + fn row_ids(&self) -> Box> { todo!("this method is for only IVF_HNSW_* index"); } - fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { + async fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { Ok(()) } - fn ivf_model(&self) -> IvfModel { + async fn to_batch_stream(&self, _with_vector: bool) -> Result { + unimplemented!("only for SubIndex") + } + + fn ivf_model(&self) -> &IvfModel { unimplemented!("only for IVF") } fn quantizer(&self) -> Quantizer { @@ -164,7 +182,7 @@ mod test { async fn test_ivf_residual_handling() { let centroids = Float32Array::from_iter(vec![1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0]); let centroids = FixedSizeListArray::try_new_from_values(centroids, 2).unwrap(); - let mut ivf = IvfModel::new(centroids); + let mut ivf = IvfModel::new(centroids, None); // Add 4 partitions for _ in 0..4 { ivf.add_partition(0); @@ -233,6 +251,8 @@ mod test { column: "test".to_string(), key: Arc::new(Float32Array::from(query)), k: 1, + lower_bound: None, + upper_bound: None, nprobes: 1, ef: None, refine_factor: None, @@ -247,6 +267,7 @@ mod test { filtered_ids: None, final_mask: Mutex::new(OnceCell::new()), }), + &NoOpMetricsCollector, ) .await .unwrap(); diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 25dfee8b364..8c943649ac7 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -9,6 +9,7 @@ use std::{ sync::{Arc, Weak}, }; +use arrow::datatypes::UInt8Type; use arrow_arith::numeric::sub; use arrow_array::{ cast::{as_struct_array, AsArray}, @@ -19,6 +20,7 @@ use arrow_ord::sort::sort_to_indices; use arrow_schema::{DataType, Schema}; use arrow_select::{concat::concat_batches, take::take}; use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; use deepsize::DeepSizeOf; use futures::{ stream::{self, StreamExt}, @@ -27,17 +29,24 @@ use futures::{ use io::write_hnsw_quantization_index_partitions; use lance_arrow::*; use lance_core::{ - datatypes::Field, traits::DatasetTakeRows, utils::tokio::get_num_compute_intensive_cpus, Error, - Result, ROW_ID_FIELD, + traits::DatasetTakeRows, + utils::{ + tokio::get_num_compute_intensive_cpus, + tracing::{IO_TYPE_LOAD_VECTOR_PART, TRACE_IO_EVENTS}, + }, + Error, Result, ROW_ID_FIELD, }; use lance_file::{ format::MAGIC, writer::{FileWriter, FileWriterOptions}, }; -use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::metrics::MetricsCollector; +use lance_index::metrics::NoOpMetricsCollector; +use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::storage::transpose; use lance_index::vector::quantizer::QuantizationType; +use lance_index::vector::utils::is_finite; use lance_index::vector::v3::shuffler::IvfShuffler; use lance_index::vector::v3::subindex::{IvfSubIndex, SubIndexType}; use lance_index::{ @@ -68,13 +77,13 @@ use lance_linalg::{ distance::Normalize, kernels::{normalize_arrow, normalize_fsl}, }; -use log::info; +use log::{info, warn}; use object_store::path::Path; use rand::{rngs::SmallRng, SeedableRng}; use roaring::RoaringBitmap; use serde::Serialize; use serde_json::json; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use uuid::Uuid; @@ -84,6 +93,7 @@ use super::{ utils::maybe_sample_training_data, }; use crate::dataset::builder::DatasetBuilder; +use crate::index::vector::utils::{get_vector_dim, get_vector_type}; use crate::{ dataset::Dataset, index::{pb, prefilter::PreFilter, vector::ivf::io::write_pq_partitions, INDEX_FILE_NAME}, @@ -162,11 +172,12 @@ impl IVFIndex { /// Parameters /// ---------- /// - partition_id: partition ID. - #[instrument(level = "debug", skip(self))] + #[instrument(level = "debug", skip(self, metrics))] pub async fn load_partition( &self, partition_id: usize, write_cache: bool, + metrics: &dyn MetricsCollector, ) -> Result> { let cache_key = format!("{}-ivf-{}", self.uuid, partition_id); let session = self.session.upgrade().ok_or(Error::Internal { @@ -176,6 +187,9 @@ impl IVFIndex { let part_index = if let Some(part_idx) = session.index_cache.get_vector(&cache_key) { part_idx } else { + metrics.record_part_load(); + tracing::info!(target: TRACE_IO_EVENTS, type=IO_TYPE_LOAD_VECTOR_PART, index_type="ivf", part_id=cache_key); + let mtx = self.partition_locks.get_partition_mutex(partition_id); let _guard = mtx.lock().await; // check the cache again, as the partition may have been loaded by another @@ -267,6 +281,12 @@ pub(crate) async fn optimize_vector_indices( .await; } + if options.retrain { + warn!( + "optimizing vector index: retrain is only supported for v3 vector indices, falling back to normal optimization. please re-create the index with lance>=0.25.0 to enable retrain." + ); + } + let new_uuid = Uuid::new_v4(); let object_store = dataset.object_store(); let index_file = dataset @@ -355,34 +375,61 @@ pub(crate) async fn optimize_vector_indices_v2( let num_partitions = ivf_model.num_partitions(); let index_type = existing_indices[0].sub_index_type(); + let num_indices_to_merge = if options.retrain { + existing_indices.len() + } else { + options.num_indices_to_merge + }; let temp_dir = tempfile::tempdir()?; - let temp_dir = temp_dir.path().to_str().unwrap().into(); - let shuffler = Box::new(IvfShuffler::new(temp_dir, num_partitions)); + let temp_dir_path = Path::from_filesystem_path(temp_dir.path())?; + let shuffler = Box::new(IvfShuffler::new(temp_dir_path, num_partitions)); let start_pos = if options.num_indices_to_merge > existing_indices.len() { 0 } else { - existing_indices.len() - options.num_indices_to_merge + existing_indices.len() - num_indices_to_merge }; let indices_to_merge = existing_indices[start_pos..].to_vec(); let merged_num = indices_to_merge.len(); + + let (_, element_type) = get_vector_type(dataset.schema(), vector_column)?; match index_type { // IVF_FLAT (SubIndexType::Flat, QuantizationType::Flat) => { - IvfIndexBuilder::::new_incremental( - dataset.clone(), - vector_column.to_owned(), - index_dir, - distance_type, - shuffler, - (), - )? - .with_ivf(ivf_model) - .with_quantizer(quantizer.try_into()?) - .with_existing_indices(indices_to_merge) - .shuffle_data(unindexed) - .await? - .build() - .await?; + if element_type == DataType::UInt8 { + IvfIndexBuilder::::new_incremental( + dataset.clone(), + vector_column.to_owned(), + index_dir, + distance_type, + shuffler, + (), + )? + .with_ivf(ivf_model.clone()) + .with_quantizer(quantizer.try_into()?) + .with_existing_indices(indices_to_merge) + .retrain(options.retrain) + .shuffle_data(unindexed) + .await? + .build() + .await?; + } else { + IvfIndexBuilder::::new_incremental( + dataset.clone(), + vector_column.to_owned(), + index_dir, + distance_type, + shuffler, + (), + )? + .with_ivf(ivf_model.clone()) + .with_quantizer(quantizer.try_into()?) + .with_existing_indices(indices_to_merge) + .retrain(options.retrain) + .shuffle_data(unindexed) + .await? + .build() + .await?; + } } // IVF_PQ (SubIndexType::Flat, QuantizationType::Product) => { @@ -394,9 +441,10 @@ pub(crate) async fn optimize_vector_indices_v2( shuffler, (), )? - .with_ivf(ivf_model) + .with_ivf(ivf_model.clone()) .with_quantizer(quantizer.try_into()?) .with_existing_indices(indices_to_merge) + .retrain(options.retrain) .shuffle_data(unindexed) .await? .build() @@ -415,9 +463,10 @@ pub(crate) async fn optimize_vector_indices_v2( // TODO: get the HNSW parameters from the existing indices HnswBuildParams::default(), )? - .with_ivf(ivf_model) + .with_ivf(ivf_model.clone()) .with_quantizer(quantizer.try_into()?) .with_existing_indices(indices_to_merge) + .retrain(options.retrain) .shuffle_data(unindexed) .await? .build() @@ -436,9 +485,10 @@ pub(crate) async fn optimize_vector_indices_v2( // TODO: get the HNSW parameters from the existing indices HnswBuildParams::default(), )? - .with_ivf(ivf_model) + .with_ivf(ivf_model.clone()) .with_quantizer(quantizer.try_into()?) .with_existing_indices(indices_to_merge) + .retrain(options.retrain) .shuffle_data(unindexed) .await? .build() @@ -479,7 +529,6 @@ async fn optimize_ivf_pq_indices( vector_column, pq_index.pq.clone(), None, - true, ); // Shuffled un-indexed data with partition. @@ -487,7 +536,6 @@ async fn optimize_ivf_pq_indices( Some(unindexed) => Some( shuffle_dataset( unindexed, - vector_column, ivf.into(), None, first_idx.ivf.num_partitions() as u32, @@ -500,13 +548,11 @@ async fn optimize_ivf_pq_indices( None => None, }; - let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap()); + let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap(), first_idx.ivf.loss); - let start_pos = if options.num_indices_to_merge > existing_indices.len() { - 0 - } else { - existing_indices.len() - options.num_indices_to_merge - }; + let start_pos = existing_indices + .len() + .saturating_sub(options.num_indices_to_merge); let indices_to_merge = existing_indices[start_pos..] .iter() @@ -566,7 +612,6 @@ async fn optimize_ivf_hnsw_indices( Some(unindexed) => Some( shuffle_dataset( unindexed, - vector_column, Arc::new(ivf), None, first_idx.ivf.num_partitions() as u32, @@ -579,7 +624,7 @@ async fn optimize_ivf_hnsw_indices( None => None, }; - let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap()); + let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap(), first_idx.ivf.loss); let start_pos = if options.num_indices_to_merge > existing_indices.len() { 0 @@ -638,6 +683,7 @@ async fn optimize_ivf_hnsw_indices( // Write the metadata of quantizer let quantization_metadata = match &quantizer { Quantizer::Flat(_) => None, + Quantizer::FlatBin(_) => None, Quantizer::Product(pq) => { let codebook_tensor = pb::Tensor::try_from(&pq.codebook)?; let codebook_pos = aux_writer.tell().await?; @@ -706,6 +752,7 @@ pub struct IvfIndexStatistics { sub_index: serde_json::Value, partitions: Vec, centroids: Vec>, + loss: Option, } fn centroids_to_vectors(centroids: &FixedSizeListArray) -> Result>> { @@ -727,6 +774,12 @@ fn centroids_to_vectors(centroids: &FixedSizeListArray) -> Result>> .iter() .map(|v| *v as f32) .collect::>()), + DataType::UInt8 => Ok(row + .as_primitive::() + .values() + .iter() + .map(|v| *v as f32) + .collect::>()), _ => Err(Error::Index { message: format!( "IVF centroids must be FixedSizeList of floating number, got: {}", @@ -781,6 +834,11 @@ impl Index for IVFIndex { } } + async fn prewarm(&self) -> Result<()> { + // TODO: We should prewarm the IVF index by loading the partitions into memory + Ok(()) + } + fn statistics(&self) -> Result { let partitions_statistics = (0..self.ivf.num_partitions()) .map(|part_id| IvfIndexPartitionStatistics { @@ -799,6 +857,7 @@ impl Index for IVFIndex { sub_index: self.sub_index.statistics()?, partitions: partitions_statistics, centroids: centroid_vecs, + loss: self.ivf.loss(), })?) } @@ -806,7 +865,9 @@ impl Index for IVFIndex { let mut frag_ids = RoaringBitmap::default(); let part_ids = 0..self.ivf.num_partitions(); for part_id in part_ids { - let part = self.load_partition(part_id, false).await?; + let part = self + .load_partition(part_id, false, &NoOpMetricsCollector) + .await?; frag_ids |= part.calculate_included_frags().await?; } Ok(frag_ids) @@ -816,7 +877,12 @@ impl Index for IVFIndex { #[async_trait] impl VectorIndex for IVFIndex { #[instrument(level = "debug", skip_all, name = "IVFIndex::search")] - async fn search(&self, query: &Query, pre_filter: Arc) -> Result { + async fn search( + &self, + query: &Query, + pre_filter: Arc, + metrics: &dyn MetricsCollector, + ) -> Result { let mut query = query.clone(); if self.metric_type == MetricType::Cosine { let key = normalize_arrow(&query.key)?; @@ -827,7 +893,9 @@ impl VectorIndex for IVFIndex { assert!(partition_ids.len() <= query.nprobes); let part_ids = partition_ids.values().to_vec(); let batches = stream::iter(part_ids) - .map(|part_id| self.search_in_partition(part_id as usize, &query, pre_filter.clone())) + .map(|part_id| { + self.search_in_partition(part_id as usize, &query, pre_filter.clone(), metrics) + }) .buffer_unordered(get_num_compute_intensive_cpus()) .try_collect::>() .await?; @@ -871,11 +939,12 @@ impl VectorIndex for IVFIndex { partition_id: usize, query: &Query, pre_filter: Arc, + metrics: &dyn MetricsCollector, ) -> Result { - let part_index = self.load_partition(partition_id, true).await?; + let part_index = self.load_partition(partition_id, true, metrics).await?; let query = self.preprocess_query(partition_id, query)?; - let batch = part_index.search(&query, pre_filter).await?; + let batch = part_index.search(&query, pre_filter, metrics).await?; Ok(batch) } @@ -903,11 +972,29 @@ impl VectorIndex for IVFIndex { }) } + async fn partition_reader( + &self, + partition_id: usize, + with_vector: bool, + metrics: &dyn MetricsCollector, + ) -> Result { + let partition = self.load_partition(partition_id, false, metrics).await?; + partition.to_batch_stream(with_vector).await + } + + async fn to_batch_stream(&self, _with_vector: bool) -> Result { + unimplemented!("this method is for only sub index") + } + + fn num_rows(&self) -> u64 { + self.ivf.num_rows() + } + fn row_ids(&self) -> Box> { todo!("this method is for only IVF_HNSW_* index"); } - fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { + async fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { // This will be needed if we want to clean up IVF to allow more than just // one layer (e.g. IVF -> IVF -> PQ). We need to pass on the call to // remap to the lower layers. @@ -920,8 +1007,8 @@ impl VectorIndex for IVFIndex { }) } - fn ivf_model(&self) -> IvfModel { - self.ivf.clone() + fn ivf_model(&self) -> &IvfModel { + &self.ivf } fn quantizer(&self) -> Quantizer { @@ -1041,38 +1128,6 @@ impl TryFrom<&IvfPQIndexMetadata> for pb::Index { } } -fn sanity_check<'a>(dataset: &'a Dataset, column: &str) -> Result<&'a Field> { - let Some(field) = dataset.schema().field(column) else { - return Err(Error::io( - format!( - "Building index: column {} does not exist in dataset: {:?}", - column, dataset - ), - location!(), - )); - }; - if let DataType::FixedSizeList(elem_type, _) = field.data_type() { - if !elem_type.data_type().is_floating() { - return Err(Error::Index{ - message:format!( - "VectorIndex requires the column data type to be fixed size list of f16/f32/f64, got {}", - elem_type.data_type() - ), - location: location!() - }); - } - } else { - return Err(Error::Index { - message: format!( - "VectorIndex requires the column data type to be fixed size list of float32s, got {}", - field.data_type() - ), - location: location!(), - }); - } - Ok(field) -} - fn sanity_check_ivf_params(ivf: &IvfBuildParams) -> Result<()> { if ivf.precomputed_partitions_file.is_some() && ivf.centroids.is_none() { return Err(Error::Index { @@ -1135,7 +1190,9 @@ pub async fn build_ivf_model( metric_type: MetricType, params: &IvfBuildParams, ) -> Result { - if let Some(centroids) = params.centroids.as_ref() { + let centroids = params.centroids.clone(); + if centroids.is_some() && !params.retrain { + let centroids = centroids.unwrap(); info!("Pre-computed IVF centroids is provided, skip IVF training"); if centroids.values().len() != params.num_partitions * dim { return Err(Error::Index { @@ -1147,7 +1204,7 @@ pub async fn build_ivf_model( location: location!(), }); } - return Ok(IvfModel::new(centroids.as_ref().clone())); + return Ok(IvfModel::new(centroids.as_ref().clone(), None)); } let sample_size_hint = params.num_partitions * params.sample_rate; @@ -1171,9 +1228,13 @@ pub async fn build_ivf_model( (training_data, metric_type) }; + // we filtered out nulls when sampling, but we still need to filter out NaNs and INFs here + let training_data = arrow::compute::filter(&training_data, &is_finite(&training_data))?; + let training_data = training_data.as_fixed_size_list(); + info!("Start to train IVF model"); let start = std::time::Instant::now(); - let ivf = train_ivf_model(&training_data, mt, params).await?; + let ivf = train_ivf_model(centroids, training_data, mt, params).await?; info!( "Trained IVF model in {:02} seconds", start.elapsed().as_secs_f32() @@ -1195,18 +1256,9 @@ async fn build_ivf_model_and_pq( ivf_params.num_partitions, pq_params.num_sub_vectors, metric_type, ); - let field = sanity_check(dataset, column)?; - let dim = if let DataType::FixedSizeList(_, d) = field.data_type() { - d as usize - } else { - return Err(Error::Index { - message: format!( - "VectorIndex requires the column data type to be fixed size list of floats, got {}", - field.data_type() - ), - location: location!(), - }); - }; + // sanity check + get_vector_type(dataset.schema(), column)?; + let dim = get_vector_dim(dataset.schema(), column)?; let ivf_model = build_ivf_model(dataset, column, dim, metric_type, ivf_params).await?; @@ -1345,7 +1397,7 @@ impl RemapPageTask { .sub_index .load(reader, self.offset, self.length as usize) .await?; - page.remap(mapping)?; + page.remap(mapping).await?; self.page = Some(page); Ok(self) } @@ -1380,6 +1432,20 @@ fn generate_remap_tasks(offsets: &[usize], lengths: &[u32]) -> Result, + mapping: &HashMap>, + column: String, +) -> Result<()> { + let index_dir = dataset.indices_dir().child(new_uuid); + index + .remap_to(dataset.object_store().clone(), mapping, column, index_dir) + .await +} + #[allow(clippy::too_many_arguments)] pub(crate) async fn remap_index_file( dataset: &Dataset, @@ -1409,6 +1475,7 @@ pub(crate) async fn remap_index_file( centroids: index.ivf.centroids.clone(), offsets: Vec::with_capacity(index.ivf.offsets.len()), lengths: Vec::with_capacity(index.ivf.lengths.len()), + loss: index.ivf.loss, }; while let Some(write_task) = task_stream.try_next().await? { write_task.write(&mut writer, &mut ivf).await?; @@ -1604,6 +1671,7 @@ async fn write_ivf_hnsw_file( // For PQ, we need to store the codebook let quantization_metadata = match &quantizer { Quantizer::Flat(_) => None, + Quantizer::FlatBin(_) => None, Quantizer::Product(pq) => { let codebook_tensor = pb::Tensor::try_from(&pq.codebook)?; let codebook_pos = aux_writer.tell().await?; @@ -1664,6 +1732,7 @@ async fn write_ivf_hnsw_file( } async fn do_train_ivf_model( + centroids: Option>, data: &[T::Native], dimension: usize, metric_type: MetricType, @@ -1675,7 +1744,8 @@ where { let rng = SmallRng::from_entropy(); const REDOS: usize = 1; - let centroids = lance_index::vector::kmeans::train_kmeans::( + let kmeans = lance_index::vector::kmeans::train_kmeans::( + centroids, data, dimension, params.num_partitions, @@ -1685,14 +1755,15 @@ where metric_type, params.sample_rate, )?; - Ok(IvfModel::new(FixedSizeListArray::try_new_from_values( - centroids, - dimension as i32, - )?)) + Ok(IvfModel::new( + FixedSizeListArray::try_new_from_values(kmeans.centroids, dimension as i32)?, + Some(kmeans.loss), + )) } /// Train IVF partitions using kmeans. async fn train_ivf_model( + centroids: Option>, data: &FixedSizeListArray, distance_type: DistanceType, params: &IvfBuildParams, @@ -1706,6 +1777,7 @@ async fn train_ivf_model( match (values.data_type(), distance_type) { (DataType::Float16, _) => { do_train_ivf_model::( + centroids, values.as_primitive::().values(), dim, distance_type, @@ -1715,6 +1787,7 @@ async fn train_ivf_model( } (DataType::Float32, _) => { do_train_ivf_model::( + centroids, values.as_primitive::().values(), dim, distance_type, @@ -1724,6 +1797,7 @@ async fn train_ivf_model( } (DataType::Float64, _) => { do_train_ivf_model::( + centroids, values.as_primitive::().values(), dim, distance_type, @@ -1731,8 +1805,37 @@ async fn train_ivf_model( ) .await } + (DataType::Int8, DistanceType::L2) + | (DataType::Int8, DistanceType::Dot) + | (DataType::Int8, DistanceType::Cosine) => { + do_train_ivf_model::( + centroids, + data.convert_to_floating_point()? + .values() + .as_primitive::() + .values(), + dim, + distance_type, + params, + ) + .await + } + (DataType::UInt8, DistanceType::Hamming) => { + do_train_ivf_model::( + centroids, + values.as_primitive::().values(), + dim, + distance_type, + params, + ) + .await + } _ => Err(Error::Index { - message: "Unsupported data type".to_string(), + message: format!( + "Unsupported data type {} with distance type {}", + values.data_type(), + distance_type + ), location: location!(), }), } @@ -1743,15 +1846,21 @@ mod tests { use super::*; use std::collections::HashSet; - use std::iter::repeat; + use std::iter::repeat_n; use std::ops::Range; use arrow_array::types::UInt64Type; - use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array}; - use arrow_schema::Field; + use arrow_array::{ + make_array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, + RecordBatchReader, UInt64Array, + }; + use arrow_buffer::{BooleanBuffer, NullBuffer}; + use arrow_schema::{DataType, Field, Schema}; use itertools::Itertools; use lance_core::utils::address::RowAddress; use lance_core::ROW_ID; + use lance_datagen::{array, gen, ArrayGeneratorExt, Dimension, RowCount}; + use lance_index::metrics::NoOpMetricsCollector; use lance_index::vector::sq::builder::SQBuildParams; use lance_linalg::distance::l2_distance_batch; use lance_testing::datagen::{ @@ -1759,9 +1868,12 @@ mod tests { generate_scaled_random_array, sample_without_replacement, }; use rand::{seq::SliceRandom, thread_rng}; + use rstest::rstest; use tempfile::tempdir; + use crate::dataset::{InsertBuilder, WriteMode, WriteParams}; use crate::index::prefilter::DatasetPreFilter; + use crate::index::vector::IndexFileVersion; use crate::index::vector_index_details; use crate::index::{vector::VectorIndexParams, DatasetIndexExt, DatasetIndexInternalExt}; @@ -1946,13 +2058,18 @@ mod tests { column: Self::COLUMN.to_string(), key: Arc::new(row), k: 5, + lower_bound: None, + upper_bound: None, nprobes: 1, ef: None, refine_factor: None, metric_type: MetricType::L2, use_index: true, }; - let search_result = index.search(&query, prefilter.clone()).await.unwrap(); + let search_result = index + .search(&query, prefilter.clone(), &NoOpMetricsCollector) + .await + .unwrap(); let found_ids = search_result.column(1); let found_ids = found_ids.as_any().downcast_ref::().unwrap(); @@ -2129,7 +2246,7 @@ mod tests { .unwrap(); let index = dataset - .open_vector_index(WellKnownIvfPqData::COLUMN, &uuid_str) + .open_vector_index(WellKnownIvfPqData::COLUMN, &uuid_str, &NoOpMetricsCollector) .await .unwrap(); let ivf_index = index.as_any().downcast_ref::().unwrap(); @@ -2185,7 +2302,11 @@ mod tests { .unwrap(); let remapped = dataset - .open_vector_index(WellKnownIvfPqData::COLUMN, &new_uuid.to_string()) + .open_vector_index( + WellKnownIvfPqData::COLUMN, + &new_uuid.to_string(), + &NoOpMetricsCollector, + ) .await .unwrap(); let ivf_remapped = remapped.as_any().downcast_ref::().unwrap(); @@ -2219,6 +2340,215 @@ mod tests { .await; } + struct TestPqParams { + num_sub_vectors: usize, + num_bits: usize, + } + + impl TestPqParams { + fn small() -> Self { + Self { + num_sub_vectors: 2, + num_bits: 8, + } + } + } + + // Clippy doesn't like that all start with Ivf but we might have some in the future + // that _don't_ start with Ivf so I feel it is meaningful to keep the prefix + #[allow(clippy::enum_variant_names)] + enum TestIndexType { + IvfPq { pq: TestPqParams }, + IvfHnswPq { pq: TestPqParams, num_edges: usize }, + IvfHnswSq { num_edges: usize }, + IvfFlat, + } + + struct CreateIndexCase { + metric_type: MetricType, + num_partitions: usize, + dimension: usize, + index_type: TestIndexType, + } + + // We test L2 and Dot, because L2 PQ uses residuals while Dot doesn't, + // so they have slightly different code paths. + #[tokio::test] + #[rstest] + #[case::ivf_pq_l2(CreateIndexCase { + metric_type: MetricType::L2, + num_partitions: 2, + dimension: 16, + index_type: TestIndexType::IvfPq { pq: TestPqParams::small() }, + })] + #[case::ivf_pq_dot(CreateIndexCase { + metric_type: MetricType::Dot, + num_partitions: 2, + dimension: 2000, + index_type: TestIndexType::IvfPq { pq: TestPqParams::small() }, + })] + #[case::ivf_flat(CreateIndexCase { num_partitions: 1, metric_type: MetricType::Dot, dimension: 16, index_type: TestIndexType::IvfFlat })] + #[case::ivf_hnsw_pq(CreateIndexCase { + num_partitions: 2, + metric_type: MetricType::Dot, + dimension: 16, + index_type: TestIndexType::IvfHnswPq { pq: TestPqParams::small(), num_edges: 100 }, + })] + #[case::ivf_hnsw_sq(CreateIndexCase { + metric_type: MetricType::Dot, + num_partitions: 2, + dimension: 16, + index_type: TestIndexType::IvfHnswSq { num_edges: 100 }, + })] + async fn test_create_index_nulls( + #[case] test_case: CreateIndexCase, + #[values(IndexFileVersion::Legacy, IndexFileVersion::V3)] index_version: IndexFileVersion, + ) { + let mut index_params = match test_case.index_type { + TestIndexType::IvfPq { pq } => VectorIndexParams::with_ivf_pq_params( + test_case.metric_type, + IvfBuildParams::new(test_case.num_partitions), + PQBuildParams::new(pq.num_sub_vectors, pq.num_bits), + ), + TestIndexType::IvfHnswPq { pq, num_edges } => { + VectorIndexParams::with_ivf_hnsw_pq_params( + test_case.metric_type, + IvfBuildParams::new(test_case.num_partitions), + HnswBuildParams::default().num_edges(num_edges), + PQBuildParams::new(pq.num_sub_vectors, pq.num_bits), + ) + } + TestIndexType::IvfFlat => { + VectorIndexParams::ivf_flat(test_case.num_partitions, test_case.metric_type) + } + TestIndexType::IvfHnswSq { num_edges } => VectorIndexParams::with_ivf_hnsw_sq_params( + test_case.metric_type, + IvfBuildParams::new(test_case.num_partitions), + HnswBuildParams::default().num_edges(num_edges), + SQBuildParams::default(), + ), + }; + index_params.version(index_version); + + let nrows = 2_000; + let data = gen() + .col( + "vec", + array::rand_vec::(Dimension::from(test_case.dimension as u32)), + ) + .into_batch_rows(RowCount::from(nrows)) + .unwrap(); + + // Make every other row null + let null_buffer = (0..nrows).map(|i| i % 2 == 0).collect::(); + let null_buffer = NullBuffer::new(null_buffer); + let vectors = data["vec"] + .clone() + .to_data() + .into_builder() + .nulls(Some(null_buffer)) + .build() + .unwrap(); + let vectors = make_array(vectors); + let num_non_null = vectors.len() - vectors.logical_null_count(); + let data = RecordBatch::try_new(data.schema(), vec![vectors]).unwrap(); + + let mut dataset = InsertBuilder::new("memory://") + .execute(vec![data]) + .await + .unwrap(); + + // Create index + dataset + .create_index(&["vec"], IndexType::Vector, None, &index_params, false) + .await + .unwrap(); + + let query = vec![0.0; test_case.dimension] + .into_iter() + .collect::(); + let results = dataset + .scan() + .nearest("vec", &query, 2_000) + .unwrap() + .ef(100_000) + .nprobs(2) + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), num_non_null); + assert_eq!(results["vec"].logical_null_count(), 0); + } + + #[tokio::test] + async fn test_index_lifecycle_nulls() { + // Generate random data with nulls + let nrows = 2_000; + let dims = 32; + let data = gen() + .col( + "vec", + array::rand_vec::(Dimension::from(dims as u32)).with_random_nulls(0.5), + ) + .into_batch_rows(RowCount::from(nrows)) + .unwrap(); + let num_non_null = data["vec"].len() - data["vec"].logical_null_count(); + + let mut dataset = InsertBuilder::new("memory://") + .execute(vec![data]) + .await + .unwrap(); + + // Create index + let index_params = VectorIndexParams::with_ivf_pq_params( + MetricType::L2, + IvfBuildParams::new(2), + PQBuildParams::new(2, 8), + ); + dataset + .create_index(&["vec"], IndexType::Vector, None, &index_params, false) + .await + .unwrap(); + + // Check that the index is working + async fn check_index(dataset: &Dataset, num_non_null: usize, dims: usize) { + let query = vec![0.0; dims].into_iter().collect::(); + let results = dataset + .scan() + .nearest("vec", &query, 2_000) + .unwrap() + .nprobs(2) + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), num_non_null); + } + check_index(&dataset, num_non_null, dims).await; + + // Append more data + let data = gen() + .col( + "vec", + array::rand_vec::(Dimension::from(dims as u32)).with_random_nulls(0.5), + ) + .into_batch_rows(RowCount::from(500)) + .unwrap(); + let num_non_null = data["vec"].len() - data["vec"].logical_null_count() + num_non_null; + let mut dataset = InsertBuilder::new(Arc::new(dataset)) + .with_params(&WriteParams { + mode: WriteMode::Append, + ..Default::default() + }) + .execute(vec![data]) + .await + .unwrap(); + check_index(&dataset, num_non_null, dims).await; + + // Optimize the index + dataset.optimize_indices(&Default::default()).await.unwrap(); + check_index(&dataset, num_non_null, dims).await; + } + #[tokio::test] async fn test_create_ivf_pq_cosine() { let test_dir = tempdir().unwrap(); @@ -2415,7 +2745,7 @@ mod tests { .scan() .nearest( "vector", - &Float32Array::from_iter_values(repeat(0.5).take(DIM)), + &Float32Array::from_iter_values(repeat_n(0.5, DIM)), 5, ) .unwrap() @@ -2483,7 +2813,7 @@ mod tests { .scan() .nearest( "vector", - &Float32Array::from_iter_values(repeat(0.5).take(DIM)), + &Float32Array::from_iter_values(repeat_n(0.5, DIM)), 5, ) .unwrap() @@ -2750,7 +3080,7 @@ mod tests { true, )])); - let arr = generate_random_array_with_range(1000 * DIM, 1000.0..1001.0); + let arr = generate_random_array_with_range::(1000 * DIM, 1000.0..1001.0); let fsl = FixedSizeListArray::try_new_from_values(arr.clone(), DIM as i32).unwrap(); let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(fsl)]).unwrap(); let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); @@ -2763,10 +3093,14 @@ mod tests { .unwrap(); let indices = dataset.load_indices().await.unwrap(); let idx = dataset - .open_generic_index("vector", indices[0].uuid.to_string().as_str()) + .open_generic_index( + "vector", + indices[0].uuid.to_string().as_str(), + &NoOpMetricsCollector, + ) .await .unwrap(); - let ivf_idx = idx.as_any().downcast_ref::().unwrap(); + let ivf_idx = idx.as_any().downcast_ref::().unwrap(); assert!(ivf_idx .ivf_model() @@ -2779,16 +3113,10 @@ mod tests { .iter() .all(|v| (0.0..=1.0).contains(v))); - let pq_idx = ivf_idx - .sub_index - .as_any() - .downcast_ref::() - .unwrap(); - // PQ code is on residual space - pq_idx - .pq - .codebook + let pq_store = ivf_idx.load_partition_storage(0).await.unwrap(); + pq_store + .codebook() .values() .as_primitive::() .values() diff --git a/rust/lance/src/index/vector/ivf/builder.rs b/rust/lance/src/index/vector/ivf/builder.rs index 02df4cc0b32..33557d301e3 100644 --- a/rust/lance/src/index/vector/ivf/builder.rs +++ b/rust/lance/src/index/vector/ivf/builder.rs @@ -24,7 +24,7 @@ use lance_io::stream::RecordBatchStreamAdapter; use lance_table::io::manifest::ManifestDescribing; use log::info; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tracing::instrument; use lance_core::{traits::DatasetTakeRows, Error, Result, ROW_ID}; @@ -79,12 +79,10 @@ pub(super) async fn build_partitions( column, pq.clone(), Some(part_range), - true, ); let stream = shuffle_dataset( data, - column, ivf_transformer.into(), precomputed_partitions, ivf.num_partitions() as u32, @@ -213,7 +211,6 @@ pub async fn write_vector_storage( column, pq, None, - true, )); let data = if let Some(partitions_ds_uri) = precomputed_partitions_ds_uri { @@ -288,7 +285,6 @@ pub(super) async fn build_hnsw_partitions( let stream = shuffle_dataset( data, - column, ivf_model.into(), precomputed_partitions, ivf.num_partitions() as u32, diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index 8290f88ab26..5c0204d04f2 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -20,6 +20,7 @@ use lance_core::utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu}; use lance_core::Error; use lance_file::reader::FileReader; use lance_file::writer::FileWriter; +use lance_index::metrics::NoOpMetricsCollector; use lance_index::scalar::IndexWriter; use lance_index::vector::hnsw::HNSW; use lance_index::vector::hnsw::{builder::HnswBuildParams, HnswMetadata}; @@ -38,7 +39,7 @@ use lance_linalg::kernels::normalize_fsl; use lance_table::format::SelfDescribingFileReader; use lance_table::io::manifest::ManifestDescribing; use object_store::path::Path; -use snafu::{location, Location}; +use snafu::location; use tempfile::TempDir; use tokio::sync::Semaphore; @@ -190,7 +191,9 @@ pub(super) async fn write_pq_partitions( if let Some(&previous_indices) = existing_indices.as_ref() { for &idx in previous_indices.iter() { - let sub_index = idx.load_partition(part_id as usize, true).await?; + let sub_index = idx + .load_partition(part_id as usize, true, &NoOpMetricsCollector) + .await?; let pq_index = sub_index .as_any() @@ -312,7 +315,9 @@ pub(super) async fn write_hnsw_quantization_index_partitions( if let Some(&previous_indices) = existing_indices.as_ref() { for &idx in previous_indices.iter() { - let sub_index = idx.load_partition(part_id, true).await?; + let sub_index = idx + .load_partition(part_id, true, &NoOpMetricsCollector) + .await?; let row_ids = Arc::new(UInt64Array::from_iter_values(sub_index.row_ids().cloned())); row_id_array.push(row_ids); } @@ -320,6 +325,7 @@ pub(super) async fn write_hnsw_quantization_index_partitions( let code_column = match &quantizer { Quantizer::Flat(_) => None, + Quantizer::FlatBin(_) => None, Quantizer::Product(pq) => Some(pq.column()), Quantizer::Scalar(_) => None, }; @@ -547,10 +553,12 @@ async fn build_and_write_pq_storage( mod tests { use super::*; + use crate::index::vector::ivf::v2; use crate::index::{vector::VectorIndexParams, DatasetIndexExt, DatasetIndexInternalExt}; use crate::Dataset; use arrow_array::RecordBatchIterator; use arrow_schema::{Field, Schema}; + use lance_index::metrics::NoOpMetricsCollector; use lance_index::IndexType; use lance_testing::datagen::generate_random_array; @@ -597,12 +605,16 @@ mod tests { assert_eq!(ds.get_fragments().len(), 2); let idx = ds - .open_vector_index("vector", &indices[0].uuid.to_string()) + .open_vector_index( + "vector", + &indices[0].uuid.to_string(), + &NoOpMetricsCollector, + ) .await .unwrap(); let _ivf_idx = idx .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Invalid index type"); //let indices = /ds. diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 727f50ecea7..feaa820296e 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -3,7 +3,6 @@ //! IVF - Inverted File index. -use core::fmt; use std::marker::PhantomData; use std::{ any::Any, @@ -18,21 +17,27 @@ use arrow::{ use arrow_arith::numeric::sub; use arrow_array::{RecordBatch, StructArray, UInt32Array}; use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use deepsize::DeepSizeOf; use futures::prelude::stream::{self, StreamExt, TryStreamExt}; use lance_arrow::RecordBatchExt; use lance_core::cache::FileMetadataCache; use lance_core::utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu}; -use lance_core::{Error, Result}; +use lance_core::utils::tracing::{IO_TYPE_LOAD_VECTOR_PART, TRACE_IO_EVENTS}; +use lance_core::{Error, Result, ROW_ID}; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; use lance_file::v2::reader::{FileReader, FileReaderOptions}; +use lance_index::metrics::{LocalMetricsCollector, MetricsCollector}; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::ProductQuantizer; use lance_index::vector::quantizer::{QuantizationType, Quantizer}; use lance_index::vector::sq::ScalarQuantizer; +use lance_index::vector::storage::VectorStore; use lance_index::vector::v3::subindex::SubIndexType; +use lance_index::vector::VectorIndexCacheEntry; use lance_index::{ pb, vector::{ @@ -42,20 +47,19 @@ use lance_index::{ Index, IndexType, INDEX_AUXILIARY_FILE_NAME, INDEX_FILE_NAME, }; use lance_index::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; +use lance_io::local::to_local_path; use lance_io::scheduler::SchedulerConfig; use lance_io::{ object_store::ObjectStore, scheduler::ScanScheduler, traits::Reader, ReadBatchParams, }; use lance_linalg::{distance::DistanceType, kernels::normalize_arrow}; -use moka::sync::Cache; use object_store::path::Path; use prost::Message; use roaring::RoaringBitmap; -use serde_json::json; -use snafu::{location, Location}; -use tracing::instrument; +use snafu::location; +use tracing::{info, instrument}; -use crate::index::vector::builder::index_type_string; +use crate::index::vector::builder::{index_type_string, IvfIndexBuilder}; use crate::{ index::{ vector::{utils::PartitionLoadLock, VectorIndex}, @@ -66,15 +70,24 @@ use crate::{ use super::{centroids_to_vectors, IvfIndexPartitionStatistics, IvfIndexStatistics}; -#[derive(Debug)] -struct PartitionEntry { - index: S, - storage: Q::Storage, +#[derive(Debug, DeepSizeOf)] +pub struct PartitionEntry { + pub index: S, + pub storage: Q::Storage, +} + +impl VectorIndexCacheEntry + for PartitionEntry +{ + fn as_any(&self) -> &dyn Any { + self + } } /// IVF Index. #[derive(Debug)] pub struct IVFIndex { + uri: String, uuid: String, /// Ivf model @@ -84,9 +97,6 @@ pub struct IVFIndex { sub_index_metadata: Vec, storage: IvfQuantizationStorage, - /// Index in each partition. - partition_cache: Cache>>, - partition_locks: PartitionLoadLock, distance_type: DistanceType, @@ -96,8 +106,7 @@ pub struct IVFIndex { /// The session cache, used when fetching pages #[allow(dead_code)] session: Weak, - - _marker: PhantomData, + _marker: PhantomData<(S, Q)>, } impl DeepSizeOf for IVFIndex { @@ -122,11 +131,9 @@ impl IVFIndex { .upgrade() .map(|sess| sess.file_metadata_cache.clone()) .unwrap_or_else(FileMetadataCache::no_cache); - let index_cache_capacity = session.upgrade().unwrap().index_cache.capacity(); + let uri = index_dir.child(uuid.as_str()).child(INDEX_FILE_NAME); let index_reader = FileReader::try_open( - scheduler - .open_file(&index_dir.child(uuid.as_str()).child(INDEX_FILE_NAME)) - .await?, + scheduler.open_file(&uri).await?, None, Arc::::default(), &file_metadata_cache, @@ -190,11 +197,11 @@ impl IVFIndex { let num_partitions = ivf.num_partitions(); Ok(Self { + uri: to_local_path(&uri), uuid, ivf, reader: index_reader, storage, - partition_cache: Cache::new(index_cache_capacity), partition_locks: PartitionLoadLock::new(num_partitions), sub_index_metadata, distance_type, @@ -203,16 +210,25 @@ impl IVFIndex { }) } - #[instrument(level = "debug", skip(self))] + #[instrument(level = "debug", skip(self, metrics))] pub async fn load_partition( &self, partition_id: usize, write_cache: bool, - ) -> Result>> { + metrics: &dyn MetricsCollector, + ) -> Result> { let cache_key = format!("{}-ivf-{}", self.uuid, partition_id); - let part_entry = if let Some(part_idx) = self.partition_cache.get(&cache_key) { + let session = self.session.upgrade().ok_or(Error::Internal { + message: "attempt to use index after dataset was destroyed".into(), + location: location!(), + })?; + let part_entry = if let Some(part_idx) = + session.index_cache.get_vector_partition(&cache_key) + { part_idx } else { + info!(target: TRACE_IO_EVENTS, type=IO_TYPE_LOAD_VECTOR_PART, index_type="ivf", part_id=cache_key); + metrics.record_part_load(); if partition_id >= self.ivf.num_partitions() { return Err(Error::Index { message: format!( @@ -229,7 +245,7 @@ impl IVFIndex { // check the cache again, as the partition may have been loaded by another // thread that held the lock on loading the partition - if let Some(part_idx) = self.partition_cache.get(&cache_key) { + if let Some(part_idx) = session.index_cache.get_vector_partition(&cache_key) { part_idx } else { let schema = Arc::new(self.reader.schema().as_ref().into()); @@ -260,13 +276,14 @@ impl IVFIndex { )?; let idx = S::load(batch)?; let storage = self.load_partition_storage(partition_id).await?; - let partition_entry = Arc::new(PartitionEntry { + let partition_entry = Arc::new(PartitionEntry:: { index: idx, storage, }); if write_cache { - self.partition_cache - .insert(cache_key.clone(), partition_entry.clone()); + session + .index_cache + .insert_vector_partition(&cache_key, partition_entry.clone()); } partition_entry @@ -317,6 +334,11 @@ impl Index for IVFIndex Result<()> { + // TODO: We should prewarm the IVF index by loading the partitions into memory + Ok(()) + } + fn index_type(&self) -> IndexType { match self.sub_index_type() { (SubIndexType::Flat, QuantizationType::Flat) => IndexType::IvfFlat, @@ -339,22 +361,55 @@ impl Index for IVFIndex = if let Some(metadata) = self.sub_index_metadata.iter().find(|m| !m.is_empty()) { serde_json::from_str(metadata)? } else { - json!({}) + serde_json::map::Map::new() }; - sub_index_stats["index_type"] = S::name().into(); + let mut store_stats = serde_json::to_value(self.storage.metadata::()?)?; + let store_stats = store_stats.as_object_mut().ok_or(Error::Internal { + message: "failed to get storage metadata".to_string(), + location: location!(), + })?; + + sub_index_stats.append(store_stats); + if S::name() == "FLAT" { + sub_index_stats.insert( + "index_type".to_string(), + Q::quantization_type().to_string().into(), + ); + } else { + sub_index_stats.insert("index_type".to_string(), S::name().into()); + } + + let sub_index_distance_type = if matches!(Q::quantization_type(), QuantizationType::Product) + && self.distance_type == DistanceType::Cosine + { + DistanceType::L2 + } else { + self.distance_type + }; + sub_index_stats.insert( + "metric_type".to_string(), + sub_index_distance_type.to_string().into(), + ); + + // we need to drop some stats from the metadata + sub_index_stats.remove("codebook_position"); + sub_index_stats.remove("codebook"); + sub_index_stats.remove("codebook_tensor"); + Ok(serde_json::to_value(IvfIndexStatistics { index_type, uuid: self.uuid.clone(), - uri: self.uuid.clone(), + uri: self.uri.clone(), metric_type: self.distance_type.to_string(), num_partitions: self.ivf.num_partitions(), - sub_index: sub_index_stats, + sub_index: serde_json::Value::Object(sub_index_stats), partitions: partitions_statistics, centroids: centroid_vecs, + loss: self.ivf.loss(), })?) } @@ -366,10 +421,13 @@ impl Index for IVFIndex VectorIndex - for IVFIndex -{ - async fn search(&self, query: &Query, pre_filter: Arc) -> Result { +impl VectorIndex for IVFIndex { + async fn search( + &self, + query: &Query, + pre_filter: Arc, + metrics: &dyn MetricsCollector, + ) -> Result { let mut query = query.clone(); if self.distance_type == DistanceType::Cosine { let key = normalize_arrow(&query.key)?; @@ -380,7 +438,9 @@ impl>() .await?; @@ -414,39 +474,45 @@ impl) -> Result<()> { - // IvfIndexBuilder::new( - // dataset, - // column, - // index_dir, - // distance_type, - // shuffler, - // ivf_params, - // sub_index_params, - // quantizer_params, - // ) - // } - - #[instrument(level = "debug", skip(self, pre_filter))] + #[instrument(level = "debug", skip(self, pre_filter, metrics))] async fn search_in_partition( &self, partition_id: usize, query: &Query, pre_filter: Arc, + metrics: &dyn MetricsCollector, ) -> Result { - let part_entry = self.load_partition(partition_id, true).await?; + let part_entry = self.load_partition(partition_id, true, metrics).await?; pre_filter.wait_for_ready().await?; let query = self.preprocess_query(partition_id, query)?; - spawn_cpu(move || { + let (batch, local_metrics) = spawn_cpu(move || { let param = (&query).into(); let refine_factor = query.refine_factor.unwrap_or(1) as usize; let k = query.k * refine_factor; - part_entry - .index - .search(query.key, k, param, &part_entry.storage, pre_filter) + let local_metrics = LocalMetricsCollector::default(); + let part = part_entry + .as_any() + .downcast_ref::>() + .ok_or(Error::Internal { + message: "failed to downcast partition entry".to_string(), + location: location!(), + })?; + let batch = part.index.search( + query.key, + k, + param, + &part.storage, + pre_filter, + &local_metrics, + )?; + Ok((batch, local_metrics)) }) - .await + .await?; + + local_metrics.dump_into(metrics); + + Ok(batch) } fn is_loadable(&self) -> bool { @@ -473,25 +539,84 @@ impl Result { + let partition = self.load_partition(partition_id, false, metrics).await?; + let partition = partition + .as_any() + .downcast_ref::>() + .ok_or(Error::Internal { + message: "failed to downcast partition entry".to_string(), + location: location!(), + })?; + let store = &partition.storage; + let schema = if with_vector { + store.schema().clone() + } else { + let schema = store.schema(); + let row_id_idx = schema.index_of(ROW_ID)?; + Arc::new(store.schema().project(&[row_id_idx])?) + }; + + let batches = store + .to_batches()? + .map(|b| { + let batch = b.project_by_schema(&schema)?; + Ok(batch) + }) + .collect::>(); + let stream = RecordBatchStreamAdapter::new(schema, stream::iter(batches)); + Ok(Box::pin(stream)) + } + + async fn to_batch_stream(&self, _with_vector: bool) -> Result { + unimplemented!("this method is for only sub index"); + } + + fn num_rows(&self) -> u64 { + self.storage.num_rows() + } + fn row_ids(&self) -> Box + '_> { todo!("this method is for only IVF_HNSW_* index"); } - fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { - // This will be needed if we want to clean up IVF to allow more than just - // one layer (e.g. IVF -> IVF -> PQ). We need to pass on the call to - // remap to the lower layers. - - // Currently, remapping for IVF is implemented in remap_index_file which - // mirrors some of the other IVF routines like build_ivf_pq_index + async fn remap(&mut self, _mapping: &HashMap>) -> Result<()> { Err(Error::Index { message: "Remapping IVF in this way not supported".to_string(), location: location!(), }) } - fn ivf_model(&self) -> IvfModel { - self.ivf.clone() + async fn remap_to( + self: Arc, + store: ObjectStore, + mapping: &HashMap>, + column: String, + index_dir: Path, + ) -> Result<()> { + match self.sub_index_type() { + (SubIndexType::Flat, _) => { + let mut remapper = + IvfIndexBuilder::::new_remapper(store, column, index_dir, self)?; + remapper.remap(mapping).await + } + _ => Err(Error::Index { + message: format!( + "Remapping is not supported for index type {}", + self.index_type(), + ), + location: location!(), + }), + } + } + + fn ivf_model(&self) -> &IvfModel { + &self.ivf } fn quantizer(&self) -> Quantizer { @@ -516,86 +641,252 @@ pub type IvfHnswPqIndex = IVFIndex; #[cfg(test)] mod tests { use std::collections::HashSet; - use std::{collections::HashMap, ops::Range, sync::Arc}; + use std::{ops::Range, sync::Arc}; - use arrow::datatypes::UInt64Type; + use all_asserts::{assert_ge, assert_lt}; + use arrow::datatypes::{UInt64Type, UInt8Type}; use arrow::{array::AsArray, datatypes::Float32Type}; - use arrow_array::{Array, FixedSizeListArray, RecordBatch, RecordBatchIterator}; - use arrow_schema::{DataType, Field, Schema}; + use arrow_array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, FixedSizeListArray, Float32Array, + ListArray, RecordBatch, RecordBatchIterator, UInt64Array, + }; + use arrow_buffer::OffsetBuffer; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use itertools::Itertools; use lance_arrow::FixedSizeListArrayExt; use lance_core::ROW_ID; + use lance_index::metrics::NoOpMetricsCollector; + use lance_index::optimize::OptimizeOptions; use lance_index::vector::hnsw::builder::HnswBuildParams; + use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; use lance_index::vector::sq::builder::SQBuildParams; use lance_index::vector::DIST_COL; use lance_index::{DatasetIndexExt, IndexType}; - use lance_linalg::distance::DistanceType; + use lance_linalg::distance::{multivec_distance, DistanceType}; + use lance_linalg::kernels::normalize_fsl; use lance_testing::datagen::generate_random_array_with_range; + use rand::distributions::uniform::SampleUniform; use rstest::rstest; use tempfile::tempdir; + use crate::dataset::optimize::{compact_files, CompactionOptions}; + use crate::dataset::{UpdateBuilder, WriteParams}; + use crate::index::DatasetIndexInternalExt; use crate::{index::vector::VectorIndexParams, Dataset}; + const NUM_ROWS: usize = 500; const DIM: usize = 32; - async fn generate_test_dataset( + async fn generate_test_dataset( test_uri: &str, - range: Range, - ) -> (Dataset, Arc) { - let vectors = generate_random_array_with_range::(1000 * DIM, range); - let metadata: HashMap = vec![("test".to_string(), "ivf_pq".to_string())] - .into_iter() - .collect(); - - let schema: Arc<_> = Schema::new(vec![Field::new( - "vector", - DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Float32, true)), - DIM as i32, - ), - true, - )]) - .with_metadata(metadata) - .into(); - let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); - let fsl = lance_linalg::kernels::normalize_fsl(&fsl).unwrap(); - let array = Arc::new(fsl); - let batch = RecordBatch::try_new(schema.clone(), vec![array.clone()]).unwrap(); - - let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); + range: Range, + ) -> (Dataset, Arc) + where + T::Native: SampleUniform, + { + let (batch, schema) = generate_batch::(NUM_ROWS, None, range, false); + let vectors = batch.column_by_name("vector").unwrap().clone(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let dataset = Dataset::write( + batches, + test_uri, + Some(WriteParams { + mode: crate::dataset::WriteMode::Overwrite, + ..Default::default() + }), + ) + .await + .unwrap(); + (dataset, Arc::new(vectors.as_fixed_size_list().clone())) + } + + async fn generate_multivec_test_dataset( + test_uri: &str, + range: Range, + ) -> (Dataset, Arc) + where + T::Native: SampleUniform, + { + let (batch, schema) = generate_batch::(NUM_ROWS, None, range, true); + let vectors = batch.column_by_name("vector").unwrap().clone(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); let dataset = Dataset::write(batches, test_uri, None).await.unwrap(); - (dataset, array) + (dataset, Arc::new(vectors.as_list::().clone())) + } + + async fn append_dataset( + dataset: &mut Dataset, + num_rows: usize, + range: Range, + ) -> ArrayRef + where + T::Native: SampleUniform, + { + let is_multivector = matches!( + dataset.schema().field("vector").unwrap().data_type(), + DataType::List(_) + ); + let row_count = dataset.count_all_rows().await.unwrap(); + let (batch, schema) = + generate_batch::(num_rows, Some(row_count as u64), range, is_multivector); + let vectors = batch["vector"].clone(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + dataset.append(batches, None).await.unwrap(); + vectors + } + + fn generate_batch( + num_rows: usize, + start_id: Option, + range: Range, + is_multivector: bool, + ) -> (RecordBatch, SchemaRef) + where + T::Native: SampleUniform, + { + const VECTOR_NUM_PER_ROW: usize = 3; + let start_id = start_id.unwrap_or(0); + let ids = Arc::new(UInt64Array::from_iter_values( + start_id..start_id + num_rows as u64, + )); + let total_floats = match is_multivector { + true => num_rows * VECTOR_NUM_PER_ROW * DIM, + false => num_rows * DIM, + }; + let vectors = generate_random_array_with_range::(total_floats, range); + let data_type = vectors.data_type().clone(); + let mut fields = vec![Field::new("id", DataType::UInt64, false)]; + let mut arrays: Vec = vec![ids]; + let mut fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + if fsl.value_type() != DataType::UInt8 { + fsl = normalize_fsl(&fsl).unwrap(); + } + if is_multivector { + let vector_field = Arc::new(Field::new( + "item", + DataType::FixedSizeList(Arc::new(Field::new("item", data_type, true)), DIM as i32), + true, + )); + fields.push(Field::new( + "vector", + DataType::List(vector_field.clone()), + true, + )); + let array = Arc::new(ListArray::new( + vector_field, + OffsetBuffer::from_lengths(std::iter::repeat_n(VECTOR_NUM_PER_ROW, num_rows)), + Arc::new(fsl), + None, + )); + arrays.push(array); + } else { + fields.push(Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", data_type, true)), DIM as i32), + true, + )); + let array = Arc::new(fsl); + arrays.push(array); + } + let schema: Arc<_> = Schema::new(fields).into(); + let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); + (batch, schema) } #[allow(dead_code)] - fn ground_truth( - vectors: &FixedSizeListArray, - query: &[f32], + async fn ground_truth( + dataset: &Dataset, + column: &str, + query: &dyn Array, + k: usize, + distance_type: DistanceType, + ) -> HashSet { + let batch = dataset + .scan() + .with_row_id() + .nearest(column, query, k) + .unwrap() + .distance_metric(distance_type) + .use_index(false) + .try_into_batch() + .await + .unwrap(); + batch[ROW_ID] + .as_primitive::() + .values() + .iter() + .copied() + .collect() + } + + #[allow(dead_code)] + fn multivec_ground_truth( + vectors: &ListArray, + query: &dyn Array, k: usize, distance_type: DistanceType, ) -> Vec<(f32, u64)> { - let mut dists = vec![]; - for i in 0..vectors.len() { - let dist = distance_type.func()( - query, - vectors.value(i).as_primitive::().values(), - ); - dists.push((dist, i as u64)); + let query = if let Some(list_array) = query.as_list_opt::() { + list_array.values().clone() + } else { + query.as_fixed_size_list().values().clone() + }; + multivec_distance(&query, vectors, distance_type) + .unwrap() + .into_iter() + .enumerate() + .map(|(i, dist)| (dist, i as u64)) + .sorted_by(|a, b| a.0.total_cmp(&b.0)) + .take(k) + .collect() + } + + async fn test_index( + params: VectorIndexParams, + nlist: usize, + recall_requirement: f32, + dataset: Option<(Dataset, Arc)>, + ) { + match params.metric_type { + DistanceType::Hamming => { + test_index_impl::(params, nlist, recall_requirement, 0..4, dataset) + .await; + } + _ => { + test_index_impl::( + params, + nlist, + recall_requirement, + 0.0..1.0, + dataset, + ) + .await; + } } - dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); - dists.truncate(k); - dists } - async fn test_index(params: VectorIndexParams, nlist: usize, recall_requirement: f32) { + async fn test_index_impl( + params: VectorIndexParams, + nlist: usize, + recall_requirement: f32, + range: Range, + dataset: Option<(Dataset, Arc)>, + ) where + T::Native: SampleUniform, + { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let (mut dataset, vectors) = generate_test_dataset(test_uri, 0.0..1.0).await; + let (mut dataset, vectors) = match dataset { + Some((dataset, vectors)) => (dataset, vectors), + None => generate_test_dataset::(test_uri, range).await, + }; + let vector_column = "vector"; dataset - .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) .await .unwrap(); @@ -603,7 +894,7 @@ mod tests { let k = 100; let result = dataset .scan() - .nearest("vector", query.as_primitive::(), k) + .nearest(vector_column, query.as_primitive::(), k) .unwrap() .nprobs(nlist) .with_row_id() @@ -624,16 +915,11 @@ mod tests { .zip(row_ids.into_iter()) .collect::>(); let row_ids = results.iter().map(|(_, id)| *id).collect::>(); + assert!(row_ids.len() == k); - let gt = ground_truth( - &vectors, - query.as_primitive::().values(), - k, - params.metric_type, - ); - let gt_set = gt.iter().map(|r| r.1).collect::>(); + let gt = ground_truth(&dataset, vector_column, &query, k, params.metric_type).await; - let recall = row_ids.intersection(>_set).count() as f32 / k as f32; + let recall = row_ids.intersection(>).count() as f32 / k as f32; assert!( recall >= recall_requirement, "recall: {}\n results: {:?}\n\ngt: {:?}", @@ -643,10 +929,254 @@ mod tests { ); } + async fn test_remap(params: VectorIndexParams, nlist: usize) { + match params.metric_type { + DistanceType::Hamming => { + test_remap_impl::(params, nlist, 0..4).await; + } + _ => { + test_remap_impl::(params, nlist, 0.0..1.0).await; + } + } + } + + async fn test_remap_impl( + params: VectorIndexParams, + nlist: usize, + range: Range, + ) where + T::Native: SampleUniform, + { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let (mut dataset, vectors) = generate_test_dataset::(test_uri, range.clone()).await; + + let vector_column = "vector"; + dataset + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + let query = vectors.value(0); + // delete half rows to trigger compact + let half_rows = NUM_ROWS / 2; + dataset + .delete(&format!("id < {}", half_rows)) + .await + .unwrap(); + // update the other half rows + let update_result = UpdateBuilder::new(Arc::new(dataset)) + .update_where(&format!("id >= {} and id<{}", half_rows, half_rows + 50)) + .unwrap() + .set("id", &format!("{}+id", NUM_ROWS)) + .unwrap() + .build() + .unwrap() + .execute() + .await + .unwrap(); + let mut dataset = Dataset::open(update_result.new_dataset.uri()) + .await + .unwrap(); + let num_rows = dataset.count_rows(None).await.unwrap(); + assert_eq!(num_rows, half_rows); + compact_files(&mut dataset, CompactionOptions::default(), None) + .await + .unwrap(); + // query again, the result should not include the deleted row + let result = dataset.scan().try_into_batch().await.unwrap(); + let ids = result["id"].as_primitive::(); + assert_eq!(ids.len(), half_rows); + ids.values().iter().for_each(|id| { + assert!(*id >= half_rows as u64 + 50); + }); + + // make sure we can still hit the recall + let gt = ground_truth(&dataset, vector_column, &query, 100, params.metric_type).await; + let results = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), 100) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + let row_ids = results[ROW_ID] + .as_primitive::() + .values() + .iter() + .copied() + .collect::>(); + let recall = row_ids.intersection(>).count() as f32 / 100.0; + assert_ge!(recall, 0.8, "{}", recall); + + // delete so that only one row left, to trigger remap and there must be some empty partitions + let (mut dataset, _) = generate_test_dataset::(test_uri, range).await; + dataset + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + assert_eq!(dataset.load_indices().await.unwrap().len(), 1); + dataset.delete("id > 0").await.unwrap(); + assert_eq!(dataset.count_rows(None).await.unwrap(), 1); + assert_eq!(dataset.load_indices().await.unwrap().len(), 1); + compact_files(&mut dataset, CompactionOptions::default(), None) + .await + .unwrap(); + let results = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), 100) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + assert_eq!(results.num_rows(), 1); + } + + async fn test_optimize_strategy(params: VectorIndexParams) { + match params.metric_type { + DistanceType::Hamming => { + test_optimize_strategy_impl::(params, 0..4).await; + } + _ => { + test_optimize_strategy_impl::(params, 0.0..1.0).await; + } + } + } + + async fn test_optimize_strategy_impl( + params: VectorIndexParams, + range: Range, + ) where + T::Native: SampleUniform, + { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let (mut dataset, _) = generate_test_dataset::(test_uri, range.clone()).await; + + let vector_column = "vector"; + dataset + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + async fn get_ivf_models(dataset: &Dataset) -> Vec { + let indices = dataset.load_indices_by_name("vector_idx").await.unwrap(); + let mut ivf_models = vec![]; + for idx in indices { + let index = dataset + .open_vector_index( + "vector", + idx.uuid.to_string().as_str(), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + ivf_models.push(index.ivf_model().clone()); + } + ivf_models + } + + async fn get_losses(dataset: &Dataset) -> Vec> { + let stats = dataset.index_statistics("vector_idx").await.unwrap(); + let stats: serde_json::Value = serde_json::from_str(&stats).unwrap(); + stats["indices"] + .as_array() + .unwrap() + .iter() + .flat_map(|s| s.get("loss").map(|l| l.as_f64())) + .collect() + } + + async fn get_avg_loss(dataset: &Dataset) -> f64 { + let losses = get_losses(dataset).await; + let total_loss = losses.iter().filter_map(|l| *l).sum::(); + let num_rows = dataset.count_rows(None).await.unwrap(); + total_loss / num_rows as f64 + } + + const AVG_LOSS_RETRAIN_THRESHOLD: f64 = 1.1; + let original_ivfs = get_ivf_models(&dataset).await; + let original_avg_loss = get_avg_loss(&dataset).await; + let original_ivf = &original_ivfs[0]; + let mut count = 0; + #[allow(unused_assignments)] + let mut last_avg_loss = original_avg_loss; + // append more rows and make delta index until hitting the retrain threshold + loop { + let range = match count { + 0 => range.clone(), + _ => match params.metric_type { + DistanceType::Hamming => range.start..range.end.add_wrapping(range.end), + _ => range.end.neg_wrapping()..range.start, + }, + }; + append_dataset::(&mut dataset, NUM_ROWS / 5, range).await; + dataset + .optimize_indices(&OptimizeOptions::append()) + .await + .unwrap(); + count += 1; + + last_avg_loss = get_avg_loss(&dataset).await; + if last_avg_loss / original_avg_loss >= AVG_LOSS_RETRAIN_THRESHOLD { + if count <= 1 { + // the first append is with the same data distribution, so the loss should be + // very close to the original loss, then it shouldn't hit the retrain threshold + panic!( + "retrain threshold {} should not be hit", + AVG_LOSS_RETRAIN_THRESHOLD + ); + } + + break; + } + if count >= 10 { + panic!( + "failed to hit the retrain threshold {} < {}", + last_avg_loss / original_avg_loss, + AVG_LOSS_RETRAIN_THRESHOLD + ); + } + + // all delta indices should have the same centroids as the original index + let ivf_models = get_ivf_models(&dataset).await; + assert_eq!(ivf_models.len(), count + 1); + for ivf in ivf_models { + assert_eq!(original_ivf.centroids, ivf.centroids); + } + } + + // this optimize would merge all indices and retrain the IVF + dataset + .optimize_indices(&OptimizeOptions::retrain()) + .await + .unwrap(); + let stats = dataset.index_statistics("vector_idx").await.unwrap(); + let stats: serde_json::Value = serde_json::from_str(&stats).unwrap(); + assert_eq!(stats["num_indices"], 1); + + let ivf_models = get_ivf_models(&dataset).await; + let ivf = &ivf_models[0]; + assert_ne!(original_ivf.centroids, ivf.centroids); + if ivf.num_partitions() > 1 && params.metric_type != DistanceType::Hamming { + assert_lt!(get_avg_loss(&dataset).await, last_avg_loss); + } + } + + #[tokio::test] + async fn test_flat_knn() { + test_distance_range(None, 4).await; + } + #[rstest] #[case(4, DistanceType::L2, 1.0)] #[case(4, DistanceType::Cosine, 1.0)] #[case(4, DistanceType::Dot, 1.0)] + #[case(4, DistanceType::Hamming, 0.9)] #[tokio::test] async fn test_build_ivf_flat( #[case] nlist: usize, @@ -654,13 +1184,19 @@ mod tests { #[case] recall_requirement: f32, ) { let params = VectorIndexParams::ivf_flat(nlist, distance_type); - test_index(params, nlist, recall_requirement).await; + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_distance_range(Some(params.clone()), nlist).await; + test_remap(params.clone(), nlist).await; + test_optimize_strategy(params).await; } #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_build_ivf_pq( #[case] nlist: usize, @@ -669,14 +1205,24 @@ mod tests { ) { let ivf_params = IvfBuildParams::new(nlist); let pq_params = PQBuildParams::default(); - let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params); - test_index(params, nlist, recall_requirement).await; + let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params) + .version(crate::index::vector::IndexFileVersion::Legacy) + .clone(); + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_distance_range(Some(params.clone()), nlist).await; + test_remap(params, nlist).await; } #[rstest] + #[case(1, DistanceType::L2, 0.9)] + #[case(1, DistanceType::Cosine, 0.9)] + #[case(1, DistanceType::Dot, 0.85)] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_build_ivf_pq_v3( #[case] nlist: usize, @@ -685,16 +1231,20 @@ mod tests { ) { let ivf_params = IvfBuildParams::new(nlist); let pq_params = PQBuildParams::default(); - let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params) - .version(crate::index::vector::IndexFileVersion::V3) - .clone(); - test_index(params, nlist, recall_requirement).await; + let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params); + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_distance_range(Some(params.clone()), nlist).await; + test_remap(params.clone(), nlist).await; + test_optimize_strategy(params).await; } #[rstest] - #[case(4, DistanceType::L2, 0.9)] - #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.8)] + #[case(4, DistanceType::L2, 0.85)] + #[case(4, DistanceType::Cosine, 0.85)] + #[case(4, DistanceType::Dot, 0.75)] #[tokio::test] async fn test_build_ivf_pq_4bit( #[case] nlist: usize, @@ -703,16 +1253,19 @@ mod tests { ) { let ivf_params = IvfBuildParams::new(nlist); let pq_params = PQBuildParams::new(32, 4); - let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params) - .version(crate::index::vector::IndexFileVersion::V3) - .clone(); - test_index(params, nlist, recall_requirement).await; + let params = VectorIndexParams::with_ivf_pq_params(distance_type, ivf_params, pq_params); + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_remap(params.clone(), nlist).await; + test_optimize_strategy(params).await; } #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_create_ivf_hnsw_sq( #[case] nlist: usize, @@ -728,13 +1281,17 @@ mod tests { hnsw_params, sq_params, ); - test_index(params, nlist, recall_requirement).await; + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_optimize_strategy(params).await; } #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] - #[case(4, DistanceType::Dot, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] #[tokio::test] async fn test_create_ivf_hnsw_pq( #[case] nlist: usize, @@ -750,12 +1307,16 @@ mod tests { hnsw_params, pq_params, ); - test_index(params, nlist, recall_requirement).await; + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_optimize_strategy(params).await; } #[rstest] - #[case(4, DistanceType::L2, 0.9)] - #[case(4, DistanceType::Cosine, 0.9)] + #[case(4, DistanceType::L2, 0.85)] + #[case(4, DistanceType::Cosine, 0.85)] #[case(4, DistanceType::Dot, 0.8)] #[tokio::test] async fn test_create_ivf_hnsw_pq_4bit( @@ -772,27 +1333,173 @@ mod tests { hnsw_params, pq_params, ); - test_index(params, nlist, recall_requirement).await; + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_optimize_strategy(params).await; } - #[tokio::test] - async fn test_index_stats() { + async fn test_index_multivec(params: VectorIndexParams, nlist: usize, recall_requirement: f32) { + // we introduce XTR for performance, which would reduce the recall a little bit + let recall_requirement = recall_requirement * 0.9; + match params.metric_type { + DistanceType::Hamming => { + test_index_multivec_impl::(params, nlist, recall_requirement, 0..4) + .await; + } + _ => { + test_index_multivec_impl::( + params, + nlist, + recall_requirement, + 0.0..1.0, + ) + .await; + } + } + } + + async fn test_index_multivec_impl( + params: VectorIndexParams, + nlist: usize, + recall_requirement: f32, + range: Range, + ) where + T::Native: SampleUniform, + { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let nlist = 4; - let (mut dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; + let (mut dataset, vectors) = generate_multivec_test_dataset::(test_uri, range).await; - let ivf_params = IvfBuildParams::new(nlist); - let sq_params = SQBuildParams::default(); - let hnsw_params = HnswBuildParams::default(); - let params = VectorIndexParams::with_ivf_hnsw_sq_params( - DistanceType::L2, - ivf_params, - hnsw_params, - sq_params, + dataset + .create_index( + &["vector"], + IndexType::Vector, + Some("test_index".to_owned()), + ¶ms, + true, + ) + .await + .unwrap(); + + let query = vectors.value(0); + let k = 100; + + let result = dataset + .scan() + .nearest("vector", &query, k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + let row_ids = result[ROW_ID] + .as_primitive::() + .values() + .to_vec(); + let dists = result[DIST_COL] + .as_primitive::() + .values() + .to_vec(); + let results = dists + .into_iter() + .zip(row_ids.clone().into_iter()) + .collect::>(); + let row_ids = row_ids.into_iter().collect::>(); + + let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type); + let gt_set = gt.iter().map(|r| r.1).collect::>(); + + let recall = row_ids.intersection(>_set).count() as f32 / 100.0; + assert!( + recall >= recall_requirement, + "recall: {}\n results: {:?}\n\ngt: {:?}", + recall, + results, + gt ); + } + #[rstest] + #[tokio::test] + async fn test_migrate_v1_to_v3() { + // only test the case of IVF_PQ + // because only IVF_PQ is supported in v1 + let nlist = 4; + let recall_requirement = 0.9; + let ivf_params = IvfBuildParams::new(nlist); + let pq_params = PQBuildParams::default(); + let v1_params = + VectorIndexParams::with_ivf_pq_params(DistanceType::Cosine, ivf_params, pq_params) + .version(crate::index::vector::IndexFileVersion::Legacy) + .clone(); + + let v3_params = v1_params + .clone() + .version(crate::index::vector::IndexFileVersion::V3) + .clone(); + + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let (mut dataset, vectors) = generate_test_dataset::(test_uri, 0.0..1.0).await; + test_index( + v1_params, + nlist, + recall_requirement, + Some((dataset.clone(), vectors.clone())), + ) + .await; + // retest with v3 params on the same dataset + test_index( + v3_params, + nlist, + recall_requirement, + Some((dataset.clone(), vectors)), + ) + .await; + + dataset.checkout_latest().await.unwrap(); + let indices = dataset.load_indices_by_name("vector_idx").await.unwrap(); + assert_eq!(indices.len(), 1); // v1 index should be replaced by v3 index + let index = dataset + .open_vector_index( + "vector", + indices[0].uuid.to_string().as_str(), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let v3_index = index.as_any().downcast_ref::(); + assert!(v3_index.is_some()); + } + + #[rstest] + #[tokio::test] + async fn test_index_stats( + #[values( + (VectorIndexParams::ivf_flat(4, DistanceType::Hamming), IndexType::IvfFlat), + (VectorIndexParams::ivf_pq(4, 8, 8, DistanceType::L2, 10), IndexType::IvfPq), + (VectorIndexParams::with_ivf_hnsw_sq_params( + DistanceType::Cosine, + IvfBuildParams::new(4), + Default::default(), + Default::default() + ), IndexType::IvfHnswSq), + )] + index: (VectorIndexParams, IndexType), + ) { + let (params, index_type) = index; + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let nlist = 4; + let (mut dataset, _) = match params.metric_type { + DistanceType::Hamming => generate_test_dataset::(test_uri, 0..2).await, + _ => generate_test_dataset::(test_uri, 0.0..1.0).await, + }; dataset .create_index( &["vector"], @@ -807,14 +1514,29 @@ mod tests { let stats = dataset.index_statistics("test_index").await.unwrap(); let stats: serde_json::Value = serde_json::from_str(stats.as_str()).unwrap(); - assert_eq!(stats["index_type"].as_str().unwrap(), "IVF_HNSW_SQ"); + assert_eq!( + stats["index_type"].as_str().unwrap(), + index_type.to_string() + ); for index in stats["indices"].as_array().unwrap() { - assert_eq!(index["index_type"].as_str().unwrap(), "IVF_HNSW_SQ"); + assert_eq!( + index["index_type"].as_str().unwrap(), + index_type.to_string() + ); assert_eq!( index["num_partitions"].as_number().unwrap(), &serde_json::Number::from(nlist) ); - assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW"); + + let sub_index = match index_type { + IndexType::IvfHnswPq | IndexType::IvfHnswSq => "HNSW", + IndexType::IvfPq => "PQ", + _ => "FLAT", + }; + assert_eq!( + index["sub_index"]["index_type"].as_str().unwrap(), + sub_index + ); } } @@ -823,8 +1545,8 @@ mod tests { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let nlist = 1000; - let (mut dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; + let nlist = 500; + let (mut dataset, _) = generate_test_dataset::(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::new(nlist); let sq_params = SQBuildParams::default(); @@ -860,4 +1582,167 @@ mod tests { assert_eq!(index["sub_index"]["index_type"].as_str().unwrap(), "HNSW"); } } + + async fn test_distance_range(params: Option, nlist: usize) { + match params.as_ref().map_or(DistanceType::L2, |p| p.metric_type) { + DistanceType::Hamming => { + test_distance_range_impl::(params, nlist, 0..255).await; + } + _ => { + test_distance_range_impl::(params, nlist, 0.0..1.0).await; + } + } + } + + async fn test_distance_range_impl( + params: Option, + nlist: usize, + range: Range, + ) where + T::Native: SampleUniform, + { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let (mut dataset, vectors) = generate_test_dataset::(test_uri, range).await; + + let vector_column = "vector"; + let dist_type = params.as_ref().map_or(DistanceType::L2, |p| p.metric_type); + if let Some(params) = params { + dataset + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + } + + let query = vectors.value(0); + let k = 10; + let result = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + assert_eq!(result.num_rows(), k); + let row_ids = result[ROW_ID].as_primitive::().values(); + let dists = result[DIST_COL].as_primitive::().values(); + + let part_idx = k / 2; + let part_dist = dists[part_idx]; + + let left_res = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), part_idx) + .unwrap() + .nprobs(nlist) + .with_row_id() + .distance_range(None, Some(part_dist)) + .try_into_batch() + .await + .unwrap(); + let right_res = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), k - part_idx) + .unwrap() + .nprobs(nlist) + .with_row_id() + .distance_range(Some(part_dist), None) + .try_into_batch() + .await + .unwrap(); + // don't verify the number of results and row ids for hamming distance, + // because there are many vectors with the same distance + if dist_type != DistanceType::Hamming { + assert_eq!(left_res.num_rows(), part_idx); + assert_eq!(right_res.num_rows(), k - part_idx); + let left_row_ids = left_res[ROW_ID].as_primitive::().values(); + let right_row_ids = right_res[ROW_ID].as_primitive::().values(); + row_ids.iter().enumerate().for_each(|(i, id)| { + if i < part_idx { + assert_eq!(left_row_ids[i], *id); + } else { + assert_eq!(right_row_ids[i - part_idx], *id, "{:?}", right_row_ids); + } + }); + } + let left_dists = left_res[DIST_COL].as_primitive::().values(); + let right_dists = right_res[DIST_COL].as_primitive::().values(); + left_dists.iter().for_each(|d| { + assert!(d < &part_dist); + }); + right_dists.iter().for_each(|d| { + assert!(d >= &part_dist); + }); + + let exclude_last_res = dataset + .scan() + .nearest(vector_column, query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .distance_range(dists.first().copied(), dists.last().copied()) + .try_into_batch() + .await + .unwrap(); + if dist_type != DistanceType::Hamming { + assert_eq!(exclude_last_res.num_rows(), k - 1); + let res_row_ids = exclude_last_res[ROW_ID] + .as_primitive::() + .values(); + row_ids.iter().enumerate().for_each(|(i, id)| { + if i < k - 1 { + assert_eq!(res_row_ids[i], *id); + } + }); + } + let res_dists = exclude_last_res[DIST_COL] + .as_primitive::() + .values(); + res_dists.iter().for_each(|d| { + assert_ge!(*d, dists[0]); + assert_lt!(*d, dists[k - 1]); + }); + } + + #[tokio::test] + async fn test_index_with_zero_vectors() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let (batch, schema) = generate_batch::(256, None, 0.0..1.0, false); + let vector_field = schema.field(1).clone(); + let zero_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt64Array::from(vec![256])), + Arc::new( + FixedSizeListArray::try_new_from_values( + Float32Array::from(vec![0.0; DIM]), + DIM as i32, + ) + .unwrap(), + ), + ], + ) + .unwrap(); + let batches = RecordBatchIterator::new(vec![batch, zero_batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write( + batches, + test_uri, + Some(WriteParams { + mode: crate::dataset::WriteMode::Overwrite, + ..Default::default() + }), + ) + .await + .unwrap(); + + let vector_column = vector_field.name(); + let params = VectorIndexParams::ivf_pq(4, 8, DIM / 8, DistanceType::Cosine, 50); + dataset + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + } } diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 91973d4a350..20bccbf93c1 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -5,25 +5,28 @@ use std::sync::Arc; use std::{any::Any, collections::HashMap}; use arrow::compute::concat; -use arrow_array::UInt32Array; use arrow_array::{ cast::{as_primitive_array, AsArray}, Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array, }; +use arrow_array::{ArrayRef, Float32Array, UInt32Array}; use arrow_ord::sort::sort_to_indices; -use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; +use arrow_schema::{DataType, Field, Schema}; use arrow_select::take::take; use async_trait::async_trait; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use deepsize::DeepSizeOf; +use lance_core::utils::address::RowAddress; use lance_core::utils::tokio::spawn_cpu; -use lance_core::ROW_ID; -use lance_core::{utils::address::RowAddress, ROW_ID_FIELD}; +use lance_core::{ROW_ID, ROW_ID_FIELD}; +use lance_index::metrics::MetricsCollector; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::storage::{transpose, ProductQuantizationStorage}; use lance_index::vector::quantizer::{Quantization, QuantizationType, Quantizer}; use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ - vector::{pq::ProductQuantizer, Query, DIST_COL}, + vector::{pq::ProductQuantizer, Query}, Index, IndexType, }; use lance_io::{traits::Reader, utils::read_fixed_stride_array}; @@ -31,7 +34,7 @@ use lance_linalg::distance::{DistanceType, MetricType}; use log::{info, warn}; use roaring::RoaringBitmap; use serde_json::json; -use snafu::{location, Location}; +use snafu::location; use tracing::{instrument, span, Level}; // Re-export @@ -41,6 +44,7 @@ use lance_linalg::kernels::normalize_fsl; use super::VectorIndex; use crate::index::prefilter::PreFilter; use crate::index::vector::utils::maybe_sample_training_data; +use crate::io::exec::knn::KNN_INDEX_SCHEMA; use crate::{arrow::*, Dataset}; use crate::{Error, Result}; @@ -164,6 +168,11 @@ impl Index for PQIndex { IndexType::Vector } + async fn prewarm(&self) -> Result<()> { + // TODO: Investigate + Ok(()) + } + fn statistics(&self) -> Result { Ok(json!({ "index_type": "PQ", @@ -198,7 +207,12 @@ impl VectorIndex for PQIndex { /// Search top-k nearest neighbors for `key` within one PQ partition. /// #[instrument(level = "debug", skip_all, name = "PQIndex::search")] - async fn search(&self, query: &Query, pre_filter: Arc) -> Result { + async fn search( + &self, + query: &Query, + pre_filter: Arc, + metrics: &dyn MetricsCollector, + ) -> Result { if self.code.is_none() || self.row_ids.is_none() { return Err(Error::Index { message: "PQIndex::search: PQ is not initialized".to_string(), @@ -210,6 +224,8 @@ impl VectorIndex for PQIndex { let code = self.code.as_ref().unwrap().clone(); let row_ids = self.row_ids.as_ref().unwrap().clone(); + metrics.record_comparisons(row_ids.len()); + let pq = self.pq.clone(); let query = query.clone(); let num_sub_vectors = self.pq.code_dim() as i32; @@ -226,15 +242,42 @@ impl VectorIndex for PQIndex { debug_assert_eq!(distances.len(), row_ids.len()); let limit = query.k * query.refine_factor.unwrap_or(1) as usize; - let indices = sort_to_indices(&distances, None, Some(limit))?; - let distances = take(&distances, &indices, None)?; - let row_ids = take(row_ids.as_ref(), &indices, None)?; - - let schema = Arc::new(ArrowSchema::new(vec![ - ArrowField::new(DIST_COL, DataType::Float32, true), - ROW_ID_FIELD.clone(), - ])); - Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?) + if query.lower_bound.is_none() && query.upper_bound.is_none() { + let indices = sort_to_indices(&distances, None, Some(limit))?; + let distances = take(&distances, &indices, None)?; + let row_ids = take(row_ids.as_ref(), &indices, None)?; + Ok(RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![distances, row_ids], + )?) + } else { + let indices = sort_to_indices(&distances, None, None)?; + let mut dists = Vec::with_capacity(limit); + let mut ids = Vec::with_capacity(limit); + for idx in indices.values().iter() { + let dist = distances.value(*idx as usize); + let id = row_ids.value(*idx as usize); + if query.lower_bound.is_some_and(|lb| dist < lb) { + continue; + } + if query.upper_bound.is_some_and(|ub| dist >= ub) { + break; + } + + dists.push(dist); + ids.push(id); + + if dists.len() >= limit { + break; + } + } + let dists = Arc::new(Float32Array::from(dists)); + let ids = Arc::new(UInt64Array::from(ids)); + Ok(RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![dists, ids], + )?) + } }) .await } @@ -248,6 +291,7 @@ impl VectorIndex for PQIndex { _: usize, _: &Query, _: Arc, + _: &dyn MetricsCollector, ) -> Result { unimplemented!("only for IVF") } @@ -301,6 +345,49 @@ impl VectorIndex for PQIndex { })) } + async fn to_batch_stream(&self, with_vector: bool) -> Result { + let row_ids = self.row_ids.clone().ok_or(Error::Index { + message: "PQIndex::to_batch_stream: row ids not loaded for PQ".to_string(), + location: location!(), + })?; + + let num_rows = row_ids.len(); + let mut fields = vec![ROW_ID_FIELD.clone()]; + let mut columns: Vec = vec![row_ids]; + if with_vector { + let transposed_codes = self.code.clone().ok_or(Error::Index { + message: "PQIndex::to_batch_stream: PQ codes not loaded for PQ".to_string(), + location: location!(), + })?; + let original_codes = transpose(&transposed_codes, self.pq.num_sub_vectors, num_rows); + fields.push(Field::new( + self.pq.column(), + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::UInt8, true)), + self.pq.code_dim() as i32, + ), + true, + )); + columns.push(Arc::new(FixedSizeListArray::try_new_from_values( + original_codes, + self.pq.code_dim() as i32, + )?)); + } + + let batch = RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)?; + let stream = RecordBatchStreamAdapter::new( + batch.schema(), + futures::stream::once(futures::future::ready(Ok(batch))), + ); + Ok(Box::pin(stream)) + } + + fn num_rows(&self) -> u64 { + self.row_ids + .as_ref() + .map_or(0, |row_ids| row_ids.len() as u64) + } + fn row_ids(&self) -> Box> { todo!("this method is for only IVF_HNSW_* index"); } @@ -309,7 +396,7 @@ impl VectorIndex for PQIndex { Ok(()) } - fn remap(&mut self, mapping: &HashMap>) -> Result<()> { + async fn remap(&mut self, mapping: &HashMap>) -> Result<()> { let num_vectors = self.row_ids.as_ref().unwrap().len(); let row_ids = self.row_ids.as_ref().unwrap().values().iter(); let transposed_codes = self.code.as_ref().unwrap(); @@ -343,7 +430,7 @@ impl VectorIndex for PQIndex { Ok(()) } - fn ivf_model(&self) -> IvfModel { + fn ivf_model(&self) -> &IvfModel { unimplemented!("only for IVF") } fn quantizer(&self) -> Quantizer { @@ -419,10 +506,11 @@ pub async fn build_pq_model( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() ); + assert_eq!(training_data.logical_null_count(), 0); info!( "starting to compute partitions for PQ training, sample size: {}", - training_data.value_length() + training_data.len() ); if metric_type == MetricType::Cosine { @@ -542,7 +630,7 @@ mod tests { let centroids = generate_random_array_with_range::(4 * DIM, -1.0..1.0); let fsl = FixedSizeListArray::try_new_from_values(centroids, DIM as i32).unwrap(); - let ivf = IvfModel::new(fsl); + let ivf = IvfModel::new(fsl, None); let params = PQBuildParams::new(16, 8); let pq = build_pq_model(&dataset, "vector", DIM, MetricType::L2, ¶ms, Some(&ivf)) .await diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 661877ed539..660e7415be2 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -4,32 +4,102 @@ use std::sync::Arc; use arrow_array::{cast::AsArray, FixedSizeListArray}; -use arrow_schema::Schema as ArrowSchema; -use arrow_select::concat::concat_batches; -use futures::stream::TryStreamExt; -use snafu::{location, Location}; +use futures::StreamExt; +use lance_arrow::{interleave_batches, DataTypeExt}; +use lance_core::datatypes::Schema; +use log::info; +use rand::rngs::SmallRng; +use rand::seq::{IteratorRandom, SliceRandom}; +use rand::SeedableRng; +use snafu::location; use tokio::sync::Mutex; use crate::dataset::Dataset; use crate::{Error, Result}; -pub fn get_vector_dim(dataset: &Dataset, column: &str) -> Result { - let schema = dataset.schema(); +/// Get the vector dimension of the given column in the schema. +pub fn get_vector_dim(schema: &Schema, column: &str) -> Result { let field = schema.field(column).ok_or(Error::Index { message: format!("Column {} does not exist in schema {}", column, schema), location: location!(), })?; - let data_type = field.data_type(); - if let arrow_schema::DataType::FixedSizeList(_, dim) = data_type { - Ok(dim as usize) - } else { - Err(Error::Index { + infer_vector_dim(&field.data_type()) +} + +/// Infer the vector dimension from the given data type. +pub fn infer_vector_dim(data_type: &arrow::datatypes::DataType) -> Result { + infer_vector_dim_impl(data_type, false) +} + +fn infer_vector_dim_impl(data_type: &arrow::datatypes::DataType, in_list: bool) -> Result { + match (data_type,in_list) { + (arrow::datatypes::DataType::FixedSizeList(_, dim),_) => Ok(*dim as usize), + (arrow::datatypes::DataType::List(inner), false) => infer_vector_dim_impl(inner.data_type(),true), + _ => Err(Error::Index { + message: format!("Data type is not a vector (FixedSizeListArray or List), but {:?}", data_type), + location: location!(), + }), + } +} + +/// Checks whether the given column is with a valid vector type +/// returns the vector type (FixedSizeList for vectors, or List for multivectors), +/// and element type (Float16/Float32/Float64 or UInt8 for binary vectors). +pub fn get_vector_type( + schema: &Schema, + column: &str, +) -> Result<(arrow_schema::DataType, arrow_schema::DataType)> { + let field = schema.field(column).ok_or(Error::Index { + message: format!("column {} does not exist in schema {}", column, schema), + location: location!(), + })?; + Ok(( + field.data_type(), + infer_vector_element_type(&field.data_type())?, + )) +} + +/// If the data type is a fixed size list or list of fixed size list return the inner element type +/// and verify it is a type we can create a vector index on. +/// +/// Return an error if the data type is any other type +pub fn infer_vector_element_type( + data_type: &arrow::datatypes::DataType, +) -> Result { + infer_vector_element_type_impl(data_type, false) +} + +fn infer_vector_element_type_impl( + data_type: &arrow::datatypes::DataType, + in_list: bool, +) -> Result { + match (data_type, in_list) { + (arrow::datatypes::DataType::FixedSizeList(element_field, _), _) => { + match element_field.data_type() { + arrow::datatypes::DataType::Float16 + | arrow::datatypes::DataType::Float32 + | arrow::datatypes::DataType::Float64 + | arrow::datatypes::DataType::UInt8 + | arrow::datatypes::DataType::Int8 => Ok(element_field.data_type().clone()), + _ => Err(Error::Index { + message: format!( + "vector element is not expected type (Float16/Float32/Float64 or UInt8): {:?}", + element_field.data_type() + ), + location: location!(), + }), + } + } + (arrow::datatypes::DataType::List(inner), false) => { + infer_vector_element_type_impl(inner.data_type(), true) + } + _ => Err(Error::Index { message: format!( - "Column {} is not a FixedSizeListArray, but {:?}", - column, data_type + "Data type is not a vector (FixedSizeListArray or List), but {:?}", + data_type ), location: location!(), - }) + }), } } @@ -43,18 +113,99 @@ pub async fn maybe_sample_training_data( sample_size_hint: usize, ) -> Result { let num_rows = dataset.count_rows(None).await?; - let projection = dataset.schema().project(&[column])?; - let batch = if num_rows > sample_size_hint { - dataset.sample(sample_size_hint, &projection).await? + + let vector_field = dataset.schema().field(column).ok_or(Error::Index { + message: format!( + "Sample training data: column {} does not exist in schema", + column + ), + location: location!(), + })?; + let is_nullable = vector_field.nullable; + + let batch = if num_rows > sample_size_hint && !is_nullable { + let projection = dataset.schema().project(&[column])?; + let batch = dataset.sample(sample_size_hint, &projection).await?; + info!( + "Sample training data: retrieved {} rows by sampling", + batch.num_rows() + ); + batch + } else if num_rows > sample_size_hint && is_nullable { + // Use min block size + vector size to determine sample granularity + // For example, on object storage, block size is 64 KB. A 768-dim 32-bit + // vector is 3 KB. So we can sample every 64 KB / 3 KB = 21 vectors. + let block_size = dataset.object_store().block_size(); + // We provide a fallback in case of multi-vector, which will have + // a variable size. We use 4 KB as a fallback. + let byte_width = vector_field + .data_type() + .byte_width_opt() + .unwrap_or(4 * 1024); + + let ranges = random_ranges(num_rows, sample_size_hint, block_size, byte_width); + + let mut collected = Vec::with_capacity(ranges.size_hint().0); + let mut indices = Vec::with_capacity(sample_size_hint); + let mut num_non_null = 0; + + let mut scan = dataset.take_scan( + Box::pin(futures::stream::iter(ranges).map(Ok)), + Arc::new(dataset.schema().project(&[column])?), + dataset.object_store().io_parallelism(), + ); + + while let Some(batch) = scan.next().await { + let batch = batch?; + + let array = batch.column_by_name(column).ok_or(Error::Index { + message: format!( + "Sample training data: column {} does not exist in return", + column + ), + location: location!(), + })?; + let null_count = array.logical_null_count(); + if null_count < array.len() { + num_non_null += array.len() - null_count; + + let batch_i = collected.len(); + if let Some(null_buffer) = array.nulls() { + for i in null_buffer.valid_indices() { + indices.push((batch_i, i)); + } + } else { + indices.extend((0..array.len()).map(|i| (batch_i, i))); + } + + collected.push(batch); + } + if num_non_null >= sample_size_hint { + break; + } + } + + let batch = interleave_batches(&collected, &indices).map_err(|err| Error::Index { + message: format!("Sample training data: {}", err), + location: location!(), + })?; + info!( + "Sample training data: retrieved {} rows by sampling after filtering out nulls", + batch.num_rows() + ); + batch } else { let mut scanner = dataset.scan(); scanner.project(&[column])?; - let batches = scanner - .try_into_stream() - .await? - .try_collect::>() - .await?; - concat_batches(&Arc::new(ArrowSchema::from(&projection)), &batches)? + if is_nullable { + scanner.filter_expr(datafusion_expr::col(column).is_not_null()); + } + let batch = scanner.try_into_batch().await?; + info!( + "Sample training data: retrieved {} rows scanning full datasets", + batch.num_rows() + ); + batch }; let array = batch.column_by_name(column).ok_or(Error::Index { @@ -64,7 +215,23 @@ pub async fn maybe_sample_training_data( ), location: location!(), })?; - Ok(array.as_fixed_size_list().clone()) + + match array.data_type() { + arrow::datatypes::DataType::FixedSizeList(_, _) => Ok(array.as_fixed_size_list().clone()), + // for multivector, flatten the vectors into a FixedSizeListArray + arrow::datatypes::DataType::List(_) => { + let list_array = array.as_list::(); + let vectors = list_array.values().as_fixed_size_list(); + Ok(vectors.clone()) + } + _ => Err(Error::Index { + message: format!( + "Sample training data: column {} is not a FixedSizeListArray", + column + ), + location: location!(), + }), + } } #[derive(Debug)] @@ -87,3 +254,98 @@ impl PartitionLoadLock { mtx.clone() } } + +/// Generate random ranges to sample from a dataset. +/// +/// This will return an iterator of ranges that cover the whole dataset. It +/// provides an unbound iterator so that the caller can decide when to stop. +/// This is useful when the caller wants to sample a fixed number of rows, but +/// has an additional filter that must be applied. +/// +/// Parameters: +/// * `num_rows`: number of rows in the dataset +/// * `sample_size_hint`: the target number of rows to be sampled in the end. +/// This is a hint for the minimum number of rows that will be consumed, but +/// the caller may consume more than this. +/// * `block_size`: the byte size of ranges that should be used. +/// * `byte_width`: the byte width of the vectors that will be sampled. +fn random_ranges( + num_rows: usize, + sample_size_hint: usize, + block_size: usize, + byte_width: usize, +) -> impl Iterator> + Send { + let rows_per_batch = 1.max(block_size / byte_width); + let mut rng = SmallRng::from_entropy(); + let num_bins = num_rows.div_ceil(rows_per_batch); + + let bins_iter: Box + Send> = if sample_size_hint * 5 >= num_rows { + // It's faster to just allocate and shuffle + let mut indices = (0..num_bins).collect::>(); + indices.shuffle(&mut rng); + Box::new(indices.into_iter()) + } else { + // If the sample is a small proportion, then we can instead use a set + // to track which bins we have seen. We start by using the sample_size_hint + // to provide an efficient start, and from there we randomly choose bins + // one by one. + let num_bins = num_rows.div_ceil(rows_per_batch); + // Start with the minimum number we will need. + let min_sample_size = sample_size_hint / rows_per_batch; + let starting_bins = (0..num_bins).choose_multiple(&mut rng, min_sample_size); + let mut seen = starting_bins + .iter() + .cloned() + .collect::>(); + + let additional = std::iter::from_fn(move || loop { + if seen.len() >= num_bins { + break None; + } + let next = (0..num_bins).choose(&mut rng).unwrap(); + if seen.contains(&next) { + continue; + } else { + seen.insert(next); + return Some(next); + } + }); + + Box::new(starting_bins.into_iter().chain(additional)) + }; + + bins_iter.map(move |i| { + let start = (i * rows_per_batch) as u64; + let end = ((i + 1) * rows_per_batch) as u64; + let end = std::cmp::min(end, num_rows as u64); + start..end + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[rstest::rstest] + #[test] + fn test_random_ranges( + #[values(99, 100, 102)] num_rows: usize, + #[values(10, 100)] sample_size: usize, + ) { + // We can just assert that the output when sorted is the same as the input + let block_size = 100; + let byte_width = 10; + + let bin_size = block_size / byte_width; + assert_eq!(bin_size, 10); + + let mut ranges = + random_ranges(num_rows, sample_size, block_size, byte_width).collect::>(); + ranges.sort_by_key(|r| r.start); + let expected = (0..num_rows as u64).step_by(bin_size).map(|start| { + let end = std::cmp::min(start + bin_size as u64, num_rows as u64); + start..end + }); + assert_eq!(ranges, expected.collect::>()); + } +} diff --git a/rust/lance/src/io.rs b/rust/lance/src/io.rs index 94ba83897c0..89c4fabd76f 100644 --- a/rust/lance/src/io.rs +++ b/rust/lance/src/io.rs @@ -7,6 +7,7 @@ pub mod commit; pub mod exec; pub use lance_io::{ + bytes_read_counter, iops_counter, object_store::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry, WrappingObjectStore}, stream::RecordBatchStream, }; diff --git a/rust/lance/src/io/commit.rs b/rust/lance/src/io/commit.rs index e8e77e4b411..d96283ee80b 100644 --- a/rust/lance/src/io/commit.rs +++ b/rust/lance/src/io/commit.rs @@ -11,12 +11,9 @@ //! //! The trait [CommitHandler] can be implemented to provide different commit //! strategies. The default implementation for most object stores is -//! [RenameCommitHandler], which writes the manifest to a temporary path, then +//! [ConditionalPutCommitHandler], which writes the manifest to a temporary path, then //! renames the temporary path to the final path if no object already exists -//! at the final path. This is an atomic operation in most object stores, but -//! not in AWS S3. So for AWS S3, the default commit handler is -//! [UnsafeCommitHandler], which writes the manifest to the final path without -//! any checks. +//! at the final path. //! //! When providing your own commit handler, most often you are implementing in //! terms of a lock. The trait [CommitLock] can be implemented as a simpler @@ -25,7 +22,9 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use lance_core::utils::backoff::Backoff; use lance_file::version::LanceFileVersion; +use lance_index::metrics::NoOpMetricsCollector; use lance_table::format::{ is_detached_version, pb, DataStorageFormat, DeletionFile, Fragment, Index, Manifest, WriterVersion, DETACHED_VERSION_MASK, @@ -33,27 +32,31 @@ use lance_table::format::{ use lance_table::io::commit::{CommitConfig, CommitError, CommitHandler, ManifestNamingScheme}; use lance_table::io::deletion::read_deletion_file; use rand::{thread_rng, Rng}; -use snafu::{location, Location}; +use snafu::location; use futures::future::Either; use futures::{FutureExt, StreamExt, TryStreamExt}; use lance_core::{Error, Result}; use lance_index::DatasetIndexExt; +use log; use object_store::path::Path; use prost::Message; use super::ObjectStore; +use crate::dataset::cleanup::auto_cleanup_hook; use crate::dataset::fragment::FileFragment; -use crate::dataset::transaction::{Operation, Transaction}; +use crate::dataset::transaction::{ConflictResult, Operation, Transaction}; use crate::dataset::{write_manifest_file, ManifestWriteConfig, BLOB_DIR}; use crate::index::DatasetIndexInternalExt; use crate::session::Session; use crate::Dataset; -#[cfg(all(feature = "dynamodb", test))] +#[cfg(all(feature = "dynamodb_tests", test))] mod dynamodb; #[cfg(test)] mod external_manifest; +#[cfg(all(feature = "dynamodb_tests", test))] +mod s3_test; /// Read the transaction data from a transaction file. async fn read_transaction_file( @@ -125,7 +128,7 @@ fn check_transaction( other_version: u64, other_transaction: Option<&Transaction>, ) -> Result<()> { - if other_transaction.is_none() { + let Some(other_transaction) = other_transaction else { return Err(crate::Error::Internal { message: format!( "There was a conflicting transaction at version {}, \ @@ -134,23 +137,28 @@ fn check_transaction( ), location: location!(), }); - } + }; - if transaction.conflicts_with(other_transaction.as_ref().unwrap()) { - return Err(crate::Error::CommitConflict { - version: other_version, - source: format!( - "There was a concurrent commit that conflicts with this one and it \ - cannot be automatically resolved. Please rerun the operation off the latest version \ - of the table.\n Transaction: {:?}\n Conflicting Transaction: {:?}", - transaction, other_transaction - ) - .into(), - location: location!(), - }); + match transaction.conflicts_with(other_transaction) { + ConflictResult::Compatible => Ok(()), + ConflictResult::NotCompatible => { + Err(crate::Error::CommitConflict { + version: other_version, + source: format!( + "This {} transaction is incompatible with concurrent transaction {} at version {}.", + transaction.operation, other_transaction.operation, other_version).into(), + location: location!(), + }) + }, + ConflictResult::Retryable => { + Err(crate::Error::RetryableCommitConflict { + version: other_version, + source: format!( + "This {} transaction was preempted by concurrent transaction {} at version {}. Please retry.", + transaction.operation, other_transaction.operation, other_version).into(), + location: location!() }) + } } - - Ok(()) } #[allow(clippy::too_many_arguments)] @@ -163,7 +171,7 @@ async fn do_commit_new_dataset( manifest_naming_scheme: ManifestNamingScheme, blob_version: Option, session: &Session, -) -> Result<(Manifest, Path)> { +) -> Result<(Manifest, Path, Option)> { let transaction_file = write_transaction_file(object_store, base_path, transaction).await?; let (mut manifest, indices) = @@ -189,12 +197,12 @@ async fn do_commit_new_dataset( // TODO: Allow Append or Overwrite mode to retry using `commit_transaction` // if there is a conflict. match result { - Ok(manifest_path) => { + Ok(manifest_location) => { session.file_metadata_cache.insert( transaction_file_cache_path(base_path, manifest.version), Arc::new(transaction.clone()), ); - Ok((manifest, manifest_path)) + Ok((manifest, manifest_location.path, manifest_location.e_tag)) } Err(CommitError::CommitConflict) => Err(crate::Error::DatasetAlreadyExists { uri: base_path.to_string(), @@ -212,11 +220,11 @@ pub(crate) async fn commit_new_dataset( write_config: &ManifestWriteConfig, manifest_naming_scheme: ManifestNamingScheme, session: &Session, -) -> Result<(Manifest, Path)> { +) -> Result<(Manifest, Path, Option)> { let blob_version = if let Some(blob_op) = transaction.blobs_op.as_ref() { let blob_path = base_path.child(BLOB_DIR); let blob_tx = Transaction::new(0, blob_op.clone(), None, None); - let (blob_manifest, _) = do_commit_new_dataset( + let (blob_manifest, _, _) = do_commit_new_dataset( object_store, commit_handler, &blob_path, @@ -497,13 +505,25 @@ fn must_recalculate_fragment_bitmap(index: &Index, version: Option<&WriterVersio /// /// Indices might be missing `fragment_bitmap`, so this function will add it. async fn migrate_indices(dataset: &Dataset, indices: &mut [Index]) -> Result<()> { + let needs_recalculating = match detect_overlapping_fragments(indices) { + Ok(()) => vec![], + Err(BadFragmentBitmapError { bad_indices }) => { + bad_indices.into_iter().map(|(name, _)| name).collect() + } + }; for index in indices { - if must_recalculate_fragment_bitmap(index, dataset.manifest.writer_version.as_ref()) { + if needs_recalculating.contains(&index.name) + || must_recalculate_fragment_bitmap(index, dataset.manifest.writer_version.as_ref()) + { debug_assert_eq!(index.fields.len(), 1); let idx_field = dataset.schema().field_by_id(index.fields[0]).ok_or_else(|| Error::Internal { message: format!("Index with uuid {} referred to field with id {} which did not exist in dataset", index.uuid, index.fields[0]), location: location!() })?; // We need to calculate the fragments covered by the index let idx = dataset - .open_generic_index(&idx_field.name, &index.uuid.to_string()) + .open_generic_index( + &idx_field.name, + &index.uuid.to_string(), + &NoOpMetricsCollector, + ) .await?; index.fragment_bitmap = Some(idx.calculate_included_frags().await?); } @@ -517,6 +537,40 @@ async fn migrate_indices(dataset: &Dataset, indices: &mut [Index]) -> Result<()> Ok(()) } +pub(crate) struct BadFragmentBitmapError { + pub bad_indices: Vec<(String, Vec)>, +} + +/// Detect whether a given index has overlapping fragment bitmaps in it's index +/// segments. +pub(crate) fn detect_overlapping_fragments( + indices: &[Index], +) -> std::result::Result<(), BadFragmentBitmapError> { + let index_names: HashSet<&str> = indices.iter().map(|i| i.name.as_str()).collect(); + let mut bad_indices = Vec::new(); // (index_name, overlapping_fragments) + for name in index_names { + let mut seen_fragment_ids = HashSet::new(); + let mut overlap = Vec::new(); + for index in indices.iter().filter(|i| i.name == name) { + if let Some(fragment_bitmap) = index.fragment_bitmap.as_ref() { + for fragment in fragment_bitmap { + if !seen_fragment_ids.insert(fragment) { + overlap.push(fragment); + } + } + } + } + if !overlap.is_empty() { + bad_indices.push((name.to_string(), overlap)); + } + } + if bad_indices.is_empty() { + Ok(()) + } else { + Err(BadFragmentBitmapError { bad_indices }) + } +} + pub(crate) async fn do_commit_detached_transaction( dataset: &Dataset, object_store: &ObjectStore, @@ -525,13 +579,14 @@ pub(crate) async fn do_commit_detached_transaction( write_config: &ManifestWriteConfig, commit_config: &CommitConfig, new_blob_version: Option, -) -> Result<(Manifest, Path)> { +) -> Result<(Manifest, Path, Option)> { // We don't strictly need a transaction file but we go ahead and create one for // record-keeping if nothing else. let transaction_file = write_transaction_file(object_store, &dataset.base, transaction).await?; // We still do a loop since we may have conflicts in the random version we pick - for attempt_i in 0..commit_config.num_retries { + let mut backoff = Backoff::default(); + while backoff.attempt() < commit_config.num_retries { // Pick a random u64 with the highest bit set to indicate it is detached let random_version = thread_rng().gen::() | DETACHED_VERSION_MASK; @@ -583,15 +638,13 @@ pub(crate) async fn do_commit_detached_transaction( .await; match result { - Ok(path) => { - return Ok((manifest, path)); + Ok(location) => { + return Ok((manifest, location.path, location.e_tag)); } Err(CommitError::CommitConflict) => { // We pick a random u64 for the version, so it's possible (though extremely unlikely) // that we have a conflict. In that case, we just try again. - - let backoff_time = backoff_time(attempt_i); - tokio::time::sleep(backoff_time).await; + tokio::time::sleep(backoff.next_backoff()).await; } Err(CommitError::OtherError(err)) => { // If other error, return @@ -620,12 +673,12 @@ pub(crate) async fn commit_detached_transaction( transaction: &Transaction, write_config: &ManifestWriteConfig, commit_config: &CommitConfig, -) -> Result<(Manifest, Path)> { +) -> Result<(Manifest, Path, Option)> { let new_blob_version = if let Some(blob_op) = transaction.blobs_op.as_ref() { let blobs_dataset = dataset.blobs_dataset().await?.unwrap(); let blobs_tx = Transaction::new(blobs_dataset.version().version, blob_op.clone(), None, None); - let (blobs_manifest, _) = do_commit_detached_transaction( + let (blobs_manifest, _, _) = do_commit_detached_transaction( blobs_dataset.as_ref(), object_store, commit_handler, @@ -661,12 +714,12 @@ pub(crate) async fn commit_transaction( write_config: &ManifestWriteConfig, commit_config: &CommitConfig, manifest_naming_scheme: ManifestNamingScheme, -) -> Result<(Manifest, Path)> { +) -> Result<(Manifest, Path, Option)> { let new_blob_version = if let Some(blob_op) = transaction.blobs_op.as_ref() { let blobs_dataset = dataset.blobs_dataset().await?.unwrap(); let blobs_tx = Transaction::new(blobs_dataset.version().version, blob_op.clone(), None, None); - let (blobs_manifest, _) = do_commit_detached_transaction( + let (blobs_manifest, _, _) = do_commit_detached_transaction( blobs_dataset.as_ref(), object_store, commit_handler, @@ -684,46 +737,52 @@ pub(crate) async fn commit_transaction( // Note: object_store has been configured with WriteParams, but dataset.object_store() // has not necessarily. So for anything involving writing, use `object_store`. let transaction_file = write_transaction_file(object_store, &dataset.base, transaction).await?; - - // First, get all transactions since read_version let read_version = transaction.read_version; + let mut target_version = read_version + 1; let mut dataset = dataset.clone(); - // We need to checkout the latest version, because any fixes we apply - // (like computing the new row ids) needs to be done based on the most - // recent manifest. - dataset.checkout_latest().await?; - let latest_version = dataset.manifest.version; - let other_transactions = futures::stream::iter((read_version + 1)..=latest_version) - .map(|version| { - read_dataset_transaction_file(&dataset, version) - .map(move |res| res.map(|tx| (version, tx))) - }) - .buffer_unordered(dataset.object_store().io_parallelism()) - .take_while(|res| { - futures::future::ready(!matches!( - res, - Err(crate::Error::NotFound { .. }) | Err(crate::Error::DatasetNotFound { .. }) - )) - }) - .try_collect::>() - .await?; - - let mut target_version = latest_version + 1; - - if is_detached_version(target_version) { - return Err(Error::Internal { message: "more than 2^65 versions have been created and so regular version numbers are appearing as 'detached' versions.".into(), location: location!() }); - } - - // If any of them conflict with the transaction, return an error - for (other_version, other_transaction) in other_transactions.iter() { - check_transaction( - transaction, - *other_version, - Some(other_transaction.as_ref()), - )?; + if matches!(transaction.operation, Operation::Overwrite { .. }) + && commit_config.num_retries == 0 + { + dataset.checkout_version(transaction.read_version).await?; + } else { + // We need to checkout the latest version, because any fixes we apply + // (like computing the new row ids) needs to be done based on the most + // recent manifest. + dataset.checkout_latest().await?; } - - for attempt_i in 0..commit_config.num_retries { + let num_attempts = std::cmp::max(commit_config.num_retries, 1); + let mut backoff = Backoff::default(); + while backoff.attempt() < num_attempts { + // See if we can retry the commit. Try to account for all + // transactions that have been committed since the read_version. + // Use small amount of backoff to handle transactions that all + // started at exact same time better. + futures::stream::iter(target_version..=dataset.manifest.version) + .map(|version| { + read_dataset_transaction_file(&dataset, version) + .map(move |res| res.map(|tx| (version, tx))) + }) + .buffer_unordered(dataset.object_store().io_parallelism()) + .take_while(|res| { + futures::future::ready( + backoff.attempt() > 0 + || !matches!( + res, + Err(crate::Error::NotFound { .. }) + | Err(crate::Error::DatasetNotFound { .. }) + ), + ) + }) + .try_for_each(|(other_version, other_transaction)| { + let res = + check_transaction(transaction, other_version, Some(other_transaction.as_ref())); + futures::future::ready(res) + }) + .await?; + target_version = dataset.manifest.version + 1; + if is_detached_version(target_version) { + return Err(Error::Internal { message: "more than 2^65 versions have been created and so regular version numbers are appearing as 'detached' versions.".into(), location: location!() }); + } // Build an up-to-date manifest from the transaction and current manifest let (mut manifest, mut indices) = match transaction.operation { Operation::Restore { version } => { @@ -780,42 +839,26 @@ pub(crate) async fn commit_transaction( .await; match result { - Ok(manifest_path) => { + Ok(manifest_location) => { let cache_path = transaction_file_cache_path(&dataset.base, target_version); dataset .session() .file_metadata_cache .insert(cache_path, Arc::new(transaction.clone())); - return Ok((manifest, manifest_path)); + match auto_cleanup_hook(&dataset, &manifest).await { + Ok(Some(stats)) => log::info!("Auto cleanup triggered: {:?}", stats), + Err(e) => log::error!("Error encountered during auto_cleanup_hook: {}", e), + _ => {} + }; + return Ok((manifest, manifest_location.path, manifest_location.e_tag)); } Err(CommitError::CommitConflict) => { - // See if we can retry the commit. Try to account for all - // transactions that have been committed since the read_version. - // Use small amount of backoff to handle transactions that all - // started at exact same time better. - - let backoff_time = backoff_time(attempt_i); - tokio::time::sleep(backoff_time).await; - - dataset.checkout_latest().await?; - let latest_version = dataset.manifest.version; - futures::stream::iter(target_version..=latest_version) - .map(|version| { - read_dataset_transaction_file(&dataset, version) - .map(move |res| res.map(|tx| (version, tx))) - }) - .buffer_unordered(dataset.object_store().io_parallelism()) - .try_for_each(|(version, other_transaction)| { - let res = check_transaction( - transaction, - version, - Some(other_transaction.as_ref()), - ); - futures::future::ready(res) - }) - .await?; - target_version = latest_version + 1; + let next_attempt_i = backoff.attempt() + 1; + if next_attempt_i < num_attempts { + tokio::time::sleep(backoff.next_backoff()).await; + dataset.checkout_latest().await?; + } } Err(CommitError::OtherError(err)) => { // If other error, return @@ -835,18 +878,6 @@ pub(crate) async fn commit_transaction( }) } -fn backoff_time(attempt_i: u32) -> std::time::Duration { - // Exponential base: - // 100ms, 200ms, 400ms, 800ms, 1600ms, 3200ms, 6400ms - let backoff = 2_i32.pow(attempt_i) * 100; - // With +-100ms jitter - let jitter = rand::thread_rng().gen_range(-100..100); - let backoff = backoff + jitter; - // No more than 5 seconds and less than 10ms. - let backoff = backoff.clamp(10, 5_000) as u64; - std::time::Duration::from_millis(backoff) -} - #[cfg(test)] mod tests { use std::sync::Mutex; @@ -1210,6 +1241,7 @@ mod tests { #[tokio::test] async fn test_good_concurrent_config_writes() { let (_tmpdir, dataset) = get_empty_dataset().await; + let original_num_config_keys = dataset.manifest.config.len(); // Test successful concurrent insert config operations let futures: Vec<_> = ["key1", "key2", "key3", "key4", "key5"] @@ -1231,7 +1263,7 @@ mod tests { } let dataset = dataset.checkout_version(6).await.unwrap(); - assert_eq!(dataset.manifest.config.len(), 5); + assert_eq!(dataset.manifest.config.len(), 5 + original_num_config_keys); dataset.validate().await.unwrap(); @@ -1254,7 +1286,7 @@ mod tests { let dataset = dataset.checkout_version(11).await.unwrap(); // There are now two fewer keys - assert_eq!(dataset.manifest.config.len(), 3); + assert_eq!(dataset.manifest.config.len(), 3 + original_num_config_keys); dataset.validate().await.unwrap() } diff --git a/rust/lance/src/io/commit/dynamodb.rs b/rust/lance/src/io/commit/dynamodb.rs index 357e055d98f..ec4ca2a24e1 100644 --- a/rust/lance/src/io/commit/dynamodb.rs +++ b/rust/lance/src/io/commit/dynamodb.rs @@ -7,11 +7,8 @@ // since these tests applies to all external manifest stores, // we should move them to a common place // https://github.com/lancedb/lance/issues/1208 -// -// The tests are linux only because -// GHA Mac runner doesn't have docker, which is required to run dynamodb-local // Windows FS can't handle concurrent copy -#[cfg(all(test, target_os = "linux", feature = "dynamodb_tests"))] +#[cfg(all(test, not(target_os = "windows")))] mod test { macro_rules! base_uri { () => { @@ -69,10 +66,10 @@ mod test { .behavior_version_latest() .endpoint_url( // url for dynamodb-local - "http://localhost:8000", + "http://localhost:4566", ) .region(Some(Region::new("us-east-1"))) - .credentials_provider(Credentials::new("DUMMYKEY", "DUMMYKEY", None, None, "")) + .credentials_provider(Credentials::new("ACCESS_KEY", "SECRET_KEY", None, None, "")) .build(); let table_name = uuid::Uuid::new_v4().to_string(); @@ -138,16 +135,19 @@ mod test { .to_string() .starts_with("Not found: dynamodb not found: base_uri: test; version: 1")); // try to use the API for finalizing should return err when the version is DNE - assert!(store.put_if_exists("test", 1, "test").await.is_err()); + assert!(store + .put_if_exists("test", 1, "test", 4, None) + .await + .is_err()); // Put a new version should work assert!(store - .put_if_not_exists("test", 1, "test.unfinalized") + .put_if_not_exists("test", 1, "test.unfinalized", 4, None) .await .is_ok()); // put again should get err assert!(store - .put_if_not_exists("test", 1, "test.unfinalized_1") + .put_if_not_exists("test", 1, "test.unfinalized_1", 4, None) .await .is_err()); @@ -160,7 +160,7 @@ mod test { // Put a new version should work again assert!(store - .put_if_not_exists("test", 2, "test.unfinalized_2") + .put_if_not_exists("test", 2, "test.unfinalized_2", 4, None) .await .is_ok()); // latest should see update @@ -170,7 +170,10 @@ mod test { ); // try to finalize should work on existing version - assert!(store.put_if_exists("test", 2, "test").await.is_ok()); + assert!(store + .put_if_exists("test", 2, "test", 4, None) + .await + .is_ok()); // latest should see update assert_eq!( @@ -322,8 +325,19 @@ mod test { ) .await .unwrap(); + let size = localfs + .head(&version_six_staging_location) + .await + .unwrap() + .size as u64; store - .put_if_exists(ds.base.as_ref(), 6, version_six_staging_location.as_ref()) + .put_if_exists( + ds.base.as_ref(), + 6, + version_six_staging_location.as_ref(), + size, + None, + ) .await .unwrap(); diff --git a/rust/lance/src/io/commit/external_manifest.rs b/rust/lance/src/io/commit/external_manifest.rs index 14bd842f373..44cc7fa7aca 100644 --- a/rust/lance/src/io/commit/external_manifest.rs +++ b/rust/lance/src/io/commit/external_manifest.rs @@ -17,7 +17,7 @@ mod test { use lance_testing::datagen::{BatchGenerator, IncrementingInt32}; use object_store::local::LocalFileSystem; use object_store::path::Path; - use snafu::{location, Location}; + use snafu::location; use tokio::sync::Mutex; use crate::dataset::builder::DatasetBuilder; @@ -72,7 +72,14 @@ mod test { } /// Put the manifest path for a given uri and version, should fail if the version already exists - async fn put_if_not_exists(&self, uri: &str, version: u64, path: &str) -> Result<()> { + async fn put_if_not_exists( + &self, + uri: &str, + version: u64, + path: &str, + _size: u64, + _e_tag: Option, + ) -> Result<()> { tokio::time::sleep(Duration::from_millis(100)).await; let mut store = self.store.lock().await; @@ -92,7 +99,14 @@ mod test { } /// Put the manifest path for a given uri and version, should fail if the version already exists - async fn put_if_exists(&self, uri: &str, version: u64, path: &str) -> Result<()> { + async fn put_if_exists( + &self, + uri: &str, + version: u64, + path: &str, + _size: u64, + _e_tag: Option, + ) -> Result<()> { tokio::time::sleep(Duration::from_millis(100)).await; let mut store = self.store.lock().await; @@ -162,6 +176,7 @@ mod test { } #[tokio::test] + #[cfg(not(windows))] async fn test_can_create_dataset_with_external_store() { let sleepy_store = SleepyExternalManifestStore::new(); let handler = ExternalManifestCommitHandler { @@ -268,6 +283,7 @@ mod test { } #[tokio::test] + #[cfg(not(windows))] async fn test_out_of_sync_dataset_can_recover() { let sleepy_store = SleepyExternalManifestStore::new(); let inner_store = sleepy_store.store.clone(); diff --git a/rust/lance/src/io/commit/s3_test.rs b/rust/lance/src/io/commit/s3_test.rs new file mode 100644 index 00000000000..599022bf6ed --- /dev/null +++ b/rust/lance/src/io/commit/s3_test.rs @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors +use std::{ops::DerefMut, sync::Arc}; + +use arrow::datatypes::Int32Type; + +use crate::{ + dataset::{ + builder::DatasetBuilder, CommitBuilder, InsertBuilder, ReadParams, WriteMode, WriteParams, + }, + io::ObjectStoreParams, +}; +use aws_config::{BehaviorVersion, ConfigLoader, Region, SdkConfig}; +use aws_sdk_s3::{config::Credentials, Client as S3Client}; +use futures::future::try_join_all; +use lance_datagen::{array, gen, RowCount}; + +const CONFIG: &[(&str, &str)] = &[ + ("access_key_id", "ACCESS_KEY"), + ("secret_access_key", "SECRET_KEY"), + ("endpoint", "http://127.0.0.1:4566"), + ("dynamodb_endpoint", "http://127.0.0.1:4566"), + ("allow_http", "true"), + ("region", "us-east-1"), +]; + +async fn aws_config() -> SdkConfig { + let credentials = Credentials::new(CONFIG[0].1, CONFIG[1].1, None, None, "static"); + ConfigLoader::default() + .credentials_provider(credentials) + .endpoint_url(CONFIG[2].1) + .behavior_version(BehaviorVersion::latest()) + .region(Region::new(CONFIG[5].1)) + .load() + .await +} + +struct S3Bucket(String); + +impl S3Bucket { + async fn new(bucket: &str) -> Self { + let config = aws_config().await; + let client = S3Client::new(&config); + + // In case it wasn't deleted earlier + Self::delete_bucket(client.clone(), bucket).await; + + client.create_bucket().bucket(bucket).send().await.unwrap(); + + Self(bucket.to_string()) + } + + async fn delete_bucket(client: S3Client, bucket: &str) { + // Before we delete the bucket, we need to delete all objects in it + let res = client + .list_objects_v2() + .bucket(bucket) + .send() + .await + .map_err(|err| err.into_service_error()); + match res { + Err(e) if e.is_no_such_bucket() => return, + Err(e) => panic!("Failed to list objects in bucket: {}", e), + _ => {} + } + let objects = res.unwrap().contents.unwrap_or_default(); + for object in objects { + client + .delete_object() + .bucket(bucket) + .key(object.key.unwrap()) + .send() + .await + .unwrap(); + } + client.delete_bucket().bucket(bucket).send().await.unwrap(); + } +} + +impl Drop for S3Bucket { + fn drop(&mut self) { + let bucket_name = self.0.clone(); + tokio::task::spawn(async move { + let config = aws_config().await; + let client = S3Client::new(&config); + Self::delete_bucket(client, &bucket_name).await; + }); + } +} + +struct DynamoDBCommitTable(String); + +impl DynamoDBCommitTable { + async fn new(name: &str) -> Self { + let config = aws_config().await; + let client = aws_sdk_dynamodb::Client::new(&config); + + // In case it wasn't deleted earlier + Self::delete_table(client.clone(), name).await; + // Dynamodb table drop is async, so we need to wait a bit + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + + use aws_sdk_dynamodb::types::*; + + client + .create_table() + .table_name(name) + .attribute_definitions( + AttributeDefinition::builder() + .attribute_name("base_uri") + .attribute_type(ScalarAttributeType::S) + .build() + .unwrap(), + ) + .attribute_definitions( + AttributeDefinition::builder() + .attribute_name("version") + .attribute_type(ScalarAttributeType::N) + .build() + .unwrap(), + ) + .key_schema( + KeySchemaElement::builder() + .attribute_name("base_uri") + .key_type(KeyType::Hash) + .build() + .unwrap(), + ) + .key_schema( + KeySchemaElement::builder() + .attribute_name("version") + .key_type(KeyType::Range) + .build() + .unwrap(), + ) + .provisioned_throughput( + ProvisionedThroughput::builder() + .read_capacity_units(1) + .write_capacity_units(1) + .build() + .unwrap(), + ) + .send() + .await + .unwrap(); + + Self(name.to_string()) + } + + async fn delete_table(client: aws_sdk_dynamodb::Client, name: &str) { + match client + .delete_table() + .table_name(name) + .send() + .await + .map_err(|err| err.into_service_error()) + { + Ok(_) => {} + Err(e) if e.is_resource_not_found_exception() => {} + Err(e) => panic!("Failed to delete table: {}", e), + }; + } +} + +impl Drop for DynamoDBCommitTable { + fn drop(&mut self) { + let table_name = self.0.clone(); + tokio::task::spawn(async move { + let config = aws_config().await; + let client = aws_sdk_dynamodb::Client::new(&config); + Self::delete_table(client, &table_name).await; + }); + } +} + +#[tokio::test] +async fn test_concurrent_writers() { + use crate::utils::test::IoTrackingStore; + + let datagen = gen().col("values", array::step::()); + let data = datagen.into_batch_rows(RowCount::from(100)).unwrap(); + + let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); + + // Create a table + let store_params = ObjectStoreParams { + object_store_wrapper: Some(io_stats_wrapper), + storage_options: Some( + CONFIG + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(), + ), + ..Default::default() + }; + let write_params = WriteParams { + store_params: Some(store_params.clone()), + mode: WriteMode::Append, + ..Default::default() + }; + let bucket = S3Bucket::new("test-concurrent-writers").await; + let uri = format!("s3://{}/test", bucket.0); + let transaction = InsertBuilder::new(&uri) + .with_params(&write_params) + .execute_uncommitted(vec![data.clone()]) + .await + .unwrap(); + + // 1 IOPS for uncommitted write + let incremental_stats = || { + let mut stats = io_stats.as_ref().lock().unwrap(); + std::mem::take(stats.deref_mut()) + }; + assert_eq!(incremental_stats().write_iops, 1); + + let dataset = CommitBuilder::new(&uri) + .with_store_params(store_params.clone()) + .execute(transaction) + .await + .unwrap(); + // Commit: 2 IOPs. 1 for transaction file, 1 for manifest file + assert_eq!(incremental_stats().write_iops, 2); + let dataset = Arc::new(dataset); + let old_version = dataset.manifest().version; + + let concurrency = 10; + let mut tasks = Vec::with_capacity(concurrency); + for _ in 0..concurrency { + let ds_ref = dataset.clone(); + let data_ref = data.clone(); + let task = tokio::spawn(async move { + InsertBuilder::new(ds_ref) + .with_params(&WriteParams { + mode: WriteMode::Append, + ..Default::default() + }) + .execute(vec![data_ref]) + .await + .unwrap(); + }); + tasks.push(task); + } + try_join_all(tasks).await.unwrap(); + + let mut dataset = dataset.as_ref().clone(); + dataset.checkout_latest().await.unwrap(); + assert_eq!(old_version + concurrency as u64, dataset.manifest().version); + + let num_rows = dataset.count_rows(None).await.unwrap(); + assert_eq!(num_rows, data.num_rows() * (concurrency + 1)); + + dataset.validate().await.unwrap(); + let half_rows = dataset + .count_rows(Some("values >= 50".into())) + .await + .unwrap(); + assert_eq!(half_rows, num_rows / 2); +} + +#[tokio::test] +async fn test_ddb_open_iops() { + use crate::utils::test::IoTrackingStore; + + let bucket = S3Bucket::new("test-ddb-iops").await; + let ddb_table = DynamoDBCommitTable::new("test-ddb-iops").await; + let uri = format!("s3+ddb://{}/test?ddbTableName={}", bucket.0, ddb_table.0); + + let datagen = gen().col("values", array::step::()); + let data = datagen.into_batch_rows(RowCount::from(100)).unwrap(); + + let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); + + // Create a table + let store_params = ObjectStoreParams { + object_store_wrapper: Some(io_stats_wrapper), + storage_options: Some( + CONFIG + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(), + ), + ..Default::default() + }; + let write_params = WriteParams { + store_params: Some(store_params.clone()), + mode: WriteMode::Append, + ..Default::default() + }; + let transaction = InsertBuilder::new(&uri) + .with_params(&write_params) + .execute_uncommitted(vec![data.clone()]) + .await + .unwrap(); + + // 1 IOPS for uncommitted write + let incremental_stats = || { + let mut stats = io_stats.as_ref().lock().unwrap(); + std::mem::take(stats.deref_mut()) + }; + assert_eq!(incremental_stats().write_iops, 1); + + let _ = CommitBuilder::new(&uri) + .with_store_params(store_params.clone()) + .execute(transaction) + .await + .unwrap(); + // Commit: 4 write IOPs: + // * 1 for transaction file + // * 3 for manifest file + // * write staged file + // * copy to final file + // * delete staged file + let stats = incremental_stats(); + + assert_eq!(stats.write_iops, 4); + assert_eq!(stats.read_iops, 1); + + let dataset = DatasetBuilder::from_uri(&uri) + .with_read_params(ReadParams { + store_options: Some(store_params.clone()), + ..Default::default() + }) + .load() + .await + .unwrap(); + let stats = incremental_stats(); + // Open dataset can be read with 1 IOP, just to read the manifest. + // Looking up latest manifest is handled in dynamodb. + assert_eq!(stats.read_iops, 1); + assert_eq!(stats.write_iops, 0); + + // Append + let dataset = InsertBuilder::new(Arc::new(dataset)) + .with_params(&WriteParams { + mode: WriteMode::Append, + ..Default::default() + }) + .execute(vec![data.clone()]) + .await + .unwrap(); + let stats = incremental_stats(); + // Append: 5 IOPS: data file, transaction file, 3x manifest file + assert_eq!(stats.write_iops, 5); + assert_eq!(stats.read_iops, 0); + + // Checkout original version + dataset.checkout_version(1).await.unwrap(); + let stats = incremental_stats(); + // Checkout: 1 IOPS: manifest file + assert_eq!(stats.read_iops, 1); + assert_eq!(stats.write_iops, 0); +} diff --git a/rust/lance/src/io/exec/fts.rs b/rust/lance/src/io/exec/fts.rs index 6984045d4de..f3b75564334 100644 --- a/rust/lance/src/io/exec/fts.rs +++ b/rust/lance/src/io/exec/fts.rs @@ -4,91 +4,90 @@ use std::collections::HashMap; use std::sync::Arc; +use arrow::array::AsArray; +use arrow::datatypes::{Float32Type, UInt64Type}; use arrow_array::{Float32Array, RecordBatch, UInt64Array}; -use arrow_schema::SchemaRef; use datafusion::common::Statistics; use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, PlanProperties, -}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; +use datafusion_physical_expr::{Distribution, EquivalenceProperties, Partitioning}; use futures::stream::{self}; use futures::{StreamExt, TryStreamExt}; -use lance_index::prefilter::{FilterLoader, PreFilter}; -use lance_index::scalar::inverted::{flat_bm25_search_stream, InvertedIndex, FTS_SCHEMA}; -use lance_index::scalar::FullTextSearchQuery; -use lance_table::format::Index; +use itertools::Itertools; +use lance_core::ROW_ID; +use lance_index::prefilter::PreFilter; +use lance_index::scalar::inverted::query::{ + collect_tokens, BoostQuery, FtsSearchParams, MatchQuery, PhraseQuery, +}; +use lance_index::scalar::inverted::{ + flat_bm25_search_stream, InvertedIndex, FTS_SCHEMA, SCORE_COL, +}; +use lance_index::DatasetIndexExt; use tracing::instrument; -use crate::index::prefilter::DatasetPreFilter; use crate::{index::DatasetIndexInternalExt, Dataset}; -use super::utils::{FilteredRowIdsToPrefilter, SelectionVectorToPrefilter}; +use super::utils::{build_prefilter, IndexMetrics, InstrumentedRecordBatchStreamAdapter}; use super::PreFilterSource; -/// An execution node that performs full text search -/// -/// This node would perform full text search with inverted index on the dataset. -/// The result is a stream of record batches containing the row ids that match the search query, -/// and scores of the matched rows. #[derive(Debug)] -pub struct FtsExec { +pub struct MatchQueryExec { dataset: Arc, - // column -> (indices, unindexed input stream) - indices: HashMap>, - query: FullTextSearchQuery, - /// Prefiltering input + query: MatchQuery, + params: FtsSearchParams, prefilter_source: PreFilterSource, + properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } -impl DisplayAs for FtsExec { +impl DisplayAs for MatchQueryExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "Fts: query={}", self.query.query) + write!(f, "MatchQuery: query={}", self.query.terms) } } } } -impl FtsExec { +impl MatchQueryExec { pub fn new( dataset: Arc, - indices: HashMap>, - query: FullTextSearchQuery, + query: MatchQuery, + params: FtsSearchParams, prefilter_source: PreFilterSource, ) -> Self { let properties = PlanProperties::new( EquivalenceProperties::new(FTS_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Final, + Boundedness::Bounded, ); Self { dataset, - indices, query, + params, prefilter_source, properties, + metrics: ExecutionPlanMetricsSet::new(), } } } -impl ExecutionPlan for FtsExec { +impl ExecutionPlan for MatchQueryExec { fn name(&self) -> &str { - "FtsExec" + "MatchQueryExec" } fn as_any(&self) -> &dyn std::any::Any { self } - fn schema(&self) -> SchemaRef { - FTS_SCHEMA.clone() - } - fn children(&self) -> Vec<&Arc> { match &self.prefilter_source { PreFilterSource::None => vec![], @@ -97,218 +96,758 @@ impl ExecutionPlan for FtsExec { } } + fn required_input_distribution(&self) -> Vec { + // Prefilter inputs must be a single partition + self.children() + .iter() + .map(|_| Distribution::SinglePartition) + .collect() + } + fn with_new_children( self: Arc, - _children: Vec>, + mut children: Vec>, ) -> DataFusionResult> { - todo!() + let plan = match children.len() { + 0 => { + if !matches!(self.prefilter_source, PreFilterSource::None) { + return Err(DataFusionError::Internal( + "Unexpected prefilter source".to_string(), + )); + } + + Self { + dataset: self.dataset.clone(), + query: self.query.clone(), + params: self.params.clone(), + prefilter_source: PreFilterSource::None, + properties: self.properties.clone(), + metrics: ExecutionPlanMetricsSet::new(), + } + } + 1 => { + let src = children.pop().unwrap(); + let prefilter_source = match &self.prefilter_source { + PreFilterSource::FilteredRowIds(_) => { + PreFilterSource::FilteredRowIds(src.clone()) + } + PreFilterSource::ScalarIndexQuery(_) => { + PreFilterSource::ScalarIndexQuery(src.clone()) + } + PreFilterSource::None => { + return Err(DataFusionError::Internal( + "Unexpected prefilter source".to_string(), + )); + } + }; + + Self { + dataset: self.dataset.clone(), + query: self.query.clone(), + params: self.params.clone(), + prefilter_source, + properties: self.properties.clone(), + metrics: ExecutionPlanMetricsSet::new(), + } + } + _ => { + return Err(DataFusionError::Internal( + "Unexpected number of children".to_string(), + )); + } + }; + Ok(Arc::new(plan)) } - #[instrument(name = "fts_exec", level = "debug", skip_all)] + #[instrument(name = "match_query_exec", level = "debug", skip_all)] fn execute( &self, partition: usize, - context: Arc, + context: Arc, ) -> DataFusionResult { let query = self.query.clone(); + let params = self.params.clone(); let ds = self.dataset.clone(); let prefilter_source = self.prefilter_source.clone(); + let metrics = Arc::new(IndexMetrics::new(&self.metrics, partition)); + let column = query.column.ok_or(DataFusionError::Execution(format!( + "column not set for MatchQuery {}", + query.terms + )))?; + + let stream = stream::once(async move { + let index_meta = ds.load_scalar_index_for_column(&column).await?.ok_or( + DataFusionError::Execution(format!("No index found for column {}", column,)), + )?; + let uuid = index_meta.uuid.to_string(); + let index = ds + .open_generic_index(&column, &uuid, metrics.as_ref()) + .await?; + + let pre_filter = build_prefilter( + context.clone(), + partition, + &prefilter_source, + ds, + &[index_meta], + )?; + + let inverted_idx = index + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Index for column {} is not an inverted index", + column, + )) + })?; + + let is_fuzzy = matches!(query.fuzziness, Some(n) if n != 0); + let mut tokenizer = match is_fuzzy { + false => inverted_idx.tokenizer(), + true => tantivy::tokenizer::TextAnalyzer::from( + tantivy::tokenizer::SimpleTokenizer::default(), + ), + }; + let mut tokens = collect_tokens(&query.terms, &mut tokenizer, None); + if is_fuzzy { + tokens = + inverted_idx.expand_fuzzy(tokens, query.fuzziness, query.max_expansions)?; + } - let indices = self.indices.clone(); - let stream = stream::iter(indices) - .map(move |(column, indices)| { - let index_meta = indices[0].clone(); - let uuid = index_meta.uuid.to_string(); - let query = query.clone(); - let ds = ds.clone(); - let context = context.clone(); - let prefilter_source = prefilter_source.clone(); - - async move { - let prefilter_loader = match &prefilter_source { - PreFilterSource::FilteredRowIds(src_node) => { - let stream = src_node.execute(partition, context.clone())?; - Some(Box::new(FilteredRowIdsToPrefilter(stream)) - as Box) - } - PreFilterSource::ScalarIndexQuery(src_node) => { - let stream = src_node.execute(partition, context.clone())?; - Some(Box::new(SelectionVectorToPrefilter(stream)) - as Box) - } - PreFilterSource::None => None, - }; - let pre_filter = Arc::new(DatasetPreFilter::new( - ds.clone(), - &[index_meta], - prefilter_loader, - )); - - let index = ds.open_generic_index(&column, &uuid).await?; - let index = - index - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Index {} is not an inverted index", - uuid, - )) - })?; - pre_filter.wait_for_ready().await?; - let results = index.full_text_search(&query, pre_filter).await?; - - let (row_ids, scores): (Vec, Vec) = results.into_iter().unzip(); - let batch = RecordBatch::try_new( - FTS_SCHEMA.clone(), - vec![ - Arc::new(UInt64Array::from(row_ids)), - Arc::new(Float32Array::from(scores)), - ], - )?; - Ok::<_, DataFusionError>(batch) - } - }) - .buffered(self.indices.len()); - let schema = self.schema(); - Ok( - Box::pin(RecordBatchStreamAdapter::new(schema, stream.boxed())) - as SendableRecordBatchStream, - ) + pre_filter.wait_for_ready().await?; + let (doc_ids, mut scores) = inverted_idx + .bm25_search( + &tokens, + ¶ms, + query.operator, + false, + pre_filter, + metrics.as_ref(), + ) + .await?; + scores.iter_mut().for_each(|s| { + *s *= query.boost; + }); + + let batch = RecordBatch::try_new( + FTS_SCHEMA.clone(), + vec![ + Arc::new(UInt64Array::from(doc_ids)), + Arc::new(Float32Array::from(scores)), + ], + )?; + Ok::<_, DataFusionError>(batch) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream.boxed(), + ))) } fn statistics(&self) -> DataFusionResult { Ok(Statistics::new_unknown(&FTS_SCHEMA)) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn properties(&self) -> &PlanProperties { &self.properties } } -/// An execution node that performs flat full text search -/// -/// This node would perform flat full text search on unindexed rows. -/// The result is a stream of record batches containing the row ids that match the search query, -/// and scores of the matched rows. +/// Calculates the FTS score for each row in the input #[derive(Debug)] -pub struct FlatFtsExec { +pub struct FlatMatchQueryExec { dataset: Arc, - // column -> (indices, unindexed input stream) - column_inputs: HashMap, Arc)>, - query: FullTextSearchQuery, + query: MatchQuery, + params: FtsSearchParams, + unindexed_input: Arc, + properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } -impl DisplayAs for FlatFtsExec { +impl DisplayAs for FlatMatchQueryExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "FlatFts: query={}", self.query.query) + write!(f, "FlatMatchQuery: query={}", self.query.terms) } } } } -impl FlatFtsExec { +impl FlatMatchQueryExec { pub fn new( dataset: Arc, - column_inputs: HashMap, Arc)>, - query: FullTextSearchQuery, + query: MatchQuery, + params: FtsSearchParams, + unindexed_input: Arc, ) -> Self { let properties = PlanProperties::new( EquivalenceProperties::new(FTS_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Self { dataset, - column_inputs, query, + params, + unindexed_input, properties, + metrics: ExecutionPlanMetricsSet::new(), } } } -impl ExecutionPlan for FlatFtsExec { +impl ExecutionPlan for FlatMatchQueryExec { fn name(&self) -> &str { - "FlatFtsExec" + "FlatMatchQueryExec" } fn as_any(&self) -> &dyn std::any::Any { self } - fn schema(&self) -> SchemaRef { - FTS_SCHEMA.clone() + fn children(&self) -> Vec<&Arc> { + vec![&self.unindexed_input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> DataFusionResult> { + if children.len() != 1 { + return Err(DataFusionError::Internal( + "Unexpected number of children".to_string(), + )); + } + let unindexed_input = children.pop().unwrap(); + Ok(Arc::new(Self { + dataset: self.dataset.clone(), + query: self.query.clone(), + params: self.params.clone(), + unindexed_input, + properties: self.properties.clone(), + metrics: ExecutionPlanMetricsSet::new(), + })) + } + + #[instrument(name = "flat_match_query_exec", level = "debug", skip_all)] + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let query = self.query.clone(); + let ds = self.dataset.clone(); + let metrics = Arc::new(IndexMetrics::new(&self.metrics, partition)); + let unindexed_input = self.unindexed_input.execute(partition, context)?; + + let column = query.column.ok_or(DataFusionError::Execution(format!( + "column not set for MatchQuery {}", + query.terms + )))?; + + let stream = stream::once(async move { + let index_meta = ds.load_scalar_index_for_column(&column).await?.ok_or( + DataFusionError::Execution(format!("No index found for column {}", column,)), + )?; + let uuid = index_meta.uuid.to_string(); + let index = ds + .open_generic_index(&column, &uuid, metrics.as_ref()) + .await?; + let inverted_idx = index + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Index for column {} is not an inverted index", + column, + )) + })?; + Ok::<_, DataFusionError>(flat_bm25_search_stream( + unindexed_input, + column, + query.terms, + inverted_idx, + )) + }) + .try_flatten_unordered(None); + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( + self.schema(), + stream.boxed(), + partition, + &self.metrics, + ))) + } + + fn statistics(&self) -> DataFusionResult { + Ok(Statistics::new_unknown(&FTS_SCHEMA)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } +} + +#[derive(Debug)] +pub struct PhraseQueryExec { + dataset: Arc, + query: PhraseQuery, + params: FtsSearchParams, + prefilter_source: PreFilterSource, + properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl DisplayAs for PhraseQueryExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PhraseQuery: query={}", self.query.terms) + } + } + } +} + +impl PhraseQueryExec { + pub fn new( + dataset: Arc, + query: PhraseQuery, + params: FtsSearchParams, + prefilter_source: PreFilterSource, + ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(FTS_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Final, + Boundedness::Bounded, + ); + Self { + dataset, + query, + params, + prefilter_source, + properties, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl ExecutionPlan for PhraseQueryExec { + fn name(&self) -> &str { + "PhraseQueryExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self } fn children(&self) -> Vec<&Arc> { - self.column_inputs - .values() - .map(|(_, input)| input) + match &self.prefilter_source { + PreFilterSource::None => vec![], + PreFilterSource::FilteredRowIds(src) => vec![&src], + PreFilterSource::ScalarIndexQuery(src) => vec![&src], + } + } + + fn required_input_distribution(&self) -> Vec { + // Prefilter inputs must be a single partition + self.children() + .iter() + .map(|_| Distribution::SinglePartition) .collect() } fn with_new_children( self: Arc, - _children: Vec>, + mut children: Vec>, ) -> DataFusionResult> { - todo!() + let plan = match children.len() { + 0 => Self { + dataset: self.dataset.clone(), + query: self.query.clone(), + params: self.params.clone(), + prefilter_source: PreFilterSource::None, + properties: self.properties.clone(), + metrics: ExecutionPlanMetricsSet::new(), + }, + 1 => { + let src = children.pop().unwrap(); + let prefilter_source = match &self.prefilter_source { + PreFilterSource::FilteredRowIds(_) => { + PreFilterSource::FilteredRowIds(src.clone()) + } + PreFilterSource::ScalarIndexQuery(_) => { + PreFilterSource::ScalarIndexQuery(src.clone()) + } + PreFilterSource::None => { + return Err(DataFusionError::Internal( + "Unexpected prefilter source".to_string(), + )); + } + }; + Self { + dataset: self.dataset.clone(), + query: self.query.clone(), + params: self.params.clone(), + prefilter_source, + properties: self.properties.clone(), + metrics: ExecutionPlanMetricsSet::new(), + } + } + _ => { + return Err(DataFusionError::Internal( + "Unexpected number of children".to_string(), + )); + } + }; + Ok(Arc::new(plan)) } - #[instrument(name = "flat_fts_exec", level = "debug", skip_all)] + #[instrument(name = "phrase_query_exec", level = "debug", skip_all)] fn execute( &self, partition: usize, - context: Arc, + context: Arc, ) -> DataFusionResult { let query = self.query.clone(); + let params = self.params.clone(); let ds = self.dataset.clone(); - let column_inputs = self.column_inputs.clone(); - - let stream = stream::iter(column_inputs) - .map(move |(column, (indices, input))| { - let index_meta = indices[0].clone(); - let uuid = index_meta.uuid.to_string(); - let query = query.clone(); - let ds = ds.clone(); - let context = context.clone(); - - async move { - let index = ds.open_generic_index(&column, &uuid).await?; - let index = - index - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Index {} is not an inverted index", - uuid, - )) - })?; - - let unindexed_stream = input.execute(partition, context)?; - let unindexed_result_stream = - flat_bm25_search_stream(unindexed_stream, column, query, index); - - Ok::<_, DataFusionError>(unindexed_result_stream) + let prefilter_source = self.prefilter_source.clone(); + let metrics = Arc::new(IndexMetrics::new(&self.metrics, partition)); + let stream = stream::once(async move { + let column = query.column.ok_or(DataFusionError::Execution(format!( + "column not set for PhraseQuery {}", + query.terms + )))?; + let index_meta = ds.load_scalar_index_for_column(&column).await?.ok_or( + DataFusionError::Execution(format!("No index found for column {}", column,)), + )?; + let uuid = index_meta.uuid.to_string(); + let index = ds + .open_generic_index(&column, &uuid, metrics.as_ref()) + .await?; + + let pre_filter = build_prefilter( + context.clone(), + partition, + &prefilter_source, + ds, + &[index_meta], + )?; + + let index = index + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Index for column {} is not an inverted index", + column, + )) + })?; + + let mut tokenizer = index.tokenizer(); + let tokens = collect_tokens(&query.terms, &mut tokenizer, None); + + pre_filter.wait_for_ready().await?; + let (doc_ids, scores) = index + .bm25_search( + &tokens, + ¶ms, + lance_index::scalar::inverted::query::Operator::And, + true, + pre_filter, + metrics.as_ref(), + ) + .await?; + let batch = RecordBatch::try_new( + FTS_SCHEMA.clone(), + vec![ + Arc::new(UInt64Array::from(doc_ids)), + Arc::new(Float32Array::from(scores)), + ], + )?; + Ok::<_, DataFusionError>(batch) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream.boxed(), + ))) + } + + fn statistics(&self) -> DataFusionResult { + Ok(Statistics::new_unknown(&FTS_SCHEMA)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } +} + +#[derive(Debug)] +pub struct BoostQueryExec { + query: BoostQuery, + params: FtsSearchParams, + positive: Arc, + negative: Arc, + + properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl DisplayAs for BoostQueryExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "BoostQuery: negative_boost={}", + self.query.negative_boost + ) + } + } + } +} + +impl BoostQueryExec { + pub fn new( + query: BoostQuery, + params: FtsSearchParams, + positive: Arc, + negative: Arc, + ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(FTS_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Final, + Boundedness::Bounded, + ); + Self { + query, + params, + positive, + negative, + properties, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl ExecutionPlan for BoostQueryExec { + fn name(&self) -> &str { + "BoostQueryExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.positive, &self.negative] + } + + fn required_input_distribution(&self) -> Vec { + // This node fully consumes and re-orders the input rows. + // It must be run on a single partition. + self.children() + .iter() + .map(|_| Distribution::SinglePartition) + .collect() + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> DataFusionResult> { + if children.len() != 2 { + return Err(DataFusionError::Internal( + "Unexpected number of children".to_string(), + )); + } + + let negative = children.pop().unwrap(); + let positive = children.pop().unwrap(); + Ok(Arc::new(Self { + query: self.query.clone(), + params: self.params.clone(), + positive, + negative, + properties: self.properties.clone(), + metrics: ExecutionPlanMetricsSet::new(), + })) + } + + #[instrument(name = "boost_query_exec", level = "debug", skip_all)] + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let query = self.query.clone(); + let params = self.params.clone(); + let positive = self.positive.execute(partition, context.clone())?; + let negative = self.negative.execute(partition, context)?; + let stream = stream::once(async move { + let positive = positive.try_collect::>().await?; + let negative = negative.try_collect::>().await?; + + let mut res = HashMap::new(); + for batch in positive { + let doc_ids = batch[ROW_ID].as_primitive::().values(); + let scores = batch[SCORE_COL].as_primitive::().values(); + + for (doc_id, score) in std::iter::zip(doc_ids, scores) { + res.insert(*doc_id, *score); } - }) - .buffered(self.column_inputs.len()) - .try_flatten(); - let schema = self.schema(); - Ok( - Box::pin(RecordBatchStreamAdapter::new(schema, stream.boxed())) - as SendableRecordBatchStream, - ) + } + for batch in negative { + let doc_ids = batch[ROW_ID].as_primitive::().values(); + let scores = batch[SCORE_COL].as_primitive::().values(); + + for (doc_id, neg_score) in std::iter::zip(doc_ids, scores) { + if let Some(score) = res.get_mut(doc_id) { + *score -= query.negative_boost * neg_score; + } + } + } + + let (doc_ids, scores): (Vec<_>, Vec<_>) = res + .into_iter() + .sorted_unstable_by(|(_, a), (_, b)| b.total_cmp(a)) + .take(params.limit.unwrap_or(usize::MAX)) + .unzip(); + + let batch = RecordBatch::try_new( + FTS_SCHEMA.clone(), + vec![ + Arc::new(UInt64Array::from(doc_ids)), + Arc::new(Float32Array::from(scores)), + ], + )?; + Ok::<_, DataFusionError>(batch) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream.boxed(), + ))) } fn statistics(&self) -> DataFusionResult { Ok(Statistics::new_unknown(&FTS_SCHEMA)) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn properties(&self) -> &PlanProperties { &self.properties } } + +#[cfg(test)] +pub mod tests { + use std::sync::Arc; + + use datafusion::{execution::TaskContext, physical_plan::ExecutionPlan}; + use lance_datafusion::datagen::DatafusionDatagenExt; + use lance_datagen::{BatchCount, ByteCount, RowCount}; + use lance_index::scalar::inverted::query::{ + BoostQuery, FtsQuery, FtsSearchParams, MatchQuery, PhraseQuery, + }; + + use crate::{io::exec::PreFilterSource, utils::test::NoContextTestFixture}; + + use super::{BoostQueryExec, FlatMatchQueryExec, MatchQueryExec, PhraseQueryExec}; + + #[test] + fn execute_without_context() { + // These tests ensure we can create nodes and call execute without a tokio Runtime + // being active. This is a requirement for proper implementation of a Datafusion foreign + // table provider. + let fixture = NoContextTestFixture::new(); + let match_query = MatchQueryExec::new( + Arc::new(fixture.dataset.clone()), + MatchQuery::new("blah".to_string()).with_column(Some("text".to_string())), + FtsSearchParams::default(), + PreFilterSource::None, + ); + match_query + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + + let flat_input = lance_datagen::gen() + .col( + "text", + lance_datagen::array::rand_utf8(ByteCount::from(10), false), + ) + .into_df_exec(RowCount::from(15), BatchCount::from(2)); + + let flat_match_query = FlatMatchQueryExec::new( + Arc::new(fixture.dataset.clone()), + MatchQuery::new("blah".to_string()).with_column(Some("text".to_string())), + FtsSearchParams::default(), + flat_input, + ); + flat_match_query + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + + let phrase_query = PhraseQueryExec::new( + Arc::new(fixture.dataset.clone()), + PhraseQuery::new("blah".to_string()), + FtsSearchParams::default(), + PreFilterSource::None, + ); + phrase_query + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + + let boost_input_one = MatchQueryExec::new( + Arc::new(fixture.dataset.clone()), + MatchQuery::new("blah".to_string()).with_column(Some("text".to_string())), + FtsSearchParams::default(), + PreFilterSource::None, + ); + + let boost_input_two = MatchQueryExec::new( + Arc::new(fixture.dataset), + MatchQuery::new("blah".to_string()).with_column(Some("text".to_string())), + FtsSearchParams::default(), + PreFilterSource::None, + ); + + let boost_query = BoostQueryExec::new( + BoostQuery::new( + FtsQuery::Match( + MatchQuery::new("blah".to_string()).with_column(Some("text".to_string())), + ), + FtsQuery::Match( + MatchQuery::new("test".to_string()).with_column(Some("text".to_string())), + ), + Some(1.0), + ), + FtsSearchParams::default(), + Arc::new(boost_input_one), + Arc::new(boost_input_two), + ); + boost_query + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + } +} diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index a09aa2f1331..839ae993156 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -2,26 +2,37 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::any::Any; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use arrow::datatypes::UInt32Type; +use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type}; use arrow_array::{ builder::{ListBuilder, UInt32Builder}, cast::AsArray, ArrayRef, RecordBatch, StringArray, }; +use arrow_array::{Array, Float32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; -use datafusion::common::stats::Precision; -use datafusion::error::{DataFusionError, Result as DataFusionResult}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::PlanProperties; use datafusion::physical_plan::{ - stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, }; -use datafusion::physical_plan::{ExecutionMode, PlanProperties}; -use datafusion_physical_expr::EquivalenceProperties; +use datafusion::{ + common::stats::Precision, + physical_plan::execution_plan::{Boundedness, EmissionType}, +}; +use datafusion::{common::ColumnStatistics, physical_plan::metrics::ExecutionPlanMetricsSet}; +use datafusion::{ + error::{DataFusionError, Result as DataFusionResult}, + physical_plan::metrics::MetricsSet, +}; +use datafusion_physical_expr::{Distribution, EquivalenceProperties}; use futures::stream::repeat_with; use futures::{future, stream, StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; +use lance_core::ROW_ID; use lance_core::{utils::tokio::get_num_compute_intensive_cpus, ROW_ID_FIELD}; use lance_index::vector::{ flat::compute_distance, Query, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN, @@ -29,38 +40,28 @@ use lance_index::vector::{ use lance_linalg::distance::DistanceType; use lance_linalg::kernels::normalize_arrow; use lance_table::format::Index; -use snafu::{location, Location}; +use snafu::location; use crate::dataset::Dataset; use crate::index::prefilter::{DatasetPreFilter, FilterLoader}; +use crate::index::vector::utils::get_vector_type; use crate::index::DatasetIndexInternalExt; use crate::{Error, Result}; use lance_arrow::*; -use super::utils::{FilteredRowIdsToPrefilter, PreFilterSource, SelectionVectorToPrefilter}; +use super::utils::{ + FilteredRowIdsToPrefilter, IndexMetrics, InstrumentedRecordBatchStreamAdapter, PreFilterSource, + SelectionVectorToPrefilter, +}; -/// Check vector column exists and has the correct data type. -fn check_vector_column(schema: &Schema, column: &str) -> Result<()> { - let field = schema.field_with_name(column).map_err(|_| { - Error::io( - format!("Query column '{}' not found in input schema", column), - location!(), - ) - })?; - match field.data_type() { - DataType::FixedSizeList(list_field, _) - if matches!( - list_field.data_type(), - DataType::UInt8 | DataType::Float16 | DataType::Float32 | DataType::Float64 - ) => Ok(()), - _ => { - Err(Error::io( - format!( - "KNNFlatExec node: query column {} is not a vector. Expect FixedSizeList, got {}", - column, field.data_type() - ), - location!(), - )) +pub struct AnnMetrics { + index_metrics: IndexMetrics, +} + +impl AnnMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + index_metrics: IndexMetrics::new(metrics, partition), } } } @@ -84,6 +85,8 @@ pub struct KNNVectorDistanceExec { output_schema: SchemaRef, properties: PlanProperties, + + metrics: ExecutionPlanMetricsSet, } impl DisplayAs for KNNVectorDistanceExec { @@ -107,7 +110,7 @@ impl KNNVectorDistanceExec { distance_type: DistanceType, ) -> Result { let mut output_schema = input.schema().as_ref().clone(); - check_vector_column(&output_schema, column)?; + get_vector_type(&(&output_schema).try_into()?, column)?; // FlatExec appends a distance column to the input schema. The input // may already have a distance column (possibly in the wrong position), so @@ -135,6 +138,7 @@ impl KNNVectorDistanceExec { distance_type, output_schema, properties, + metrics: ExecutionPlanMetricsSet::new(), }) } } @@ -181,7 +185,6 @@ impl ExecutionPlan for KNNVectorDistanceExec { context: Arc, ) -> DataFusionResult { let input_stream = self.input.execute(partition, context)?; - let key = self.query.clone(); let column = self.column.clone(); let dt = self.distance_type; @@ -198,26 +201,44 @@ impl ExecutionPlan for KNNVectorDistanceExec { }) .buffer_unordered(get_num_compute_intensive_cpus()); let schema = self.schema(); - Ok( - Box::pin(RecordBatchStreamAdapter::new(schema, stream.boxed())) - as SendableRecordBatchStream, - ) + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( + schema, + stream.boxed(), + partition, + &self.metrics, + )) as SendableRecordBatchStream) } fn statistics(&self) -> DataFusionResult { let inner_stats = self.input.statistics()?; - let dist_col_stats = inner_stats.column_statistics[0].clone(); + let schema = self.input.schema(); + let dist_stats = inner_stats + .column_statistics + .iter() + .zip(schema.fields()) + .find(|(_, field)| field.name() == &self.column) + .map(|(stats, _)| ColumnStatistics { + null_count: stats.null_count, + ..Default::default() + }) + .unwrap_or_default(); let column_statistics = inner_stats .column_statistics .into_iter() - .chain([dist_col_stats]) + .zip(schema.fields()) + .filter(|(_, field)| field.name() != DIST_COL) + .map(|(stats, _)| stats) + .chain(std::iter::once(dist_stats)) .collect::>(); Ok(Statistics { num_rows: inner_stats.num_rows, column_statistics, ..Statistics::new_unknown(self.schema().as_ref()) }) - // self.input.statistics() + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) } fn properties(&self) -> &PlanProperties { @@ -285,12 +306,14 @@ pub struct ANNIvfPartitionExec { pub index_uuids: Vec, pub properties: PlanProperties, + + pub metrics: ExecutionPlanMetricsSet, } impl ANNIvfPartitionExec { pub fn try_new(dataset: Arc, index_uuids: Vec, query: Query) -> Result { let dataset_schema = dataset.schema(); - check_vector_column(&dataset_schema.into(), &query.column)?; + get_vector_type(dataset_schema, &query.column)?; if index_uuids.is_empty() { return Err(Error::Execution { message: "ANNIVFPartitionExec node: no index found for query".to_string(), @@ -302,7 +325,8 @@ impl ANNIvfPartitionExec { let properties = PlanProperties::new( EquivalenceProperties::new(schema), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Ok(Self { @@ -310,6 +334,7 @@ impl ANNIvfPartitionExec { query, index_uuids, properties, + metrics: ExecutionPlanMetricsSet::new(), }) } } @@ -360,28 +385,35 @@ impl ExecutionPlan for ANNIvfPartitionExec { fn with_new_children( self: Arc, - _children: Vec>, + children: Vec>, ) -> DataFusionResult> { - Err(DataFusionError::Internal( - "ANNIVFPartitionExec: with_new_children called, but no children to replace".to_string(), - )) + if !children.is_empty() { + Err(DataFusionError::Internal( + "ANNIVFPartitionExec node does not accept children".to_string(), + )) + } else { + Ok(self) + } } fn execute( &self, - _partition: usize, + partition: usize, _context: Arc, ) -> DataFusionResult { let query = self.query.clone(); let ds = self.dataset.clone(); - + let metrics = Arc::new(AnnMetrics::new(&self.metrics, partition)); let stream = stream::iter(self.index_uuids.clone()) .map(move |uuid| { let query = query.clone(); let ds = ds.clone(); + let metrics = metrics.clone(); async move { - let index = ds.open_vector_index(&query.column, &uuid).await?; + let index = ds + .open_vector_index(&query.column, &uuid, &metrics.index_metrics) + .await?; let mut query = query.clone(); if index.metric_type() == DistanceType::Cosine { @@ -407,10 +439,12 @@ impl ExecutionPlan for ANNIvfPartitionExec { }) .buffered(self.index_uuids.len()); let schema = self.schema(); - Ok( - Box::pin(RecordBatchStreamAdapter::new(schema, stream.boxed())) - as SendableRecordBatchStream, - ) + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( + schema, + stream.boxed(), + partition, + &self.metrics, + )) as SendableRecordBatchStream) } } @@ -439,6 +473,8 @@ pub struct ANNIvfSubIndexExec { /// Datafusion Plan Properties properties: PlanProperties, + + metrics: ExecutionPlanMetricsSet, } impl ANNIvfSubIndexExec { @@ -461,7 +497,8 @@ impl ANNIvfSubIndexExec { let properties = PlanProperties::new( EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Final, + Boundedness::Bounded, ); Ok(Self { input, @@ -470,6 +507,7 @@ impl ANNIvfSubIndexExec { query, prefilter_source, properties, + metrics: ExecutionPlanMetricsSet::new(), }) } } @@ -511,26 +549,38 @@ impl ExecutionPlan for ANNIvfSubIndexExec { } } + fn required_input_distribution(&self) -> Vec { + // Prefilter inputs must be a single partition + self.children() + .iter() + .map(|_| Distribution::SinglePartition) + .collect() + } + fn with_new_children( self: Arc, mut children: Vec>, ) -> DataFusionResult> { - if children.len() != 1 { + let plan = if children.len() == 1 || children.len() == 2 { + if children.len() == 2 { + let _prefilter = children.pop().expect("length checked"); + } + // NOTE!!!! Prefilter transformation is ignored. + Self { + input: children.pop().expect("length checked"), + dataset: self.dataset.clone(), + indices: self.indices.clone(), + query: self.query.clone(), + prefilter_source: self.prefilter_source.clone(), + properties: self.properties.clone(), + metrics: ExecutionPlanMetricsSet::new(), + } + } else { return Err(DataFusionError::Internal( - "ANNSubIndexExec node must have exactly one child".to_string(), + "ANNSubIndexExec node must have exactly one or two (prefilter) child".to_string(), )); - } - - let new_plan = Self { - input: children.pop().expect("length checked"), - dataset: self.dataset.clone(), - indices: self.indices.clone(), - query: self.query.clone(), - prefilter_source: self.prefilter_source.clone(), - properties: self.properties.clone(), }; - - Ok(Arc::new(new_plan)) + Ok(Arc::new(plan)) } fn execute( @@ -539,14 +589,14 @@ impl ExecutionPlan for ANNIvfSubIndexExec { context: Arc, ) -> DataFusionResult { let input_stream = self.input.execute(partition, context.clone())?; - let schema = self.schema(); let query = self.query.clone(); let ds = self.dataset.clone(); let column = self.query.column.clone(); let indices = self.indices.clone(); let prefilter_source = self.prefilter_source.clone(); - + let metrics = Arc::new(AnnMetrics::new(&self.metrics, partition)); + let metrics_clone = metrics.clone(); // Per-delta-index stream: // Stream<(parttitions, index uuid)> let per_index_stream = input_stream @@ -578,7 +628,7 @@ impl ExecutionPlan for ANNIvfSubIndexExec { }) .try_flatten(); - Ok(Box::pin(RecordBatchStreamAdapter::new( + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( schema, per_index_stream .and_then(move |(part_ids, index_uuid)| { @@ -587,7 +637,7 @@ impl ExecutionPlan for ANNIvfSubIndexExec { let indices = indices.clone(); let context = context.clone(); let prefilter_source = prefilter_source.clone(); - + let metrics = metrics.clone(); let index_meta = indices .iter() .find(|idx| idx.uuid.to_string() == index_uuid) @@ -614,7 +664,9 @@ impl ExecutionPlan for ANNIvfSubIndexExec { prefilter_loader, )); - let raw_index = ds.open_vector_index(&column, &index_uuid).await?; + let raw_index = ds + .open_vector_index(&column, &index_uuid, &metrics.index_metrics) + .await?; Ok::<_, DataFusionError>( stream::iter(part_ids) @@ -626,6 +678,7 @@ impl ExecutionPlan for ANNIvfSubIndexExec { .try_flatten() .map(move |result| { let query = query.clone(); + let metrics = metrics_clone.clone(); async move { let (part_id, (index, pre_filter)) = result?; @@ -636,7 +689,12 @@ impl ExecutionPlan for ANNIvfSubIndexExec { }; index - .search_in_partition(part_id as usize, &query, pre_filter) + .search_in_partition( + part_id as usize, + &query, + pre_filter, + &metrics.index_metrics, + ) .map_err(|e| { DataFusionError::Execution(format!( "Failed to calculate KNN: {}", @@ -648,6 +706,8 @@ impl ExecutionPlan for ANNIvfSubIndexExec { }) .buffered(get_num_compute_intensive_cpus()) .boxed(), + partition, + &self.metrics, ))) } @@ -662,6 +722,202 @@ impl ExecutionPlan for ANNIvfSubIndexExec { }) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } +} + +#[derive(Debug)] +pub struct MultivectorScoringExec { + // the inputs are sorted ANN search results + inputs: Vec>, + query: Query, + properties: PlanProperties, +} + +impl MultivectorScoringExec { + pub fn try_new(inputs: Vec>, query: Query) -> Result { + let properties = PlanProperties::new( + EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Final, + Boundedness::Bounded, + ); + + Ok(Self { + inputs, + query, + properties, + }) + } +} + +impl DisplayAs for MultivectorScoringExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "MultivectorScoring: k={}", self.query.k) + } + } + } +} + +impl ExecutionPlan for MultivectorScoringExec { + fn name(&self) -> &str { + "MultivectorScoringExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + KNN_INDEX_SCHEMA.clone() + } + + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() + } + + fn required_input_distribution(&self) -> Vec { + // This node fully consumes and re-orders the input rows. It must be + // run on a single partition. + self.children() + .iter() + .map(|_| Distribution::SinglePartition) + .collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + let plan = Self::try_new(children, self.query.clone())?; + Ok(Arc::new(plan)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let inputs = self + .inputs + .iter() + .map(|input| input.execute(partition, context.clone())) + .collect::>>()?; + + // collect the top k results from each stream, + // and max-reduce for each query, + // records the minimum distance for each query as estimation. + let mut reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| { + stream.map(|batch| { + let batch = batch?; + let row_ids = batch[ROW_ID].as_primitive::(); + let dists = batch[DIST_COL].as_primitive::(); + debug_assert_eq!(dists.null_count(), 0); + + // max-reduce for the same row id + let min_sim = dists + .values() + .last() + .map(|dist| 1.0 - *dist) + .unwrap_or_default(); + let mut new_row_ids = Vec::with_capacity(row_ids.len()); + let mut new_sims = Vec::with_capacity(row_ids.len()); + let mut visited_row_ids = HashSet::with_capacity(row_ids.len()); + + for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { + // the results are sorted by distance, so we can skip if we have seen this row id before + if visited_row_ids.contains(row_id) { + continue; + } + visited_row_ids.insert(row_id); + new_row_ids.push(*row_id); + // it's cosine distance, so we need to convert it to similarity + new_sims.push(1.0 - *dist); + } + let new_row_ids = UInt64Array::from(new_row_ids); + let new_dists = Float32Array::from(new_sims); + + let batch = RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![Arc::new(new_dists), Arc::new(new_row_ids)], + )?; + + Ok::<_, DataFusionError>((min_sim, batch)) + }) + })); + + let k = self.query.k; + let refactor = self.query.refine_factor.unwrap_or(1) as usize; + let num_queries = self.inputs.len() as f32; + let stream = stream::once(async move { + // at most, we will have k * refine_factor results for each query + let mut results = HashMap::with_capacity(k * refactor); + let mut missed_sim_sum = 0.0; + while let Some((min_sim, batch)) = reduced_inputs.try_next().await? { + let row_ids = batch[ROW_ID].as_primitive::(); + let sims = batch[DIST_COL].as_primitive::(); + + let query_results = row_ids + .values() + .iter() + .copied() + .zip(sims.values().iter().copied()) + .collect::>(); + + // for a row `r`: + // if `r` is in only `results``, then `results[r] += min_sim` + // if `r` is in only `query_results`, then `results[r] = query_results[r] + missed_similarities`, + // here `missed_similarities` is the sum of `min_sim` from previous iterations + // if `r` is in both, then `results[r] += query_results[r]` + results.iter_mut().for_each(|(row_id, sim)| { + if let Some(new_dist) = query_results.get(row_id) { + *sim += new_dist; + } else { + *sim += min_sim; + } + }); + query_results.into_iter().for_each(|(row_id, sim)| { + results.entry(row_id).or_insert(sim + missed_sim_sum); + }); + missed_sim_sum += min_sim; + } + + let (row_ids, sims): (Vec<_>, Vec<_>) = results.into_iter().unzip(); + let dists = sims + .into_iter() + // it's similarity, so we need to convert it back to distance + .map(|sim| num_queries - sim) + .collect::>(); + let row_ids = UInt64Array::from(row_ids); + let dists = Float32Array::from(dists); + let batch = RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![Arc::new(dists), Arc::new(row_ids)], + )?; + Ok::<_, DataFusionError>(batch) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream.boxed(), + ))) + } + + fn statistics(&self) -> DataFusionResult { + Ok(Statistics { + num_rows: Precision::Inexact( + self.query.k * self.query.refine_factor.unwrap_or(1) as usize, + ), + ..Statistics::new_unknown(self.schema().as_ref()) + }) + } + fn properties(&self) -> &PlanProperties { &self.properties } @@ -737,7 +993,7 @@ mod tests { let dataset = Dataset::open(test_uri).await.unwrap(); let stream = dataset .scan() - .nearest("vector", q.as_primitive(), 10) + .nearest("vector", q.as_primitive::(), 10) .unwrap() .try_into_stream() .await @@ -806,4 +1062,75 @@ mod tests { ]) ); } + + #[tokio::test] + async fn test_multivector_score() { + let query = Query { + column: "vector".to_string(), + key: Arc::new(generate_random_array(1)), + k: 10, + lower_bound: None, + upper_bound: None, + nprobes: 1, + ef: None, + refine_factor: None, + metric_type: DistanceType::Cosine, + use_index: true, + }; + + async fn multivector_scoring( + inputs: Vec>, + query: Query, + ) -> Result> { + let ctx = Arc::new(datafusion::execution::context::TaskContext::default()); + let plan = MultivectorScoringExec::try_new(inputs, query.clone())?; + let batches = plan + .execute(0, ctx.clone()) + .unwrap() + .try_collect::>() + .await?; + let mut results = HashMap::new(); + for batch in batches { + let row_ids = batch[ROW_ID].as_primitive::(); + let dists = batch[DIST_COL].as_primitive::(); + for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { + results.insert(*row_id, *dist); + } + } + Ok(results) + } + + let batches = (0..3) + .map(|i| { + RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![ + Arc::new(Float32Array::from(vec![i as f32 + 1.0, i as f32 + 2.0])), + Arc::new(UInt64Array::from(vec![i + 1, i + 2])), + ], + ) + .unwrap() + }) + .collect::>(); + + let mut res: Option> = None; + for perm in batches.into_iter().permutations(3) { + let inputs = perm + .into_iter() + .map(|batch| { + let input: Arc = Arc::new(TestingExec::new(vec![batch])); + input + }) + .collect::>(); + let new_res = multivector_scoring(inputs, query.clone()).await.unwrap(); + assert_eq!(new_res.len(), 4); + if let Some(res) = &res { + for (row_id, dist) in new_res.iter() { + assert_eq!(res.get(row_id).unwrap(), dist) + } + } else { + res = Some(new_res); + } + } + } } diff --git a/rust/lance/src/io/exec/optimizer.rs b/rust/lance/src/io/exec/optimizer.rs index b05e5f5feb9..79dddcb1bfb 100644 --- a/rust/lance/src/io/exec/optimizer.rs +++ b/rust/lance/src/io/exec/optimizer.rs @@ -6,18 +6,69 @@ use std::sync::Arc; use super::TakeExec; +use arrow_schema::Schema as ArrowSchema; use datafusion::{ common::tree_node::{Transformed, TreeNode}, config::ConfigOptions, error::Result as DFResult, physical_optimizer::{optimizer::PhysicalOptimizer, PhysicalOptimizerRule}, - physical_plan::{projection::ProjectionExec as DFProjectionExec, ExecutionPlan}, + physical_plan::{ + coalesce_batches::CoalesceBatchesExec, projection::ProjectionExec, ExecutionPlan, + }, }; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; /// Rule that eliminates [TakeExec] nodes that are immediately followed by another [TakeExec]. +#[derive(Debug)] pub struct CoalesceTake; +impl CoalesceTake { + fn field_order_differs(old_schema: &ArrowSchema, new_schema: &ArrowSchema) -> bool { + old_schema + .fields + .iter() + .zip(&new_schema.fields) + .any(|(old, new)| old.name() != new.name()) + } + + fn remap_collapsed_output( + old_schema: &ArrowSchema, + new_schema: &ArrowSchema, + plan: Arc, + ) -> Arc { + let mut project_exprs = Vec::with_capacity(old_schema.fields.len()); + for field in &old_schema.fields { + project_exprs.push(( + Arc::new(Column::new_with_schema(field.name(), new_schema).unwrap()) + as Arc, + field.name().clone(), + )); + } + Arc::new(ProjectionExec::try_new(project_exprs, plan).unwrap()) + } + + fn collapse_takes( + inner_take: &TakeExec, + outer_take: &TakeExec, + outer_exec: Arc, + ) -> Arc { + let inner_take_input = inner_take.children()[0].clone(); + let old_output_schema = outer_take.schema(); + let collapsed = outer_exec + .with_new_children(vec![inner_take_input]) + .unwrap(); + let new_output_schema = collapsed.schema(); + + // It's possible that collapsing the take can change the field order. This disturbs DF's planner and + // so we must restore it. + if Self::field_order_differs(&old_output_schema, &new_output_schema) { + Self::remap_collapsed_output(&old_output_schema, &new_output_schema, collapsed) + } else { + collapsed + } + } +} + impl PhysicalOptimizerRule for CoalesceTake { fn optimize( &self, @@ -26,11 +77,27 @@ impl PhysicalOptimizerRule for CoalesceTake { ) -> DFResult> { Ok(plan .transform_down(|plan| { - if let Some(take) = plan.as_any().downcast_ref::() { - let child = take.children()[0]; - if let Some(exec_child) = child.as_any().downcast_ref::() { + if let Some(outer_take) = plan.as_any().downcast_ref::() { + let child = outer_take.children()[0]; + // Case 1: TakeExec -> TakeExec + if let Some(inner_take) = child.as_any().downcast_ref::() { + return Ok(Transformed::yes(Self::collapse_takes( + inner_take, + outer_take, + plan.clone(), + ))); + // Case 2: TakeExec -> CoalesceBatchesExec -> TakeExec + } else if let Some(exec_child) = + child.as_any().downcast_ref::() + { let inner_child = exec_child.children()[0].clone(); - return Ok(Transformed::yes(plan.with_new_children(vec![inner_child])?)); + if let Some(inner_take) = inner_child.as_any().downcast_ref::() { + return Ok(Transformed::yes(Self::collapse_takes( + inner_take, + outer_take, + plan.clone(), + ))); + } } } Ok(Transformed::no(plan)) @@ -49,6 +116,7 @@ impl PhysicalOptimizerRule for CoalesceTake { /// Rule that eliminates [ProjectionExec] nodes that projects all columns /// from its input with no additional expressions. +#[derive(Debug)] pub struct SimplifyProjection; impl PhysicalOptimizerRule for SimplifyProjection { @@ -59,7 +127,7 @@ impl PhysicalOptimizerRule for SimplifyProjection { ) -> DFResult> { Ok(plan .transform_down(|plan| { - if let Some(proj) = plan.as_any().downcast_ref::() { + if let Some(proj) = plan.as_any().downcast_ref::() { let children = proj.children(); if children.len() != 1 { return Ok(Transformed::no(plan)); diff --git a/rust/lance/src/io/exec/projection.rs b/rust/lance/src/io/exec/projection.rs index 09b0d3fbcfe..1fb405024f6 100644 --- a/rust/lance/src/io/exec/projection.rs +++ b/rust/lance/src/io/exec/projection.rs @@ -203,9 +203,10 @@ pub fn compute_projection<'a>( #[cfg(test)] mod tests { use arrow_array::{ArrayRef, Int32Array, RecordBatch, StructArray}; - use datafusion::{physical_plan::memory::MemoryExec, prelude::SessionContext}; + use datafusion::prelude::SessionContext; use futures::TryStreamExt; use lance_core::datatypes::Schema; + use lance_datafusion::exec::OneShotExec; use super::*; @@ -278,8 +279,7 @@ mod tests { } async fn apply_to_batch(batch: RecordBatch, projection: &ArrowSchema) -> Result { - let schema = batch.schema(); - let memory_exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap(); + let memory_exec = OneShotExec::from_batch(batch); let exec = project(Arc::new(memory_exec), projection)?; let claimed_schema = exec.schema(); let session = SessionContext::new(); diff --git a/rust/lance/src/io/exec/pushdown_scan.rs b/rust/lance/src/io/exec/pushdown_scan.rs index 92bdf481326..434939141a2 100644 --- a/rust/lance/src/io/exec/pushdown_scan.rs +++ b/rust/lance/src/io/exec/pushdown_scan.rs @@ -15,12 +15,13 @@ use datafusion::logical_expr::col; use datafusion::logical_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; -use datafusion::physical_plan::{ColumnarValue, ExecutionMode, PlanProperties}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::{ColumnarValue, PlanProperties}; use datafusion::scalar::ScalarValue; use datafusion::{ physical_plan::{ - stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, - Partitioning, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }, prelude::Expr, }; @@ -32,7 +33,7 @@ use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID_FIELD}; use lance_io::ReadBatchParams; use lance_table::format::Fragment; -use snafu::{location, Location}; +use snafu::location; use crate::dataset::fragment::FragReadConfig; use crate::dataset::scanner::LEGACY_DEFAULT_FRAGMENT_READAHEAD; @@ -46,6 +47,7 @@ use crate::{ Dataset, }; +use super::utils::InstrumentedRecordBatchStreamAdapter; use super::Planner; #[derive(Debug, Clone)] @@ -93,6 +95,7 @@ pub struct LancePushdownScanExec { config: ScanConfig, output_schema: Arc, properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } impl LancePushdownScanExec { @@ -131,7 +134,8 @@ impl LancePushdownScanExec { let properties = PlanProperties::new( EquivalenceProperties::new(output_schema.clone()), Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Ok(Self { @@ -143,6 +147,7 @@ impl LancePushdownScanExec { config, output_schema, properties, + metrics: ExecutionPlanMetricsSet::new(), }) } } @@ -166,18 +171,28 @@ impl ExecutionPlan for LancePushdownScanExec { fn with_new_children( self: Arc, - _children: Vec>, + children: Vec>, ) -> datafusion::error::Result> { - todo!() + if !children.is_empty() { + Err(DataFusionError::Internal( + "LancePushdownScanExec does not accept children".to_string(), + )) + } else { + Ok(self) + } } fn statistics(&self) -> datafusion::error::Result { Ok(Statistics::new_unknown(self.output_schema.as_ref())) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn execute( &self, - _partition: usize, + partition: usize, _context: Arc, ) -> Result { // To get a stream with a static lifetime, we clone self put it into @@ -213,9 +228,11 @@ impl ExecutionPlan for LancePushdownScanExec { .buffered(self.config.fragment_readahead) .try_flatten(); - Ok(Box::pin(RecordBatchStreamAdapter::new( + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( self.schema(), batch_stream, + partition, + &self.metrics, ))) } @@ -275,7 +292,7 @@ impl FragmentScanner { // We will call the reader with projections. In order for this to work // we must ensure that we open the fragment with the maximal schema. let mut reader = fragment - .open(dataset.schema(), FragReadConfig::default(), None) + .open(dataset.schema(), FragReadConfig::default()) .await?; if config.make_deletions_null { reader.with_make_deletions_null(); diff --git a/rust/lance/src/io/exec/rowids.rs b/rust/lance/src/io/exec/rowids.rs index 90d36532c74..31844acabee 100644 --- a/rust/lance/src/io/exec/rowids.rs +++ b/rust/lance/src/io/exec/rowids.rs @@ -9,7 +9,7 @@ use datafusion::common::stats::Precision; use datafusion::common::ColumnStatistics; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::SendableRecordBatchStream; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use datafusion_physical_expr::EquivalenceProperties; use futures::StreamExt; @@ -20,6 +20,8 @@ use crate::dataset::rowids::get_row_id_index; use crate::utils::future::SharedPrerequisite; use crate::Dataset; +use super::utils::InstrumentedRecordBatchStreamAdapter; + /// Add a `_rowaddr` column to a stream of record batches that have a `_rowid`. /// /// It's generally more efficient to scan the `_rowaddr` column, but this can be @@ -36,6 +38,8 @@ pub struct AddRowAddrExec { rowaddr_pos: usize, output_schema: SchemaRef, properties: PlanProperties, + + metrics: ExecutionPlanMetricsSet, } impl std::fmt::Debug for AddRowAddrExec { @@ -105,6 +109,7 @@ impl AddRowAddrExec { rowaddr_pos, output_schema, properties, + metrics: ExecutionPlanMetricsSet::new(), }) } @@ -182,11 +187,26 @@ impl ExecutionPlan for AddRowAddrExec { vec![&self.input] } + fn benefits_from_input_partitioning(&self) -> Vec { + // We aren't doing much work here, best to avoid the thread overhead + vec![false] + } + fn with_new_children( self: Arc, - _children: Vec>, + children: Vec>, ) -> Result> { - todo!() + if children.len() != 1 { + Err(DataFusionError::Internal( + "AddRowAddrExec: invalid number of children".into(), + )) + } else { + Ok(Arc::new(Self::try_new( + children.into_iter().next().unwrap(), + self.dataset.clone(), + self.rowaddr_pos, + )?)) + } } fn execute( @@ -229,7 +249,12 @@ impl ExecutionPlan for AddRowAddrExec { } }); - let stream = RecordBatchStreamAdapter::new(self.output_schema.clone(), stream.boxed()); + let stream = InstrumentedRecordBatchStreamAdapter::new( + self.output_schema.clone(), + stream.boxed(), + partition, + &self.metrics, + ); Ok(Box::pin(stream)) } @@ -240,8 +265,9 @@ impl ExecutionPlan for AddRowAddrExec { DataFusionError::Internal("RowAddrExec: rowid column stats not found".into()) })?; let row_addr_col_stats = ColumnStatistics { - null_count: row_id_col_stats.null_count.clone(), - distinct_count: row_id_col_stats.distinct_count.clone(), + null_count: row_id_col_stats.null_count, + distinct_count: row_id_col_stats.distinct_count, + sum_value: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, }; @@ -251,7 +277,6 @@ impl ExecutionPlan for AddRowAddrExec { // is a minimum size of 64 bytes. let mut added_byte_size = stats .num_rows - .clone() .map(|n| (n * 8).max(64)) .add(&Precision::Exact(base_size)); if row_id_col_stats @@ -261,8 +286,7 @@ impl ExecutionPlan for AddRowAddrExec { .unwrap_or_default() { // Account for null buffer. - added_byte_size = - added_byte_size.add(&stats.num_rows.clone().map(|n| n.div_ceil(8).max(64))); + added_byte_size = added_byte_size.add(&stats.num_rows.map(|n| n.div_ceil(8).max(64))); } stats.total_byte_size = stats.total_byte_size.add(&added_byte_size); stats @@ -272,6 +296,10 @@ impl ExecutionPlan for AddRowAddrExec { Ok(stats) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn properties(&self) -> &PlanProperties { &self.properties } @@ -281,17 +309,17 @@ impl ExecutionPlan for AddRowAddrExec { mod test { use arrow_array::{Int32Array, RecordBatchIterator}; use arrow_schema::{DataType, Field}; - use datafusion::{physical_plan::memory::MemoryExec, prelude::SessionContext}; + use datafusion::{datasource::memory::MemorySourceConfig, prelude::SessionContext}; use futures::TryStreamExt; use lance_core::{ROW_ADDR, ROW_ID_FIELD}; + use lance_datafusion::exec::OneShotExec; use crate::dataset::WriteParams; use super::*; async fn apply_to_batch(batch: RecordBatch, dataset: Arc) -> Result { - let schema = batch.schema(); - let memory_exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap(); + let memory_exec = OneShotExec::from_batch(batch); let exec = AddRowAddrExec::try_new(Arc::new(memory_exec), dataset, 0)?; let session = SessionContext::new(); let task_ctx = session.task_ctx(); @@ -412,12 +440,9 @@ mod test { let schema = Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])); let batch = RecordBatch::try_new(schema.clone(), vec![rowids.clone()]).unwrap(); - let exec = AddRowAddrExec::try_new( - Arc::new(MemoryExec::try_new(&[vec![batch.clone()]], schema.clone(), None).unwrap()), - dataset.clone(), - 0, - ) - .unwrap(); + let memory_exec = + MemorySourceConfig::try_new_exec(&[vec![batch.clone()]], schema, None).unwrap(); + let exec = AddRowAddrExec::try_new(memory_exec, dataset.clone(), 0).unwrap(); let stats = exec.statistics().unwrap(); let result = apply_to_batch(batch, dataset).await.unwrap(); diff --git a/rust/lance/src/io/exec/scalar_index.rs b/rust/lance/src/io/exec/scalar_index.rs index 0f39ed61241..2a903926365 100644 --- a/rust/lance/src/io/exec/scalar_index.rs +++ b/rust/lance/src/io/exec/scalar_index.rs @@ -9,8 +9,10 @@ use async_trait::async_trait; use datafusion::{ common::{stats::Precision, Statistics}, physical_plan::{ - stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionMode, - ExecutionPlan, Partitioning, PlanProperties, + execution_plan::{Boundedness, EmissionType}, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, }, scalar::ScalarValue, }; @@ -25,15 +27,16 @@ use lance_core::{ }; use lance_datafusion::chunker::break_stream; use lance_index::{ + metrics::MetricsCollector, scalar::{ - expression::{ScalarIndexExpr, ScalarIndexLoader}, + expression::{IndexExprResult, ScalarIndexExpr, ScalarIndexLoader}, SargableQuery, ScalarIndex, }, DatasetIndexExt, }; use lance_table::format::Fragment; use roaring::RoaringBitmap; -use snafu::{location, Location}; +use snafu::location; use tracing::{debug_span, instrument}; use crate::{ @@ -42,13 +45,19 @@ use crate::{ Dataset, }; +use super::utils::{IndexMetrics, InstrumentedRecordBatchStreamAdapter}; + lazy_static::lazy_static! { pub static ref SCALAR_INDEX_SCHEMA: SchemaRef = Arc::new(Schema::new(vec![Field::new("result".to_string(), DataType::Binary, true)])); } #[async_trait] impl ScalarIndexLoader for Dataset { - async fn load_index(&self, name: &str) -> Result> { + async fn load_index( + &self, + name: &str, + metrics: &dyn MetricsCollector, + ) -> Result> { let idx = self .load_scalar_index_for_column(name) .await? @@ -56,7 +65,8 @@ impl ScalarIndexLoader for Dataset { message: format!("Scanner created plan for index query on {} but no index on dataset for that column", name), location: location!() })?; - self.open_scalar_index(name, &idx.uuid.to_string()).await + self.open_scalar_index(name, &idx.uuid.to_string(), metrics) + .await } } @@ -72,6 +82,7 @@ pub struct ScalarIndexExec { dataset: Arc, expr: ScalarIndexExpr, properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } impl DisplayAs for ScalarIndexExec { @@ -89,21 +100,30 @@ impl ScalarIndexExec { let properties = PlanProperties::new( EquivalenceProperties::new(SCALAR_INDEX_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Self { dataset, expr, properties, + metrics: ExecutionPlanMetricsSet::new(), } } - async fn do_execute(expr: ScalarIndexExpr, dataset: Arc) -> Result { - let query_result = expr.evaluate(dataset.as_ref()).await?; - let query_result_arr = query_result.into_arrow()?; + async fn do_execute( + expr: ScalarIndexExpr, + dataset: Arc, + metrics: IndexMetrics, + ) -> Result { + let query_result = expr.evaluate(dataset.as_ref(), &metrics).await?; + let IndexExprResult::Exact(row_id_mask) = query_result else { + todo!("Support for non-exact query results as pre-filter for vector search") + }; + let row_id_mask_arr = row_id_mask.into_arrow()?; Ok(RecordBatch::try_new( SCALAR_INDEX_SCHEMA.clone(), - vec![Arc::new(query_result_arr)], + vec![Arc::new(row_id_mask_arr)], )?) } } @@ -127,24 +147,33 @@ impl ExecutionPlan for ScalarIndexExec { fn with_new_children( self: Arc, - _children: Vec>, + children: Vec>, ) -> datafusion::error::Result> { - todo!() + if !children.is_empty() { + Err(datafusion::error::DataFusionError::Internal( + "ScalarIndexExec does not have children".to_string(), + )) + } else { + Ok(self) + } } fn execute( &self, - _partition: usize, + partition: usize, _context: Arc, ) -> datafusion::error::Result { - let batch_fut = Self::do_execute(self.expr.clone(), self.dataset.clone()); + let metrics = IndexMetrics::new(&self.metrics, partition); + let batch_fut = Self::do_execute(self.expr.clone(), self.dataset.clone(), metrics); let stream = futures::stream::iter(vec![batch_fut]) .then(|batch_fut| batch_fut.map_err(|err| err.into())) .boxed() as BoxStream<'static, datafusion::common::Result>; - Ok(Box::pin(RecordBatchStreamAdapter::new( + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( SCALAR_INDEX_SCHEMA.clone(), stream, + partition, + &self.metrics, ))) } @@ -155,6 +184,10 @@ impl ExecutionPlan for ScalarIndexExec { }) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn properties(&self) -> &PlanProperties { &self.properties } @@ -173,6 +206,7 @@ pub struct MapIndexExec { column_name: String, input: Arc, properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } impl DisplayAs for MapIndexExec { @@ -190,13 +224,15 @@ impl MapIndexExec { let properties = PlanProperties::new( EquivalenceProperties::new(INDEX_LOOKUP_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Self { dataset, column_name, input, properties, + metrics: ExecutionPlanMetricsSet::new(), } } @@ -205,6 +241,7 @@ impl MapIndexExec { dataset: Arc, deletion_mask: Option>, batch: RecordBatch, + metrics: Arc, ) -> datafusion::error::Result { let index_vals = batch.column(0); let index_vals = (0..index_vals.len()) @@ -214,15 +251,18 @@ impl MapIndexExec { column_name.clone(), Arc::new(SargableQuery::IsIn(index_vals)), ); - let mut row_addresses = query.evaluate(dataset.as_ref()).await?; + let query_result = query.evaluate(dataset.as_ref(), metrics.as_ref()).await?; + let IndexExprResult::Exact(mut row_id_mask) = query_result else { + todo!("Support for non-exact query results as input for merge_insert") + }; if let Some(deletion_mask) = deletion_mask.as_ref() { - row_addresses = row_addresses & deletion_mask.as_ref().clone(); + row_id_mask = row_id_mask & deletion_mask.as_ref().clone(); } - if let Some(mut allow_list) = row_addresses.allow_list { + if let Some(mut allow_list) = row_id_mask.allow_list { // Flatten the allow list - if let Some(block_list) = row_addresses.block_list { + if let Some(block_list) = row_id_mask.block_list { allow_list -= &block_list; } @@ -249,6 +289,7 @@ impl MapIndexExec { input: datafusion::physical_plan::SendableRecordBatchStream, dataset: Arc, column_name: String, + metrics: Arc, ) -> datafusion::error::Result< impl Stream> + Send + 'static, > { @@ -267,7 +308,8 @@ impl MapIndexExec { let column_name = column_name.clone(); let dataset = dataset.clone(); let deletion_mask = deletion_mask.clone(); - Self::map_batch(column_name, dataset, deletion_mask, res) + let metrics = metrics.clone(); + Self::map_batch(column_name, dataset, deletion_mask, res, metrics) })) } } @@ -291,9 +333,19 @@ impl ExecutionPlan for MapIndexExec { fn with_new_children( self: Arc, - _: Vec>, + children: Vec>, ) -> datafusion::error::Result> { - unimplemented!() + if children.len() != 1 { + Err(datafusion::error::DataFusionError::Internal( + "MapIndexExec requires exactly one child".to_string(), + )) + } else { + Ok(Arc::new(Self::new( + self.dataset.clone(), + self.column_name.clone(), + children.into_iter().next().unwrap(), + ))) + } } fn execute( @@ -302,15 +354,22 @@ impl ExecutionPlan for MapIndexExec { context: Arc, ) -> datafusion::error::Result { let index_vals = self.input.execute(partition, context)?; - let stream_fut = - Self::do_execute(index_vals, self.dataset.clone(), self.column_name.clone()); + let metrics = Arc::new(IndexMetrics::new(&self.metrics, partition)); + let stream_fut = Self::do_execute( + index_vals, + self.dataset.clone(), + self.column_name.clone(), + metrics, + ); let stream = futures::stream::iter(vec![stream_fut]) .then(|stream_fut| stream_fut) .try_flatten() .boxed(); - Ok(Box::pin(RecordBatchStreamAdapter::new( + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( INDEX_LOOKUP_SCHEMA.clone(), stream, + partition, + &self.metrics, ))) } @@ -334,6 +393,7 @@ pub struct MaterializeIndexExec { expr: ScalarIndexExpr, fragments: Arc>, properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } impl DisplayAs for MaterializeIndexExec { @@ -362,7 +422,7 @@ impl<'a> FragIdIter<'a> { } } -impl<'a> Iterator for FragIdIter<'a> { +impl Iterator for FragIdIter<'_> { type Item = u64; fn next(&mut self) -> Option { @@ -394,13 +454,15 @@ impl MaterializeIndexExec { let properties = PlanProperties::new( EquivalenceProperties::new(MATERIALIZE_INDEX_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Self { dataset, expr, fragments, properties, + metrics: ExecutionPlanMetricsSet::new(), } } @@ -409,8 +471,9 @@ impl MaterializeIndexExec { expr: ScalarIndexExpr, dataset: Arc, fragments: Arc>, + metrics: Arc, ) -> Result { - let mask = expr.evaluate(dataset.as_ref()); + let expr_result = expr.evaluate(dataset.as_ref(), metrics.as_ref()); let span = debug_span!("create_prefilter"); let prefilter = span.in_scope(|| { let fragment_bitmap = @@ -421,10 +484,20 @@ impl MaterializeIndexExec { DatasetPreFilter::create_deletion_mask(dataset.clone(), fragment_bitmap) }); let mask = if let Some(prefilter) = prefilter { - let (mask, prefilter) = futures::try_join!(mask, prefilter)?; + let (expr_result, prefilter) = futures::try_join!(expr_result, prefilter)?; + let mask = match expr_result { + IndexExprResult::Exact(mask) => mask, + IndexExprResult::AtMost(mask) => mask, + IndexExprResult::AtLeast(_) => todo!("Support AtLeast in MaterializeIndexExec"), + }; mask & (*prefilter).clone() } else { - mask.await? + let expr_result = expr_result.await?; + match expr_result { + IndexExprResult::Exact(mask) => mask, + IndexExprResult::AtMost(mask) => mask, + IndexExprResult::AtLeast(_) => todo!("Support AtLeast in MaterializeIndexExec"), + } }; let ids = row_ids_for_mask(mask, &dataset, &fragments).await?; let ids = UInt64Array::from(ids); @@ -563,20 +636,28 @@ impl ExecutionPlan for MaterializeIndexExec { fn with_new_children( self: Arc, - _children: Vec>, + children: Vec>, ) -> datafusion::error::Result> { - todo!() + if !children.is_empty() { + Err(datafusion::error::DataFusionError::Internal( + "MaterializeIndexExec does not have children".to_string(), + )) + } else { + Ok(self) + } } fn execute( &self, - _partition: usize, - _context: Arc, + partition: usize, + context: Arc, ) -> datafusion::error::Result { + let metrics = Arc::new(IndexMetrics::new(&self.metrics, partition)); let batch_fut = Self::do_execute( self.expr.clone(), self.dataset.clone(), self.fragments.clone(), + metrics, ); let stream = futures::stream::iter(vec![batch_fut]) .then(|batch_fut| batch_fut.map_err(|err| err.into())) @@ -586,10 +667,12 @@ impl ExecutionPlan for MaterializeIndexExec { MATERIALIZE_INDEX_SCHEMA.clone(), stream, )); - let stream = break_stream(stream, 64 * 1024); - Ok(Box::pin(RecordBatchStreamAdapter::new( + let stream = break_stream(stream, context.session_config().batch_size()); + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( MATERIALIZE_INDEX_SCHEMA.clone(), stream.map_err(|err| err.into()), + partition, + &self.metrics, ))) } @@ -597,7 +680,137 @@ impl ExecutionPlan for MaterializeIndexExec { Ok(Statistics::new_unknown(&MATERIALIZE_INDEX_SCHEMA)) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn properties(&self) -> &PlanProperties { &self.properties } } + +#[cfg(test)] +mod tests { + use std::{ops::Bound, sync::Arc}; + + use arrow::datatypes::UInt64Type; + use datafusion::{ + execution::TaskContext, physical_plan::ExecutionPlan, prelude::SessionConfig, + scalar::ScalarValue, + }; + use futures::TryStreamExt; + use lance_datagen::gen; + use lance_index::{ + scalar::{expression::ScalarIndexExpr, SargableQuery, ScalarIndexParams}, + DatasetIndexExt, IndexType, + }; + use tempfile::{tempdir, TempDir}; + + use crate::{ + io::exec::scalar_index::MaterializeIndexExec, + utils::test::{DatagenExt, FragmentCount, FragmentRowCount, NoContextTestFixture}, + Dataset, + }; + + use super::{MapIndexExec, ScalarIndexExec}; + + struct TestFixture { + dataset: Arc, + _tmp_dir_guard: TempDir, + } + + async fn test_fixture() -> TestFixture { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let mut dataset = gen() + .col("ordered", lance_datagen::array::step::()) + .into_dataset( + test_uri, + FragmentCount::from(10), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + + dataset + .create_index( + &["ordered"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + TestFixture { + dataset: Arc::new(dataset), + _tmp_dir_guard: test_dir, + } + } + + #[tokio::test] + async fn test_materialize_index_exec() { + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; + + let query = ScalarIndexExpr::Query( + "ordered".to_string(), + Arc::new(SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::UInt64(Some(47))), + )), + ); + + let fragments = dataset.fragments().clone(); + + let plan = MaterializeIndexExec::new(dataset, query, fragments); + + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + + let batches = stream.try_collect::>().await.unwrap(); + + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 47); + + let context = + TaskContext::default().with_session_config(SessionConfig::default().with_batch_size(5)); + let stream = plan.execute(0, Arc::new(context)).unwrap(); + let batches = stream.try_collect::>().await.unwrap(); + + assert_eq!(batches.len(), 10); + assert_eq!(batches[0].num_rows(), 5); + } + + #[test] + fn no_context_scalar_index() { + // These tests ensure we can create nodes and call execute without a tokio Runtime + // being active. This is a requirement for proper implementation of a Datafusion foreign + // table provider. + let fixture = NoContextTestFixture::new(); + let arc_dasaset = Arc::new(fixture.dataset); + + let query = ScalarIndexExpr::Query( + "ordered".to_string(), + Arc::new(SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::UInt64(Some(47))), + )), + ); + + // These plans aren't even valid but it appears we defer all work (even validation) until + // read time. + let plan = ScalarIndexExec::new(arc_dasaset.clone(), query.clone()); + plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + + let plan = MapIndexExec::new(arc_dasaset.clone(), "ordered".to_string(), Arc::new(plan)); + plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + + let plan = + MaterializeIndexExec::new(arc_dasaset.clone(), query, arc_dasaset.fragments().clone()); + plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + } +} diff --git a/rust/lance/src/io/exec/scan.rs b/rust/lance/src/io/exec/scan.rs index 5ec680c647a..dd922007f04 100644 --- a/rust/lance/src/io/exec/scan.rs +++ b/rust/lance/src/io/exec/scan.rs @@ -11,6 +11,9 @@ use arrow_array::RecordBatch; use arrow_schema::{Schema as ArrowSchema, SchemaRef}; use datafusion::common::stats::Precision; use datafusion::error::{DataFusionError, Result}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -27,7 +30,7 @@ use lance_core::{Error, ROW_ADDR_FIELD, ROW_ID_FIELD}; use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; use lance_table::format::Fragment; use log::debug; -use snafu::{location, Location}; +use snafu::location; use crate::dataset::fragment::{FileFragment, FragReadConfig, FragmentReader}; use crate::dataset::scanner::{ @@ -37,16 +40,22 @@ use crate::dataset::scanner::{ use crate::dataset::Dataset; use crate::datatypes::Schema; +use super::utils::IoMetrics; + async fn open_file( file_fragment: FileFragment, projection: Arc, - read_config: FragReadConfig, + mut read_config: FragReadConfig, with_make_deletions_null: bool, - scan_scheduler: Option<(Arc, u64)>, + scan_scheduler: Option<(Arc, u32)>, ) -> Result { - let mut reader = file_fragment - .open(projection.as_ref(), read_config, scan_scheduler) - .await?; + if let Some((scan_scheduler, reader_priority)) = scan_scheduler { + read_config = read_config + .with_scan_scheduler(scan_scheduler) + .with_reader_priority(reader_priority); + } + + let mut reader = file_fragment.open(projection.as_ref(), read_config).await?; if with_make_deletions_null { reader.with_make_deletions_null(); @@ -59,6 +68,20 @@ struct FragmentWithRange { range: Option>, } +struct ScanMetrics { + baseline_metrics: BaselineMetrics, + io_metrics: IoMetrics, +} + +impl ScanMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline_metrics: BaselineMetrics::new(metrics, partition), + io_metrics: IoMetrics::new(metrics, partition), + } + } +} + /// Dataset Scan Node. pub struct LanceStream { inner_stream: stream::BoxStream<'static, Result>, @@ -67,6 +90,13 @@ pub struct LanceStream { projection: Arc, config: LanceScanConfig, + + scan_metrics: ScanMetrics, + + /// Scan scheduler for the scan node. + /// + /// Only set on v2 scans. Used to record scan metrics. + scan_scheduler: Option>, } impl LanceStream { @@ -94,6 +124,8 @@ impl LanceStream { offsets: Option>, projection: Arc, config: LanceScanConfig, + metrics: &ExecutionPlanMetricsSet, + partition: usize, ) -> Result { let is_v2_scan = fragments .iter() @@ -101,9 +133,11 @@ impl LanceStream { .next() .unwrap_or(false); if is_v2_scan { - Self::try_new_v2(dataset, fragments, offsets, projection, config) + Self::try_new_v2( + dataset, fragments, offsets, projection, config, metrics, partition, + ) } else { - Self::try_new_v1(dataset, fragments, projection, config) + Self::try_new_v1(dataset, fragments, projection, config, metrics, partition) } } @@ -114,7 +148,11 @@ impl LanceStream { offsets: Option>, projection: Arc, config: LanceScanConfig, + metrics: &ExecutionPlanMetricsSet, + partition: usize, ) -> Result { + let scan_metrics = ScanMetrics::new(metrics, partition); + let timer = scan_metrics.baseline_metrics.elapsed_compute().timer(); let project_schema = projection.clone(); let io_parallelism = dataset.object_store.io_parallelism(); // First, use the value specified by the user in the call @@ -159,7 +197,7 @@ impl LanceStream { if let Some(next_frag) = frags_iter.next() { let num_rows_in_frag = next_frag .fragment - .count_rows() + .count_rows(None) // count_rows should be a fast operation in v2 files .now_or_never() .ok_or(Error::Internal { @@ -199,6 +237,8 @@ impl LanceStream { }, ); + let scan_scheduler_clone = scan_scheduler.clone(); + let batches = stream::iter(file_fragments.into_iter().enumerate()) .map(move |(priority, file_fragment)| { let project_schema = project_schema.clone(); @@ -214,7 +254,7 @@ impl LanceStream { .with_row_id(config.with_row_id) .with_row_address(config.with_row_address), config.with_make_deletions_null, - Some((scan_scheduler, priority as u64)), + Some((scan_scheduler, priority as u32)), ) .await?; let batch_stream = if let Some(range) = file_fragment.range { @@ -254,10 +294,13 @@ impl LanceStream { .stream_in_current_span() .boxed(); + timer.done(); Ok(Self { inner_stream: batches, projection, config, + scan_metrics, + scan_scheduler: Some(scan_scheduler_clone), }) } @@ -267,7 +310,11 @@ impl LanceStream { fragments: Arc>, projection: Arc, config: LanceScanConfig, + metrics: &ExecutionPlanMetricsSet, + partition: usize, ) -> Result { + let scan_metrics = ScanMetrics::new(metrics, partition); + let timer = scan_metrics.baseline_metrics.elapsed_compute().timer(); let project_schema = projection.clone(); let fragment_readahead = config .fragment_readahead @@ -347,10 +394,13 @@ impl LanceStream { .map(|batch| batch.map_err(DataFusionError::from)) .boxed(); + timer.done(); Ok(Self { inner_stream, projection, config, + scan_metrics, + scan_scheduler: None, }) } } @@ -382,7 +432,16 @@ impl Stream for LanceStream { type Item = std::result::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::into_inner(self).inner_stream.poll_next_unpin(cx) + let this = self.get_mut(); + let timer = this.scan_metrics.baseline_metrics.elapsed_compute().timer(); + let poll = Pin::new(&mut this.inner_stream).poll_next(cx); + timer.done(); + if matches!(poll, Poll::Ready(None)) { + if let Some(scan_scheduler) = this.scan_scheduler.as_ref() { + this.scan_metrics.io_metrics.record_final(scan_scheduler); + } + } + this.scan_metrics.baseline_metrics.record_poll(poll) } } @@ -425,6 +484,7 @@ pub struct LanceScanExec { output_schema: Arc, properties: PlanProperties, config: LanceScanConfig, + metrics: ExecutionPlanMetricsSet, } impl DisplayAs for LanceScanExec { @@ -476,7 +536,8 @@ impl LanceScanExec { let properties = PlanProperties::new( EquivalenceProperties::new(output_schema.clone()), Partitioning::RoundRobinBatch(1), - datafusion::physical_plan::ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Self { dataset, @@ -486,8 +547,29 @@ impl LanceScanExec { output_schema, properties, config, + metrics: ExecutionPlanMetricsSet::new(), } } + + /// Get the dataset for this scan. + pub fn dataset(&self) -> &Arc { + &self.dataset + } + + /// Get the fragments for this scan. + pub fn fragments(&self) -> &Arc> { + &self.fragments + } + + /// Get the range for this scan. + pub fn range(&self) -> &Option> { + &self.range + } + + /// Get the projection for this scan. + pub fn projection(&self) -> &Arc { + &self.projection + } } impl ExecutionPlan for LanceScanExec { @@ -523,16 +605,30 @@ impl ExecutionPlan for LanceScanExec { fn execute( &self, - _partition: usize, + partition: usize, _context: Arc, ) -> Result { - Ok(Box::pin(LanceStream::try_new( - self.dataset.clone(), - self.fragments.clone(), - self.range.clone(), - self.projection.clone(), - self.config.clone(), - )?)) + let dataset = self.dataset.clone(); + let fragments = self.fragments.clone(); + let range = self.range.clone(); + let projection = self.projection.clone(); + let config = self.config.clone(); + let metrics = self.metrics.clone(); + + let lance_fut_stream = stream::once(async move { + LanceStream::try_new( + dataset, fragments, range, projection, config, &metrics, partition, + ) + }); + let lance_stream = lance_fut_stream.try_flatten(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + lance_stream, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) } fn statistics(&self) -> datafusion::error::Result { @@ -562,3 +658,30 @@ impl ExecutionPlan for LanceScanExec { &self.properties } } + +#[cfg(test)] +mod tests { + use datafusion::execution::TaskContext; + + use crate::utils::test::NoContextTestFixture; + + use super::*; + + #[test] + fn no_context_scan() { + // These tests ensure we can create nodes and call execute without a tokio Runtime + // being active. This is a requirement for proper implementation of a Datafusion foreign + // table provider. + let fixture = NoContextTestFixture::new(); + + let scan = LanceScanExec::new( + Arc::new(fixture.dataset.clone()), + fixture.dataset.fragments().clone(), + None, + Arc::new(fixture.dataset.schema().clone()), + LanceScanConfig::default(), + ); + + scan.execute(0, Arc::new(TaskContext::default())).unwrap(); + } +} diff --git a/rust/lance/src/io/exec/take.rs b/rust/lance/src/io/exec/take.rs index bf7b808844f..b7593b782eb 100644 --- a/rust/lance/src/io/exec/take.rs +++ b/rust/lance/src/io/exec/take.rs @@ -1,196 +1,305 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::collections::HashSet; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow_array::{cast::as_primitive_array, RecordBatch, UInt64Array}; +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, Mutex}; + +use arrow::array::AsArray; +use arrow::compute::{concat_batches, TakeOptions}; +use arrow::datatypes::UInt64Type; +use arrow_array::{Array, UInt32Array}; +use arrow_array::{RecordBatch, UInt64Array}; use arrow_schema::{Schema as ArrowSchema, SchemaRef}; use datafusion::common::Statistics; use datafusion::error::{DataFusionError, Result}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet, +}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }; use datafusion_physical_expr::EquivalenceProperties; -use futures::stream::{self, Stream, StreamExt, TryStreamExt}; -use futures::{Future, FutureExt}; -use tokio::sync::mpsc::{self, Receiver}; -use tokio::task::JoinHandle; -use tracing::{instrument, Instrument}; - -use crate::dataset::{Dataset, ProjectionRequest, ROW_ID}; +use futures::stream::{FuturesOrdered, Stream, StreamExt, TryStreamExt}; +use futures::FutureExt; +use lance_arrow::RecordBatchExt; +use lance_core::datatypes::{Field, OnMissing, Projection}; +use lance_core::error::{DataFusionResult, LanceOptionExt}; +use lance_core::utils::address::RowAddress; +use lance_core::utils::futures::FinallyStreamExt; +use lance_core::utils::tokio::get_num_compute_intensive_cpus; +use lance_core::{ROW_ADDR, ROW_ID}; +use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; + +use crate::dataset::fragment::{FragReadConfig, FragmentReader}; +use crate::dataset::rowids::get_row_id_index; +use crate::dataset::Dataset; use crate::datatypes::Schema; -use crate::{arrow::*, Error}; -/// Dataset Take Node. -/// -/// [Take] node takes the filtered batch from the child node. -/// -/// It uses the `_rowid` to random access on [Dataset] to gather the final results. -pub struct Take { - rx: Receiver>, - bg_thread: Option>, +use super::utils::IoMetrics; - output_schema: SchemaRef, +#[derive(Debug, Clone)] +struct TakeStreamMetrics { + baseline_metrics: BaselineMetrics, + batches_processed: Count, + io_metrics: IoMetrics, +} + +impl TakeStreamMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let batches_processed = Count::new(); + MetricBuilder::new(metrics) + .with_partition(partition) + .build(MetricValue::Count { + name: Cow::Borrowed("batches_processed"), + count: batches_processed.clone(), + }); + Self { + baseline_metrics: BaselineMetrics::new(metrics, partition), + batches_processed, + io_metrics: IoMetrics::new(metrics, partition), + } + } } -impl Take { - /// Create a Take node with +struct TakeStream { + /// The dataset to take from + dataset: Arc, + /// The fields to take from the input stream + fields_to_take: Arc, + /// The output schema, needed for us to merge the new columns + /// into the input data in the correct order + output_schema: SchemaRef, + /// A cache of opened file readers /// - /// - Dataset: the dataset to read from - /// - projection: extra columns to take from the dataset. - /// - output_schema: the output schema of the take node. - /// - child: the upstream stream to feed data in. - /// - batch_readahead: max number of batches to readahead, potentially concurrently - #[instrument(level = "debug", skip_all, name = "Take::new")] + /// This is a map from fragment id to a reader. + readers_cache: Arc>>>, + /// The scan scheduler to use for reading fragments + scan_scheduler: Arc, + /// The metrics for the stream + metrics: TakeStreamMetrics, +} + +impl TakeStream { fn new( dataset: Arc, - projection: Arc, + fields_to_take: Arc, output_schema: SchemaRef, - child: SendableRecordBatchStream, - batch_readahead: usize, + scan_scheduler: Arc, + metrics: &ExecutionPlanMetricsSet, + partition: usize, ) -> Self { - let (tx, rx) = mpsc::channel(4); - - let bg_thread = tokio::spawn( - async move { - if let Err(e) = child - .zip(stream::repeat_with(|| { - (dataset.clone(), projection.clone()) - })) - .map(|(batch, (dataset, extra))| async move { - Self::take_batch(batch?, dataset, extra).await - }) - .buffered(batch_readahead) - .map(|r| r.map_err(|e| DataFusionError::Execution(e.to_string()))) - .try_for_each(|b| async { - if tx.send(Ok(b)).await.is_err() { - // If channel is closed, make sure we return an error to end the stream. - return Err(DataFusionError::Internal( - "ExecNode(Take): channel closed".to_string(), - )); - } - Ok(()) - }) - .await - { - if let Err(e) = tx.send(Err(e)).await { - if let Err(e) = e.0 { - // if channel was closed, it was cancelled by the receiver. - // But if there was a different error we should send it - // or log it. - if !e.to_string().contains("channel closed") { - log::error!("channel was closed by receiver, but error occurred in background thread: {:?}", e); - } - } - } - } - drop(tx) - } - .in_current_span(), - ); - Self { - rx, - bg_thread: Some(bg_thread), + dataset, + fields_to_take, output_schema, + readers_cache: Arc::new(Mutex::new(HashMap::new())), + scan_scheduler, + metrics: TakeStreamMetrics::new(metrics, partition), } } - /// Given a batch with a _rowid column, retrieve extra columns from dataset. - // This method mostly exists to annotate the Send bound so the compiler - // doesn't produce a higher-order lifetime error. - // manually implemented async for Send bound - #[allow(clippy::manual_async_fn)] - #[instrument(level = "debug", skip_all)] - fn take_batch( - batch: RecordBatch, - dataset: Arc, - extra: Arc, - ) -> impl Future> + Send { - async move { - let row_id_arr = batch.column_by_name(ROW_ID).unwrap(); - let row_ids: &UInt64Array = as_primitive_array(row_id_arr); - let rows = if extra.fields.is_empty() { - batch - } else { - let new_columns = dataset - .take_rows(row_ids.values(), ProjectionRequest::Schema(extra)) - .await?; - debug_assert_eq!(batch.num_rows(), new_columns.num_rows()); - batch.merge(&new_columns)? - }; - Ok::(rows) + async fn do_open_reader(&self, fragment_id: u32) -> DataFusionResult> { + let fragment = self + .dataset + .get_fragment(fragment_id as usize) + .ok_or_else(|| { + DataFusionError::Execution(format!("The input to a take operation specified fragment id {} but this fragment does not exist in the dataset", fragment_id)) + })?; + + let reader = Arc::new( + fragment + .open( + &self.fields_to_take, + FragReadConfig::default().with_scan_scheduler(self.scan_scheduler.clone()), + ) + .await?, + ); + + let mut readers = self.readers_cache.lock().unwrap(); + readers.insert(fragment_id, reader.clone()); + Ok(reader) + } + + async fn open_reader(&self, fragment_id: u32) -> DataFusionResult> { + if let Some(reader) = self + .readers_cache + .lock() + .unwrap() + .get(&fragment_id) + .cloned() + { + return Ok(reader); } - .in_current_span() + + self.do_open_reader(fragment_id).await } -} -impl Stream for Take { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - // We need to check the JoinHandle to make sure the thread hasn't panicked. - let bg_thread_completed = if let Some(bg_thread) = &mut this.bg_thread { - match bg_thread.poll_unpin(cx) { - Poll::Ready(Ok(())) => true, - Poll::Ready(Err(join_error)) => { - return Poll::Ready(Some(Err(DataFusionError::Execution(format!( - "ExecNode(Take): thread panicked: {}", - join_error - ))))); - } - Poll::Pending => false, + async fn get_row_addrs(&self, batch: &RecordBatch) -> Result> { + if let Some(row_addr_array) = batch.column_by_name(ROW_ADDR) { + Ok(row_addr_array.clone()) + } else { + let row_id_array = batch.column_by_name(ROW_ID).expect_ok()?; + + if let Some(row_id_index) = get_row_id_index(&self.dataset).await? { + let row_id_array = row_id_array.as_primitive::(); + let addresses = row_id_array + .values() + .iter() + .filter_map(|id| row_id_index.get(*id).map(|address| address.into())) + .collect::>(); + Ok(Arc::new(UInt64Array::from(addresses))) + } else { + Ok(row_id_array.clone()) } + } + } + + async fn map_batch( + self: Arc, + batch: RecordBatch, + batch_number: u32, + ) -> DataFusionResult { + let compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer(); + let row_addrs_arr = self.get_row_addrs(&batch).await?; + let row_addrs = row_addrs_arr.as_primitive::(); + + // Check if the row addresses are already sorted to avoid unnecessary reorders + let is_sorted = row_addrs.values().windows(2).all(|w| w[0] <= w[1]); + + let sorted_addrs: Arc; + let (sorted_addrs, permutation) = if is_sorted { + (row_addrs, None) } else { - false + let permutation = arrow::compute::sort_to_indices(&row_addrs_arr, None, None).unwrap(); + sorted_addrs = arrow::compute::take( + &row_addrs_arr, + &permutation, + Some(TakeOptions { + check_bounds: false, + }), + ) + .unwrap(); + // Calculate the inverse permutation to restore the original order + let mut inverse_permutation = vec![0; permutation.len()]; + for (i, p) in permutation.values().iter().enumerate() { + inverse_permutation[*p as usize] = i as u32; + } + ( + sorted_addrs.as_primitive::(), + Some(UInt32Array::from(inverse_permutation)), + ) }; - if bg_thread_completed { - // Need to take it, since we aren't allowed to poll if again after. - this.bg_thread.take(); + + let mut futures = FuturesOrdered::new(); + let mut current_offsets = Vec::new(); + let mut current_fragment_id = None; + + for row_addr in sorted_addrs.values() { + let addr = RowAddress::new_from_u64(*row_addr); + + if Some(addr.fragment_id()) != current_fragment_id { + // Start a new group + if let Some(fragment_id) = current_fragment_id { + let reader = self.open_reader(fragment_id).await?; + let offsets = std::mem::take(&mut current_offsets); + futures.push_back( + async move { reader.take_as_batch(&offsets, Some(batch_number)).await } + .boxed(), + ); + } + current_fragment_id = Some(addr.fragment_id()); + } + + current_offsets.push(addr.row_offset()); } - // this.rx. - this.rx.poll_recv(cx) + + // Handle the last group + if let Some(fragment_id) = current_fragment_id { + let reader = self.open_reader(fragment_id).await?; + futures.push_back( + async move { + reader + .take_as_batch(¤t_offsets, Some(batch_number)) + .await + } + .boxed(), + ); + } + + // Stop the compute timer, don't count I/O time + drop(compute_timer); + + let batches = futures.try_collect::>().await?; + + if batches.is_empty() { + return Ok(RecordBatch::new_empty(self.output_schema.clone())); + } + + let _compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer(); + let schema = batches.first().expect_ok()?.schema(); + let mut new_data = concat_batches(&schema, batches.iter())?; + + // Restore previous order (if addresses were out of order originally) + if let Some(permutation) = permutation { + new_data = arrow_select::take::take_record_batch(&new_data, &permutation).unwrap(); + } + + self.metrics + .baseline_metrics + .record_output(new_data.num_rows()); + self.metrics.batches_processed.add(1); + Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?) } -} -impl RecordBatchStream for Take { - fn schema(&self) -> SchemaRef { - self.output_schema.clone() + fn apply> + Send + 'static>( + self: Arc, + input: S, + ) -> impl Stream> { + let scan_scheduler = self.scan_scheduler.clone(); + let metrics = self.metrics.clone(); + let batches = input + .enumerate() + .map(move |(batch_index, batch)| { + let batch = batch?; + let this = self.clone(); + Ok( + tokio::task::spawn(this.map_batch(batch, batch_index as u32)) + .map(|res| res.unwrap()), + ) + }) + .boxed(); + batches + .try_buffered(get_num_compute_intensive_cpus()) + .finally(move || { + metrics.io_metrics.record_final(scan_scheduler.as_ref()); + }) } } -/// [`TakeExec`] is a [`ExecutionPlan`] that enriches the input [`RecordBatch`] -/// with extra columns from [`Dataset`]. -/// -/// The rows are identified by the inexplicit row IDs from `input` plan. -/// -/// The output schema will be the input schema, merged with extra schemas from the dataset. #[derive(Debug)] pub struct TakeExec { - /// Dataset to read from. + // The dataset to take from dataset: Arc, - - pub(crate) extra_schema: Arc, - + // The desired output projection of the relation (input schema + take schema) + // + // This is used to re-calculate output_projection and extra_schema when + // with_new_children is called. + output_projection: Projection, + // The schema of the extra columns to take from the dataset + schema_to_take: Arc, + // The schema of the output + output_schema: SchemaRef, input: Arc, - - /// Output schema is the merged schema between input schema and extra schema. - output_schema: Schema, - - batch_readahead: usize, - properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, } impl DisplayAs for TakeExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { let extra_fields = self - .extra_schema + .schema_to_take .fields .iter() .map(|f| f.name.clone()) @@ -202,10 +311,11 @@ impl DisplayAs for TakeExec { .fields .iter() .map(|f| { - if extra_fields.contains(&f.name) { - format!("({})", f.name.as_str()) + let name = f.name(); + if extra_fields.contains(name) { + format!("({})", name) } else { - f.name.clone() + name.to_string() } }) .collect::>() @@ -221,43 +331,114 @@ impl TakeExec { /// /// - dataset: the dataset to read from /// - input: the upstream [`ExecutionPlan`] to feed data in. - /// - extra_schema: the extra schema to take / read from the dataset. + /// - projection: the desired output projection, can overlap with the input schema if desired + /// + /// Returns None if no extra columns are required (everything in the projection exists in the input schema). pub fn try_new( dataset: Arc, input: Arc, - extra_schema: Arc, - batch_readahead: usize, - ) -> Result { - if input.schema().column_with_name(ROW_ID).is_none() { - return Err(DataFusionError::Plan( - "TakeExec requires the input plan to have a column named '_rowid'".to_string(), - )); + projection: Projection, + ) -> Result> { + let original_projection = projection.clone(); + let projection = + projection.subtract_arrow_schema(input.schema().as_ref(), OnMissing::Ignore)?; + if projection.is_empty() { + return Ok(None); } - let input_schema = Schema::try_from(input.schema().as_ref())?; - let output_schema = input_schema.merge(extra_schema.as_ref())?; + // We actually need a take so lets make sure we have a row id + if input.schema().column_with_name(ROW_ADDR).is_none() + && input.schema().column_with_name(ROW_ID).is_none() + { + return Err(DataFusionError::Plan(format!( + "TakeExec requires the input plan to have a column named '{}' or '{}'", + ROW_ADDR, ROW_ID + ))); + } - let remaining_schema = extra_schema.exclude(&input_schema)?; + // Can't use take if we don't want any fields and we can't use take to add row_id or row_addr + assert!( + !projection.with_row_id && !projection.with_row_addr, + "Take should not be used to insert row_id / row_addr: {:#?}", + projection + ); - let output_arrow = Arc::new(ArrowSchema::from(&output_schema)); + let output_schema = Arc::new(Self::calculate_output_schema( + dataset.schema(), + &input.schema(), + &projection, + )); + let output_arrow = Arc::new(ArrowSchema::from(output_schema.as_ref())); let properties = input .properties() .clone() - .with_eq_properties(EquivalenceProperties::new(output_arrow)); + .with_eq_properties(EquivalenceProperties::new(output_arrow.clone())); - Ok(Self { + Ok(Some(Self { dataset, - extra_schema: Arc::new(remaining_schema), + output_projection: original_projection, + schema_to_take: projection.into_schema_ref(), input, - output_schema, - batch_readahead, + output_schema: output_arrow, properties, - }) + metrics: ExecutionPlanMetricsSet::new(), + })) } - /// Get the dataset. + /// The output of a take operation will be all columns from the input schema followed + /// by any new columns from the dataset. + /// + /// The output fields will always be added in dataset schema order + /// + /// Nested columns in the input schema may have new fields inserted into them. /// - /// WARNING: Internal API with no stability guarantees. + /// If this happens the order of the new nested fields will match the order defined in + /// the dataset schema. + fn calculate_output_schema( + dataset_schema: &Schema, + input_schema: &ArrowSchema, + projection: &Projection, + ) -> Schema { + // TakeExec doesn't reorder top-level fields and so the first thing we need to do is determine the + // top-level field order. + let mut top_level_fields_added = HashSet::with_capacity(input_schema.fields.len()); + let projected_schema = projection.to_schema(); + + let mut output_fields = + Vec::with_capacity(input_schema.fields.len() + projected_schema.fields.len()); + // TakeExec always moves the _rowid to the start of the output schema + output_fields.extend(input_schema.fields.iter().map(|f| { + let f = Field::try_from(f.as_ref()).unwrap(); + if let Some(ds_field) = dataset_schema.field(&f.name) { + top_level_fields_added.insert(ds_field.id); + // Field is in the dataset, it might have new fields added to it + if let Some(projected_field) = ds_field.apply_projection(projection) { + f.merge_with_reference(&projected_field, ds_field) + } else { + // No new fields added, keep as-is + f + } + } else { + // Field not in dataset, not possible to add extra fields, use as-is + f + } + })); + + // Now we add to the end any brand new top-level fields. These will be added + // dataset schema order. + output_fields.extend( + projected_schema + .fields + .into_iter() + .filter(|f| !top_level_fields_added.contains(&f.id)), + ); + Schema { + fields: output_fields, + metadata: dataset_schema.metadata.clone(), + } + } + + /// Get the dataset. pub fn dataset(&self) -> &Arc { &self.dataset } @@ -273,13 +454,21 @@ impl ExecutionPlan for TakeExec { } fn schema(&self) -> SchemaRef { - ArrowSchema::from(&self.output_schema).into() + self.output_schema.clone() } fn children(&self) -> Vec<&Arc> { vec![&self.input] } + fn benefits_from_input_partitioning(&self) -> Vec { + // This is an I/O bound operation and wouldn't really benefit from partitioning + // + // Plus, if we did that, we would be creating multiple schedulers which could use + // a lot of RAM. + vec![false] + } + /// This preserves the output schema. fn with_new_children( self: Arc, @@ -291,18 +480,16 @@ impl ExecutionPlan for TakeExec { )); } - let child = &children[0]; - - let extra_schema = self.output_schema.exclude(child.schema().as_ref())?; + let projection = self.output_projection.clone(); - let plan = Self::try_new( - self.dataset.clone(), - children[0].clone(), - Arc::new(extra_schema), - self.batch_readahead, - )?; + let plan = Self::try_new(self.dataset.clone(), children[0].clone(), projection)?; - Ok(Arc::new(plan)) + if let Some(plan) = plan { + Ok(Arc::new(plan)) + } else { + // Is this legal or do we need to insert a no-op node? + Ok(children[0].clone()) + } } fn execute( @@ -311,15 +498,40 @@ impl ExecutionPlan for TakeExec { context: Arc, ) -> Result { let input_stream = self.input.execute(partition, context)?; - Ok(Box::pin(Take::new( - self.dataset.clone(), - self.extra_schema.clone(), - self.schema(), - input_stream, - self.batch_readahead, + let dataset = self.dataset.clone(); + let schema_to_take = self.schema_to_take.clone(); + let output_schema = self.output_schema.clone(); + let metrics = self.metrics.clone(); + + // ScanScheduler::new launches the I/O scheduler in the background. + // We aren't allowed to do work in `execute` and so we defer creation of the + // TakeStream until the stream is polled. + let lazy_take_stream = futures::stream::once(async move { + let obj_store = dataset.object_store.clone(); + let scheduler_config = SchedulerConfig::max_bandwidth(&obj_store); + let scan_scheduler = ScanScheduler::new(obj_store, scheduler_config); + + let take_stream = Arc::new(TakeStream::new( + dataset, + schema_to_take, + output_schema, + scan_scheduler, + &metrics, + partition, + )); + take_stream.apply(input_stream) + }); + let output_schema = self.output_schema.clone(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + output_schema, + lazy_take_stream.flatten(), ))) } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn statistics(&self) -> Result { Ok(Statistics { num_rows: self.input.statistics()?.num_rows, @@ -336,20 +548,40 @@ impl ExecutionPlan for TakeExec { mod tests { use super::*; - use arrow_array::{ArrayRef, Float32Array, Int32Array, RecordBatchIterator, StringArray}; - use arrow_schema::{DataType, Field}; - use tempfile::tempdir; + use arrow_array::{ + ArrayRef, Float32Array, Int32Array, RecordBatchIterator, StringArray, StructArray, + }; + use arrow_schema::{DataType, Field, Fields}; + use datafusion::execution::TaskContext; + use lance_arrow::SchemaExt; + use lance_core::{datatypes::OnMissing, ROW_ID}; + use lance_datafusion::{datagen::DatafusionDatagenExt, exec::OneShotExec, utils::MetricsExt}; + use lance_datagen::{BatchCount, RowCount}; + use rstest::rstest; + use tempfile::{tempdir, TempDir}; use crate::{ dataset::WriteParams, io::exec::{LanceScanConfig, LanceScanExec}, + utils::test::NoContextTestFixture, }; - async fn create_dataset() -> Arc { + struct TestFixture { + dataset: Arc, + _tmp_dir_guard: TempDir, + } + + async fn test_fixture() -> TestFixture { + let struct_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, false)), + Arc::new(Field::new("y", DataType::Int32, false)), + ]); + let schema = Arc::new(ArrowSchema::new(vec![ Field::new("i", DataType::Int32, false), Field::new("f", DataType::Float32, false), Field::new("s", DataType::Utf8, false), + Field::new("struct", DataType::Struct(struct_fields.clone()), false), ])); // Write 3 batches. @@ -362,7 +594,15 @@ mod tests { value_range.clone().map(|v| v as f32), )), Arc::new(StringArray::from_iter_values( - value_range.map(|v| format!("str-{v}")), + value_range.clone().map(|v| format!("str-{v}")), + )), + Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from_iter(value_range.clone())), + Arc::new(Int32Array::from_iter(value_range)), + ], + None, )), ]; RecordBatch::try_new(schema.clone(), columns).unwrap() @@ -372,7 +612,7 @@ mod tests { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); let params = WriteParams { - max_rows_per_group: 10, + max_rows_per_file: 10, ..Default::default() }; let reader = @@ -381,19 +621,19 @@ mod tests { .await .unwrap(); - Arc::new(Dataset::open(test_uri).await.unwrap()) + TestFixture { + dataset: Arc::new(Dataset::open(test_uri).await.unwrap()), + _tmp_dir_guard: test_dir, + } } #[tokio::test] async fn test_take_schema() { - let dataset = create_dataset().await; + let TestFixture { dataset, .. } = test_fixture().await; let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]); let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap()); - let extra_arrow_schema = ArrowSchema::new(vec![Field::new("s", DataType::Int32, false)]); - let extra_schema = Arc::new(Schema::try_from(&extra_arrow_schema).unwrap()); - // With row id let config = LanceScanConfig { with_row_id: true, @@ -406,7 +646,14 @@ mod tests { scan_schema, config, )); - let take_exec = TakeExec::try_new(dataset, input, extra_schema, 10).unwrap(); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); let schema = take_exec.schema(); assert_eq!( schema.fields.iter().map(|f| f.name()).collect::>(), @@ -414,22 +661,144 @@ mod tests { ); } + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum TakeInput { + Ids, + Addrs, + IdsAndAddrs, + } + + #[rstest] #[tokio::test] - async fn test_take_no_extra_columns() { - let dataset = create_dataset().await; + async fn test_simple_take( + #[values(TakeInput::Ids, TakeInput::Addrs, TakeInput::IdsAndAddrs)] take_input: TakeInput, + ) { + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; - let scan_arrow_schema = ArrowSchema::new(vec![ - Field::new("i", DataType::Int32, false), - Field::new("s", DataType::Int32, false), - ]); - let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap()); + let scan_schema = Arc::new(dataset.schema().project(&["i"]).unwrap()); + let config = LanceScanConfig { + with_row_address: take_input != TakeInput::Ids, + with_row_id: take_input != TakeInput::Addrs, + ..Default::default() + }; + let input = Arc::new(LanceScanExec::new( + dataset.clone(), + dataset.fragments().clone(), + None, + scan_schema, + config, + )); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); + let schema = take_exec.schema(); + + let mut expected_fields = vec!["i"]; + if take_input != TakeInput::Addrs { + expected_fields.push(ROW_ID); + } + if take_input != TakeInput::Ids { + expected_fields.push(ROW_ADDR); + } + expected_fields.push("s"); + assert_eq!(&schema.field_names(), &expected_fields); + + let mut stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + + while let Some(batch) = stream.try_next().await.unwrap() { + assert_eq!(&batch.schema().field_names(), &expected_fields); + } + } + + #[tokio::test] + async fn test_take_order() { + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; + + // Grab all row addresses, shuffle them, and select the first 15 (half of the rows) + let data = dataset + .scan() + .project(&["s"]) + .unwrap() + .with_row_address() + .try_into_batch() + .await + .unwrap(); + let indices = UInt64Array::from(vec![8, 13, 1, 7, 4, 5, 12, 9, 10, 2, 11, 6, 3, 0, 28]); + let data = arrow_select::take::take_record_batch(&data, &indices).unwrap(); + + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + ROW_ADDR, + DataType::UInt64, + true, + )])); + let row_addrs = data.project_by_schema(&schema).unwrap(); + + // Split into 3 batches of 5 + let batches = (0..3) + .map(|i| { + let start = i * 5; + row_addrs.slice(start, 5) + }) + .collect::>(); + + let row_addr_stream = futures::stream::iter(batches.clone().into_iter().map(Ok)); + let row_addr_stream = Box::pin(RecordBatchStreamAdapter::new(schema, row_addr_stream)); + + let input = Arc::new(OneShotExec::new(row_addr_stream)); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); + + let stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + + let expected = vec![data.slice(0, 5), data.slice(5, 5), data.slice(10, 5)]; + + let batches = stream.try_collect::>().await.unwrap(); + assert_eq!(batches.len(), 3); + for (batch, expected) in batches.into_iter().zip(expected) { + assert_eq!(batch.schema().field_names(), vec![ROW_ADDR, "s"]); + let expected = expected.project_by_schema(&batch.schema()).unwrap(); + assert_eq!(batch, expected); + } + + let metrics = take_exec.metrics().unwrap(); + assert_eq!(metrics.output_rows(), Some(15)); + assert_eq!(metrics.find_count("batches_processed").unwrap().value(), 3); + } + + #[tokio::test] + async fn test_take_struct() { + // When taking fields into an existing struct, the field order should be maintained + // according the the schema of the struct. + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; - // Extra column is already read. - let extra_arrow_schema = ArrowSchema::new(vec![Field::new("s", DataType::Int32, false)]); - let extra_schema = Arc::new(Schema::try_from(&extra_arrow_schema).unwrap()); + let scan_schema = Arc::new(dataset.schema().project(&["struct.y"]).unwrap()); let config = LanceScanConfig { - with_row_id: true, + with_row_address: true, ..Default::default() }; let input = Arc::new(LanceScanExec::new( @@ -439,28 +808,52 @@ mod tests { scan_schema, config, )); - let take_exec = TakeExec::try_new(dataset, input, extra_schema, 10).unwrap(); + + let projection = dataset + .empty_projection() + .union_column("struct.x", OnMissing::Error) + .unwrap(); + + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); + + let expected_schema = ArrowSchema::new(vec![ + Field::new( + "struct", + DataType::Struct(Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, false)), + Arc::new(Field::new("y", DataType::Int32, false)), + ])), + false, + ), + Field::new(ROW_ADDR, DataType::UInt64, true), + ]); let schema = take_exec.schema(); - assert_eq!( - schema.fields.iter().map(|f| f.name()).collect::>(), - vec!["i", "s", ROW_ID] - ); + assert_eq!(schema.as_ref(), &expected_schema); + + let mut stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + + while let Some(batch) = stream.try_next().await.unwrap() { + assert_eq!(batch.schema().as_ref(), &expected_schema); + } } #[tokio::test] - async fn test_take_no_row_id() { - let dataset = create_dataset().await; + async fn test_take_no_row_addr() { + let TestFixture { dataset, .. } = test_fixture().await; - let scan_arrow_schema = ArrowSchema::new(vec![ - Field::new("i", DataType::Int32, false), - Field::new("s", DataType::Int32, false), - ]); + let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]); let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap()); - let extra_arrow_schema = ArrowSchema::new(vec![Field::new("s", DataType::Int32, false)]); - let extra_schema = Arc::new(Schema::try_from(&extra_arrow_schema).unwrap()); + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); - // No row ID + // No row address let input = Arc::new(LanceScanExec::new( dataset.clone(), dataset.fragments().clone(), @@ -468,38 +861,43 @@ mod tests { scan_schema, LanceScanConfig::default(), )); - assert!(TakeExec::try_new(dataset, input, extra_schema, 10).is_err()); + assert!(TakeExec::try_new(dataset, input, projection).is_err()); } #[tokio::test] async fn test_with_new_children() -> Result<()> { - let dataset = create_dataset().await; + let TestFixture { dataset, .. } = test_fixture().await; let config = LanceScanConfig { with_row_id: true, ..Default::default() }; + + let input_schema = Arc::new(dataset.schema().project(&["i"])?); + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let input = Arc::new(LanceScanExec::new( dataset.clone(), dataset.fragments().clone(), None, - Arc::new(dataset.schema().project(&["i"])?), + input_schema, config, )); + assert_eq!(input.schema().field_names(), vec!["i", ROW_ID],); - let take_exec = TakeExec::try_new( - dataset.clone(), - input.clone(), - Arc::new(dataset.schema().project(&["s"])?), - 10, - )?; + let take_exec = TakeExec::try_new(dataset.clone(), input.clone(), projection)?.unwrap(); assert_eq!(take_exec.schema().field_names(), vec!["i", ROW_ID, "s"],); - let outer_take = Arc::new(TakeExec::try_new( - dataset.clone(), - Arc::new(take_exec), - Arc::new(dataset.schema().project(&["f"])?), - 10, - )?); + + let projection = dataset + .empty_projection() + .union_columns(["s", "f"], OnMissing::Error) + .unwrap(); + + let outer_take = + Arc::new(TakeExec::try_new(dataset, Arc::new(take_exec), projection)?.unwrap()); assert_eq!( outer_take.schema().field_names(), vec!["i", ROW_ID, "s", "f"], @@ -507,7 +905,33 @@ mod tests { // with_new_children should preserve the output schema. let edited = outer_take.with_new_children(vec![input])?; - assert_eq!(edited.schema().field_names(), vec!["i", ROW_ID, "s", "f"],); + assert_eq!(edited.schema().field_names(), vec!["i", ROW_ID, "f", "s"],); Ok(()) } + + #[test] + fn no_context_take() { + // These tests ensure we can create nodes and call execute without a tokio Runtime + // being active. This is a requirement for proper implementation of a Datafusion foreign + // table provider. + let fixture = NoContextTestFixture::new(); + let arc_dasaset = Arc::new(fixture.dataset); + + let input = lance_datagen::gen() + .col(ROW_ID, lance_datagen::array::step::()) + .into_df_exec(RowCount::from(50), BatchCount::from(2)); + + let take = TakeExec::try_new( + arc_dasaset.clone(), + input, + arc_dasaset + .empty_projection() + .union_column("text", OnMissing::Error) + .unwrap(), + ) + .unwrap() + .unwrap(); + + take.execute(0, Arc::new(TaskContext::default())).unwrap(); + } } diff --git a/rust/lance/src/io/exec/testing.rs b/rust/lance/src/io/exec/testing.rs index 2864acde038..23a69ed7d05 100644 --- a/rust/lance/src/io/exec/testing.rs +++ b/rust/lance/src/io/exec/testing.rs @@ -8,15 +8,17 @@ use std::any::Any; use std::sync::Arc; use arrow_array::RecordBatch; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::{ common::Statistics, execution::context::TaskContext, physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, PlanProperties, - SendableRecordBatchStream, + execution_plan::{Boundedness, EmissionType}, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }, }; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use futures::StreamExt; #[derive(Debug)] pub struct TestingExec { @@ -29,7 +31,8 @@ impl TestingExec { let properties = PlanProperties::new( EquivalenceProperties::new(batches[0].schema()), Partitioning::RoundRobinBatch(1), - ExecutionMode::Bounded, + EmissionType::Incremental, + Boundedness::Bounded, ); Self { batches, @@ -75,7 +78,9 @@ impl ExecutionPlan for TestingExec { _partition: usize, _context: Arc, ) -> datafusion::error::Result { - todo!() + let stream = futures::stream::iter(self.batches.clone().into_iter().map(Ok)); + let stream = RecordBatchStreamAdapter::new(self.schema(), stream.boxed()); + Ok(Box::pin(stream)) } fn statistics(&self) -> datafusion::error::Result { diff --git a/rust/lance/src/io/exec/utils.rs b/rust/lance/src/io/exec/utils.rs index 72b42dbf66b..efe6928e851 100644 --- a/rust/lance/src/io/exec/utils.rs +++ b/rust/lance/src/io/exec/utils.rs @@ -1,12 +1,25 @@ +use lance_datafusion::utils::{ + ExecutionPlanMetricsSetExt, BYTES_READ_METRIC, INDEX_COMPARISONS_METRIC, INDICES_LOADED_METRIC, + IOPS_METRIC, PARTS_LOADED_METRIC, REQUESTS_METRIC, +}; +use lance_index::metrics::MetricsCollector; +use lance_io::scheduler::ScanScheduler; +use lance_table::format::Index; // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use pin_project::pin_project; +use std::borrow::Cow; use std::sync::{Arc, Mutex}; +use std::task::Poll; use arrow::array::AsArray; use arrow_array::{RecordBatch, UInt64Array}; use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion::error::{DataFusionError, Result as DataFusionResult}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, +}; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, }; @@ -16,7 +29,10 @@ use lance_core::utils::futures::{Capacity, SharedStreamExt}; use lance_core::utils::mask::{RowIdMask, RowIdTreeMap}; use lance_core::{Result, ROW_ID}; use lance_index::prefilter::FilterLoader; -use snafu::{location, Location}; +use snafu::location; + +use crate::index::prefilter::DatasetPreFilter; +use crate::Dataset; #[derive(Debug, Clone)] pub enum PreFilterSource { @@ -28,6 +44,31 @@ pub enum PreFilterSource { None, } +pub(crate) fn build_prefilter( + context: Arc, + partition: usize, + prefilter_source: &PreFilterSource, + ds: Arc, + index_meta: &[Index], +) -> Result> { + let prefilter_loader = match &prefilter_source { + PreFilterSource::FilteredRowIds(src_node) => { + let stream = src_node.execute(partition, context)?; + Some(Box::new(FilteredRowIdsToPrefilter(stream)) as Box) + } + PreFilterSource::ScalarIndexQuery(src_node) => { + let stream = src_node.execute(partition, context)?; + Some(Box::new(SelectionVectorToPrefilter(stream)) as Box) + } + PreFilterSource::None => None, + }; + Ok(Arc::new(DatasetPreFilter::new( + ds, + index_meta, + prefilter_loader, + ))) +} + // Utility to convert an input (containing row ids) into a prefilter pub(crate) struct FilteredRowIdsToPrefilter(pub SendableRecordBatchStream); @@ -198,6 +239,69 @@ impl> + Unpin> RecordBatchStream } } +#[pin_project] +pub struct InstrumentedRecordBatchStreamAdapter { + schema: SchemaRef, + + #[pin] + stream: S, + baseline_metrics: BaselineMetrics, + batch_count: Count, +} + +impl InstrumentedRecordBatchStreamAdapter { + pub fn new( + schema: SchemaRef, + stream: S, + partition: usize, + metrics: &ExecutionPlanMetricsSet, + ) -> Self { + let batch_count = Count::new(); + MetricBuilder::new(metrics) + .with_partition(partition) + .build(MetricValue::Count { + name: Cow::Borrowed("output_batches"), + count: batch_count.clone(), + }); + Self { + schema, + stream, + baseline_metrics: BaselineMetrics::new(metrics, partition), + batch_count, + } + } +} + +impl Stream for InstrumentedRecordBatchStreamAdapter +where + S: Stream>, +{ + type Item = DataFusionResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.as_mut().project(); + let timer = this.baseline_metrics.elapsed_compute().timer(); + let poll = this.stream.poll_next(cx); + timer.done(); + if let Poll::Ready(Some(Ok(_))) = &poll { + this.batch_count.add(1); + } + this.baseline_metrics.record_poll(poll) + } +} + +impl RecordBatchStream for InstrumentedRecordBatchStreamAdapter +where + S: Stream>, +{ + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + impl ExecutionPlan for ReplayExec { fn name(&self) -> &str { "ReplayExec" @@ -222,6 +326,12 @@ impl ExecutionPlan for ReplayExec { unimplemented!() } + fn benefits_from_input_partitioning(&self) -> Vec { + // We aren't doing any work here, and it would be a little confusing + // to have multiple replay queues. + vec![false] + } + fn execute( &self, partition: usize, @@ -255,6 +365,61 @@ impl ExecutionPlan for ReplayExec { } } +#[derive(Debug, Clone)] +pub struct IoMetrics { + iops: Count, + requests: Count, + bytes_read: Count, +} + +impl IoMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let iops = metrics.new_count(IOPS_METRIC, partition); + let requests = metrics.new_count(REQUESTS_METRIC, partition); + let bytes_read = metrics.new_count(BYTES_READ_METRIC, partition); + Self { + iops, + requests, + bytes_read, + } + } + + pub fn record_final(&self, scan_scheduler: &ScanScheduler) { + let stats = scan_scheduler.stats(); + self.iops.add(stats.iops as usize); + self.requests.add(stats.requests as usize); + self.bytes_read.add(stats.bytes_read as usize); + } +} + +pub struct IndexMetrics { + indices_loaded: Count, + parts_loaded: Count, + index_comparisons: Count, +} + +impl IndexMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + indices_loaded: metrics.new_count(INDICES_LOADED_METRIC, partition), + parts_loaded: metrics.new_count(PARTS_LOADED_METRIC, partition), + index_comparisons: metrics.new_count(INDEX_COMPARISONS_METRIC, partition), + } + } +} + +impl MetricsCollector for IndexMetrics { + fn record_parts_loaded(&self, num_shards: usize) { + self.parts_loaded.add(num_shards); + } + fn record_index_loads(&self, num_indexes: usize) { + self.indices_loaded.add(num_indexes); + } + fn record_comparisons(&self, num_comparisons: usize) { + self.index_comparisons.add(num_comparisons); + } +} + #[cfg(test)] mod tests { diff --git a/rust/lance/src/session.rs b/rust/lance/src/session.rs index 6a978b0cb28..7b95eb6ca22 100644 --- a/rust/lance/src/session.rs +++ b/rust/lance/src/session.rs @@ -8,7 +8,8 @@ use deepsize::DeepSizeOf; use lance_core::cache::FileMetadataCache; use lance_core::{Error, Result}; use lance_index::IndexType; -use snafu::{location, Location}; +use lance_io::object_store::ObjectStoreRegistry; +use snafu::location; use crate::dataset::{DEFAULT_INDEX_CACHE_SIZE, DEFAULT_METADATA_CACHE_SIZE}; use crate::index::cache::IndexCache; @@ -18,7 +19,7 @@ use self::index_extension::IndexExtension; pub mod index_extension; /// A user session tracks the runtime state. -#[derive(Clone, DeepSizeOf)] +#[derive(Clone)] pub struct Session { /// Cache for opened indices. pub(crate) index_cache: IndexCache, @@ -27,6 +28,20 @@ pub struct Session { pub(crate) file_metadata_cache: FileMetadataCache, pub(crate) index_extensions: HashMap<(IndexType, String), Arc>, + + store_registry: Arc, +} + +impl DeepSizeOf for Session { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + let mut size = 0; + size += self.index_cache.deep_size_of_children(context); + size += self.file_metadata_cache.deep_size_of_children(context); + for ext in self.index_extensions.values() { + size += ext.deep_size_of_children(context); + } + size + } } impl std::fmt::Debug for Session { @@ -34,18 +49,13 @@ impl std::fmt::Debug for Session { f.debug_struct("Session") .field( "index_cache", - &format!( - "IndexCache(items={}, size_bytes={})", - self.index_cache.get_size(), - self.index_cache.deep_size_of() - ), + &format!("IndexCache(items={})", self.index_cache.approx_size(),), ) .field( "file_metadata_cache", &format!( - "FileMetadataCache(items={}, size_bytes={})", - self.file_metadata_cache.size(), - self.file_metadata_cache.deep_size_of() + "FileMetadataCache(items={})", + self.file_metadata_cache.approx_size(), ), ) .field( @@ -62,11 +72,20 @@ impl Session { /// Parameters: /// /// - ***index_cache_size***: the size of the index cache. - pub fn new(index_cache_size: usize, metadata_cache_size: usize) -> Self { + /// - ***metadata_cache_size***: the size of the metadata cache. + /// - ***store_registry***: the object store registry to use when opening + /// datasets. This determines which schemes are available, and also allows + /// re-using object stores. + pub fn new( + index_cache_size: usize, + metadata_cache_size: usize, + store_registry: Arc, + ) -> Self { Self { index_cache: IndexCache::new(index_cache_size), file_metadata_cache: FileMetadataCache::new(metadata_cache_size), index_extensions: HashMap::new(), + store_registry, } } @@ -125,6 +144,17 @@ impl Session { // need the deepsize crate themselves (e.g. to use deep_size_of) self.deep_size_of() as u64 } + + pub fn approx_num_items(&self) -> usize { + self.index_cache.approx_size() + + self.file_metadata_cache.approx_size() + + self.index_extensions.len() + } + + /// Get the object store registry. + pub fn store_registry(&self) -> Arc { + self.store_registry.clone() + } } impl Default for Session { @@ -133,6 +163,7 @@ impl Default for Session { index_cache: IndexCache::new(DEFAULT_INDEX_CACHE_SIZE), file_metadata_cache: FileMetadataCache::new(DEFAULT_METADATA_CACHE_SIZE), index_extensions: HashMap::new(), + store_registry: Arc::new(ObjectStoreRegistry::default()), } } } @@ -151,7 +182,7 @@ mod tests { #[test] fn test_disable_index_cache() { - let no_cache = Session::new(0, 0); + let no_cache = Session::new(0, 0, Default::default()); assert!(no_cache.index_cache.get_vector("abc").is_none()); let no_cache = Arc::new(no_cache); @@ -172,7 +203,7 @@ mod tests { #[test] fn test_basic() { - let session = Session::new(10, 1); + let session = Session::new(10, 1, Default::default()); let session = Arc::new(session); let pq = ProductQuantizer::new( diff --git a/rust/lance/src/session/index_extension.rs b/rust/lance/src/session/index_extension.rs index 2080397eaea..8219a061090 100644 --- a/rust/lance/src/session/index_extension.rs +++ b/rust/lance/src/session/index_extension.rs @@ -67,12 +67,16 @@ mod test { use arrow_array::{RecordBatch, UInt32Array}; use arrow_schema::Schema; + use datafusion::execution::SendableRecordBatchStream; use deepsize::DeepSizeOf; use lance_file::version::LanceFileVersion; use lance_file::writer::{FileWriter, FileWriterOptions}; - use lance_index::vector::ivf::storage::IvfModel; - use lance_index::vector::quantizer::{QuantizationType, Quantizer}; use lance_index::vector::v3::subindex::SubIndexType; + use lance_index::{ + metrics::MetricsCollector, + vector::quantizer::{QuantizationType, Quantizer}, + }; + use lance_index::{metrics::NoOpMetricsCollector, vector::ivf::storage::IvfModel}; use lance_index::{ vector::{hnsw::VECTOR_ID_FIELD, Query}, DatasetIndexExt, Index, IndexMetadata, IndexType, INDEX_FILE_NAME, @@ -108,6 +112,10 @@ mod test { Ok(self) } + async fn prewarm(&self) -> Result<()> { + Ok(()) + } + fn statistics(&self) -> Result { Ok(json!(())) } @@ -123,7 +131,12 @@ mod test { #[async_trait::async_trait] impl VectorIndex for MockIndex { - async fn search(&self, _: &Query, _: Arc) -> Result { + async fn search( + &self, + _: &Query, + _: Arc, + _: &dyn MetricsCollector, + ) -> Result { unimplemented!() } @@ -136,6 +149,7 @@ mod test { _: usize, _: &Query, _: Arc, + _: &dyn MetricsCollector, ) -> Result { unimplemented!() } @@ -161,15 +175,23 @@ mod test { unimplemented!() } + fn num_rows(&self) -> u64 { + unimplemented!() + } + fn row_ids(&self) -> Box> { unimplemented!() } - fn remap(&mut self, _: &HashMap>) -> Result<()> { + async fn remap(&mut self, _: &HashMap>) -> Result<()> { Ok(()) } - fn ivf_model(&self) -> IvfModel { + async fn to_batch_stream(&self, _with_vector: bool) -> Result { + unimplemented!() + } + + fn ivf_model(&self) -> &IvfModel { unimplemented!() } fn quantizer(&self) -> Quantizer { @@ -360,7 +382,7 @@ mod test { // trying to open the index should fail as there is no extension loader assert!(ds_without_extension - .open_vector_index("vec", &index_uuid) + .open_vector_index("vec", &index_uuid, &NoOpMetricsCollector) .await .unwrap_err() .to_string() @@ -368,7 +390,7 @@ mod test { // trying to open the index should succeed with the extension loader let vector_index = ds_with_extension - .open_vector_index("vec", &index_uuid) + .open_vector_index("vec", &index_uuid, &NoOpMetricsCollector) .await .unwrap(); diff --git a/rust/lance/src/utils/future.rs b/rust/lance/src/utils/future.rs index 227bb37f81b..9f06030810e 100644 --- a/rust/lance/src/utils/future.rs +++ b/rust/lance/src/utils/future.rs @@ -3,7 +3,7 @@ use async_cell::sync::AsyncCell; use futures::Future; -use snafu::{location, Location}; +use snafu::location; use std::sync::Arc; /// An async background task whose output can be shared across threads (via cloning) @@ -11,7 +11,7 @@ use std::sync::Arc; /// SharedPrerequisite is very similar to a shared future except: /// * It must be created by spawning a new task (runs in the background) /// * Shared future doesn't support Result. This class handles errors by -/// serializing them to string. +/// serializing them to string. /// * This class can optionally cache the output so that it can be accessed synchronously pub struct SharedPrerequisite(Arc>>); diff --git a/rust/lance/src/utils/test.rs b/rust/lance/src/utils/test.rs index 5f7ef481ff7..324376d2812 100644 --- a/rust/lance/src/utils/test.rs +++ b/rust/lance/src/utils/test.rs @@ -11,9 +11,9 @@ use bytes::Bytes; use futures::stream::BoxStream; use lance_arrow::RecordBatchExt; use lance_core::datatypes::Schema; -use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount}; +use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount}; use lance_file::version::LanceFileVersion; -use lance_io::object_store::{ObjectStoreRegistry, WrappingObjectStore}; +use lance_io::object_store::WrappingObjectStore; use lance_table::format::Fragment; use object_store::path::Path; use object_store::{ @@ -22,6 +22,7 @@ use object_store::{ }; use rand::prelude::SliceRandom; use rand::{Rng, SeedableRng}; +use tempfile::{tempdir, TempDir}; use crate::dataset::fragment::write::FragmentCreateBuilder; use crate::dataset::transaction::Operation; @@ -117,14 +118,13 @@ impl TestDatasetGenerator { config_upsert_values: None, }; - let registry = Arc::new(ObjectStoreRegistry::default()); Dataset::commit( uri, operation, None, Default::default(), None, - registry, + Default::default(), false, ) .await @@ -325,7 +325,13 @@ impl IoTrackingStore { } #[async_trait::async_trait] +#[deny(clippy::missing_trait_methods)] impl ObjectStore for IoTrackingStore { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { + self.record_write(bytes.content_length() as u64); + self.target.put(location, bytes).await + } + async fn put_opts( &self, location: &Path, @@ -336,6 +342,14 @@ impl ObjectStore for IoTrackingStore { self.target.put_opts(location, bytes, opts).await } + async fn put_multipart(&self, location: &Path) -> OSResult> { + let target = self.target.put_multipart(location).await?; + Ok(Box::new(IoTrackingMultipartUpload { + target, + stats: self.stats.clone(), + })) + } + async fn put_multipart_opts( &self, location: &Path, @@ -348,6 +362,15 @@ impl ObjectStore for IoTrackingStore { })) } + async fn get(&self, location: &Path) -> OSResult { + let result = self.target.get(location).await; + if let Ok(result) = &result { + let num_bytes = result.range.end - result.range.start; + self.record_read(num_bytes as u64); + } + result + } + async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { let result = self.target.get_opts(location, options).await; if let Ok(result) = &result { @@ -379,6 +402,7 @@ impl ObjectStore for IoTrackingStore { } async fn delete(&self, location: &Path) -> OSResult<()> { + self.record_write(0); self.target.delete(location).await } @@ -394,6 +418,15 @@ impl ObjectStore for IoTrackingStore { self.target.list(prefix) } + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'_, OSResult> { + self.record_read(0); + self.target.list_with_offset(prefix, offset) + } + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { self.record_read(0); self.target.list_with_delimiter(prefix).await @@ -409,6 +442,11 @@ impl ObjectStore for IoTrackingStore { self.target.rename(from, to).await } + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.record_write(0); + self.target.rename_if_not_exists(from, to).await + } + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { self.record_write(0); self.target.copy_if_not_exists(from, to).await @@ -491,6 +529,36 @@ impl DatagenExt for BatchGeneratorBuilder { } } +pub struct NoContextTestFixture { + _tmp_dir: TempDir, + pub dataset: Dataset, +} + +impl NoContextTestFixture { + pub fn new() -> Self { + let runtime = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + + runtime.block_on(async move { + let tempdir = tempdir().unwrap(); + let tmppath = tempdir.path().to_str().unwrap(); + let dataset = lance_datagen::gen() + .col( + "text", + lance_datagen::array::rand_utf8(ByteCount::from(10), false), + ) + .into_dataset(tmppath, FragmentCount::from(4), FragmentRowCount::from(100)) + .await + .unwrap(); + Self { + dataset, + _tmp_dir: tempdir, + } + }) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/rust/lance/src/utils/tfrecord.rs b/rust/lance/src/utils/tfrecord.rs index a076d728133..92b99a2f234 100644 --- a/rust/lance/src/utils/tfrecord.rs +++ b/rust/lance/src/utils/tfrecord.rs @@ -17,7 +17,7 @@ use datafusion::physical_plan::SendableRecordBatchStream; use futures::{StreamExt, TryStreamExt}; use half::{bf16, f16}; use lance_arrow::bfloat16::{ARROW_EXT_META_KEY, ARROW_EXT_NAME_KEY, BFLOAT16_EXT_NAME}; -use prost::Message; +use prost_old::Message; use std::collections::HashMap; use std::sync::Arc; @@ -32,6 +32,20 @@ use tfrecord::protobuf::feature::Kind; use tfrecord::protobuf::{DataType as TensorDataType, TensorProto}; use tfrecord::record_reader::RecordStream; use tfrecord::{Example, Feature}; + +trait OldProstResultExt { + fn map_prost_err(self, location: Location) -> Result; +} + +impl OldProstResultExt for std::result::Result { + fn map_prost_err(self, location: Location) -> Result { + self.map_err(|err| Error::IO { + source: Box::new(err), + location, + }) + } +} + /// Infer the Arrow schema from a TFRecord file. /// /// The featured named by `tensor_features` will be assumed to be binary fields @@ -224,7 +238,7 @@ impl FeatureMeta { } fn extract_tensor(data: &[u8]) -> Result { - let tensor_proto = TensorProto::decode(data)?; + let tensor_proto = TensorProto::decode(data).map_prost_err(location!())?; Ok(FeatureType::Tensor { shape: tensor_proto .tensor_shape @@ -617,7 +631,7 @@ fn convert_fixedshape_tensor( DataType::Float16 => { let mut values = Float16Builder::with_capacity(features.len()); for tensors in tensor_iter { - if let Some(tensors) = tensors? { + if let Some(tensors) = tensors.map_prost_err(location!())? { for tensor in tensors { validate_tensor(&tensor, type_info)?; if tensor.half_val.is_empty() { @@ -645,7 +659,7 @@ fn convert_fixedshape_tensor( let mut values = FixedSizeBinaryBuilder::with_capacity(features.len(), 2); for tensors in tensor_iter { - if let Some(tensors) = tensors? { + if let Some(tensors) = tensors.map_prost_err(location!())? { for tensor in tensors { validate_tensor(&tensor, type_info)?; if tensor.half_val.is_empty() { @@ -673,7 +687,7 @@ fn convert_fixedshape_tensor( DataType::Float32 => { let mut values = Float32Builder::with_capacity(features.len()); for tensors in tensor_iter { - if let Some(tensors) = tensors? { + if let Some(tensors) = tensors.map_prost_err(location!())? { for tensor in tensors { validate_tensor(&tensor, type_info)?; if tensor.float_val.is_empty() { @@ -695,7 +709,7 @@ fn convert_fixedshape_tensor( DataType::Float64 => { let mut values = Float64Builder::with_capacity(features.len()); for tensors in tensor_iter { - if let Some(tensors) = tensors? { + if let Some(tensors) = tensors.map_prost_err(location!())? { for tensor in tensors { validate_tensor(&tensor, type_info)?; if tensor.float_val.is_empty() { diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/bed6140c-b15a-454e-83a4-d66520397899/bitmap_page_lookup.lance b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/bed6140c-b15a-454e-83a4-d66520397899/bitmap_page_lookup.lance new file mode 100644 index 00000000000..5b3983fead5 Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/bed6140c-b15a-454e-83a4-d66520397899/bitmap_page_lookup.lance differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/e034c4d8-77cd-422c-8855-209eed8deff8/page_data.lance b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/e034c4d8-77cd-422c-8855-209eed8deff8/page_data.lance new file mode 100644 index 00000000000..d97d872a3fe Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/e034c4d8-77cd-422c-8855-209eed8deff8/page_data.lance differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/e034c4d8-77cd-422c-8855-209eed8deff8/page_lookup.lance b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/e034c4d8-77cd-422c-8855-209eed8deff8/page_lookup.lance new file mode 100644 index 00000000000..deeb36ca9a9 Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_indices/e034c4d8-77cd-422c-8855-209eed8deff8/page_lookup.lance differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/0-ca14443d-4119-474d-a32d-ae6c59288e9a.txn b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/0-ca14443d-4119-474d-a32d-ae6c59288e9a.txn new file mode 100644 index 00000000000..d880fea9082 Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/0-ca14443d-4119-474d-a32d-ae6c59288e9a.txn differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/1-6c1bfc70-d75f-4b58-84ec-aee73e2389d6.txn b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/1-6c1bfc70-d75f-4b58-84ec-aee73e2389d6.txn new file mode 100644 index 00000000000..8575b67ce2b Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/1-6c1bfc70-d75f-4b58-84ec-aee73e2389d6.txn differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/2-70cf21e4-8f6d-4d41-b303-3dc1ee959c0b.txn b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/2-70cf21e4-8f6d-4d41-b303-3dc1ee959c0b.txn new file mode 100644 index 00000000000..97aed3d6daf Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_transactions/2-70cf21e4-8f6d-4d41-b303-3dc1ee959c0b.txn differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/1.manifest b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/1.manifest new file mode 100644 index 00000000000..4b8b0703d6a Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/1.manifest differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/2.manifest b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/2.manifest new file mode 100644 index 00000000000..f92dab11396 Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/2.manifest differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/3.manifest b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/3.manifest new file mode 100644 index 00000000000..5f747931c41 Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/_versions/3.manifest differ diff --git a/test_data/v0.20.0/old_btree_bitmap_indices.lance/data/1f29c4b8-24ba-4f50-8d07-3b0b5c1b4f3f.lance b/test_data/v0.20.0/old_btree_bitmap_indices.lance/data/1f29c4b8-24ba-4f50-8d07-3b0b5c1b4f3f.lance new file mode 100644 index 00000000000..e6e9d742cfd Binary files /dev/null and b/test_data/v0.20.0/old_btree_bitmap_indices.lance/data/1f29c4b8-24ba-4f50-8d07-3b0b5c1b4f3f.lance differ diff --git a/test_data/v0.21.0/bad_index_fragment_bitmap/_indices/ca9b1111-abfc-4fde-b4cc-8e667b84e65d/index.idx b/test_data/v0.21.0/bad_index_fragment_bitmap/_indices/ca9b1111-abfc-4fde-b4cc-8e667b84e65d/index.idx new file mode 100644 index 00000000000..a5e08e4bbc7 Binary files /dev/null and b/test_data/v0.21.0/bad_index_fragment_bitmap/_indices/ca9b1111-abfc-4fde-b4cc-8e667b84e65d/index.idx differ diff --git a/test_data/v0.21.0/bad_index_fragment_bitmap/_indices/dc833a6e-a710-48aa-af24-9ab80f30700c/index.idx b/test_data/v0.21.0/bad_index_fragment_bitmap/_indices/dc833a6e-a710-48aa-af24-9ab80f30700c/index.idx new file mode 100644 index 00000000000..019ec584ecb Binary files /dev/null and b/test_data/v0.21.0/bad_index_fragment_bitmap/_indices/dc833a6e-a710-48aa-af24-9ab80f30700c/index.idx differ diff --git a/test_data/v0.21.0/bad_index_fragment_bitmap/_transactions/3-f68af88b-ea42-4fec-9feb-2b5bb3f48223.txn b/test_data/v0.21.0/bad_index_fragment_bitmap/_transactions/3-f68af88b-ea42-4fec-9feb-2b5bb3f48223.txn new file mode 100644 index 00000000000..c8a4c7c1fc4 Binary files /dev/null and b/test_data/v0.21.0/bad_index_fragment_bitmap/_transactions/3-f68af88b-ea42-4fec-9feb-2b5bb3f48223.txn differ diff --git a/test_data/v0.21.0/bad_index_fragment_bitmap/_versions/4.manifest b/test_data/v0.21.0/bad_index_fragment_bitmap/_versions/4.manifest new file mode 100644 index 00000000000..34ff0a18924 Binary files /dev/null and b/test_data/v0.21.0/bad_index_fragment_bitmap/_versions/4.manifest differ diff --git a/test_data/v0.21.0/bad_index_fragment_bitmap/data/0e45e8ed-1d98-4e07-a4a6-67ca3d194291.lance b/test_data/v0.21.0/bad_index_fragment_bitmap/data/0e45e8ed-1d98-4e07-a4a6-67ca3d194291.lance new file mode 100644 index 00000000000..7bcaf3cacdf Binary files /dev/null and b/test_data/v0.21.0/bad_index_fragment_bitmap/data/0e45e8ed-1d98-4e07-a4a6-67ca3d194291.lance differ diff --git a/test_data/v0.21.0/bad_index_fragment_bitmap/data/c042e881-07a6-4a65-96b9-c3f31ea3bb47.lance b/test_data/v0.21.0/bad_index_fragment_bitmap/data/c042e881-07a6-4a65-96b9-c3f31ea3bb47.lance new file mode 100644 index 00000000000..783190b834b Binary files /dev/null and b/test_data/v0.21.0/bad_index_fragment_bitmap/data/c042e881-07a6-4a65-96b9-c3f31ea3bb47.lance differ diff --git a/test_data/v0.21.0/datagen.py b/test_data/v0.21.0/datagen.py new file mode 100644 index 00000000000..f0d2edcaf18 --- /dev/null +++ b/test_data/v0.21.0/datagen.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from datetime import timedelta + +import lance +import pyarrow as pa +import pyarrow.compute as pc + +# To generate the test file, we should be running this version of lance +assert lance.__version__ == "0.21.0" + +data = pa.table( + { + "vector": pa.FixedSizeListArray.from_arrays( + pc.random(16 * 256).cast(pa.float32()), 16 + ) + } +) +ds = lance.write_dataset(data, "bad_index_fragment_bitmap") +ds.create_index("vector", index_type="IVF_PQ", num_partitions=1, num_sub_vectors=1) + +data2 = pa.table( + { + "vector": pa.FixedSizeListArray.from_arrays( + pc.random(16 * 32).cast(pa.float32()), 16 + ) + } +) +ds.insert(data2) +ds.optimize.optimize_indices(num_indices_to_merge=0) + +ds.cleanup_old_versions(older_than=timedelta(0)) + +indices = ds.list_indices() +assert len(indices) == 2 +# There is overlap in fragment_ids, which is not allowed +assert indices[0]["fragment_ids"] == {0} +assert indices[1]["fragment_ids"] == {0, 1}